# はじめに
このノートブックではSelf-Supervised Learning（以下SSL）の一種である**SimSiam**を用いて、与えられた美術作品に関しての表現学習を行います。SSLはラベルを利用せずデータそのものから擬似的に教師データを作ってデータの表現を学習する手法で、今回のように事前学習済み重みが禁止されている場合においても「より良い重み」を得るのに使える可能性があります。   
SSLの一種であるSimSiamは2つのネットワークに対して異なるデータ拡張を施して類似度を近づけるように学習させる手法で、負例を用意せずにシンプルに学習を行えるという特徴があります。  
<img src="https://tech.fusic.co.jp/uploads/exploring_siam_arcs.png" width=60%>    
今回はSSLを簡単に行うことができるライブラリ**lightly**を用いてSimSiamによる表現学習を行います。公式のチュートリアルに従って、表現学習を行い、得られたembeddingからデータの分布を確認することが目的となります。  
(※Pretrained Modelは使用しないよう注意しましょう) 
  
【参考資料】  
[Exploring Simple Siamese Representation Learning](https://arxiv.org/abs/2011.10566)  
[自己教師あり学習の新しいアプローチ / SimSiam: Exploring Simple Siamese Representation Learning](https://speakerdeck.com/sansandsoc/simsiam-exploring-simple-siamese-representation-learning)  
[【論文読み】Exploring Simple Siamese Representation Learning](https://tech.fusic.co.jp/posts/2020-12-25-ml-simsiam-representation-learning/)  
[https://github.com/lightly-ai/lightly](https://github.com/lightly-ai/lightly)  
[Train SimSiam on satellite images](https://docs.lightly.ai/tutorials/package/tutorial_simsiam_esa.html)

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# For Colab
!pip install lightly

Collecting lightly
[?25l  Downloading https://files.pythonhosted.org/packages/86/8f/0345b730f0f06a3bfb6b1e04ea1b132ef302cc00b6456bbfdb9c5a5015d6/lightly-1.1.15-py3-none-any.whl (240kB)
[K     |████████████████████████████████| 245kB 7.6MB/s 
[?25hCollecting pytorch-lightning>=1.0.4
[?25l  Downloading https://files.pythonhosted.org/packages/48/5e/19c817ad2670c1d822642ed7bfc4d9d4c30c2f8eaefebcd575a3188d7319/pytorch_lightning-1.3.8-py3-none-any.whl (813kB)
[K     |████████████████████████████████| 819kB 12.3MB/s 
[?25hCollecting lightly-utils==0.0.1
  Downloading https://files.pythonhosted.org/packages/17/a1/e36f214e3d22d8417b718f51cdc49853bd63a50584f13981926ef8c5a368/lightly_utils-0.0.1-py3-none-any.whl
Collecting tqdm>=4.44
[?25l  Downloading https://files.pythonhosted.org/packages/7a/ec/f8ff3ccfc4e59ce619a66a0bf29dc3b49c2e8c07de29d572e191c006eaa2/tqdm-4.61.2-py2.py3-none-any.whl (76kB)
[K     |████████████████████████████████| 81kB 10.1MB/s 
Collecting hydra-core>=1.0.0
[?25l 

In [3]:
# For Colab
# !unzip dataset_atmaCup11.zip
# !mkdir imgs
# !unzip photos.zip -d ./imgs/

# SimSiamによる表現学習

### ライブラリの読み込み

In [4]:
import math
import torch
import torch.nn as nn
import torchvision
import numpy as np
import lightly

### Config

- `batch_size`はより大きい方がよいかもしれません（論文参照）
- `num_ftrs`は画像認識モデルのembedding数を表しており、今回は`resnet18`を使用しているため512となります

In [5]:
use_amp = True

num_workers = 2
batch_size = 512
seed = 1
epochs = 500
input_size = 224

# dimension of the embeddings
num_ftrs = 512
# dimension of the output of the prediction and projection heads
out_dim = proj_hidden_dim = 512
# the prediction head uses a bottleneck architecture
pred_hidden_dim = 128
# use 2 layers in the projection head
num_mlp_layers = 2

### Seed / 画像データのPath

In [6]:
# seed torch and numpy
torch.manual_seed(0)
np.random.seed(0)

# set the path to the dataset

path_to_data = '/content/drive/MyDrive/atmaCup/#11/dataset_atmaCup11/inputs/photos'

In [7]:
import os

dataset_root = '/content/drive/MyDrive/atmaCup/#11/dataset_atmaCup11'
assert dataset_root is not None
output_dir = os.path.join(dataset_root, "simsam_tutorial")
os.makedirs(output_dir, exist_ok=True)

### DataLoader

- `collate_fn`でDataAugmentationと適用される確率を指定しています。
  - hf_prob: Horizontal flip
  - vf_prob: Vertical flip
  - rr_prob: Random (+90 degree) rotation
  - min_scale: Random Cropの最小スケール
  - cj_prob: Color jitter
  - cj_bright: Brightness jitter
  - cj_contrast: Contrast jitter
  - cj_hue: Hue jitter
  - cj_sat: Saturation jitter
- 参考： https://docs.lightly.ai/lightly.data.html#lightly.data.collate.ImageCollateFunction

In [8]:
# define the augmentations for self-supervised learning
collate_fn = lightly.data.ImageCollateFunction(
    input_size=input_size,
    # require invariance to flips and rotations
    hf_prob=0.5,
    vf_prob=0.5,
    rr_prob=0.5,
    # satellite images are all taken from the same height
    # so we use only slight random cropping
    min_scale=0.5,
    # use a weak color jitter for invariance w.r.t small color changes
    # cj_prob=0.2,
    # cj_bright=0.1,
    # cj_contrast=0.1,
    # cj_hue=0.1,
    # cj_sat=0.1,
)

# create a lightly dataset for training, since the augmentations are handled
# by the collate function, there is no need to apply additional ones here
dataset_train_simsiam = lightly.data.LightlyDataset(
    input_dir=path_to_data
)

# create a dataloader for training
dataloader_train_simsiam = torch.utils.data.DataLoader(
    dataset_train_simsiam,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

# create a torchvision transformation for embedding the dataset after training
# here, we resize the images to match the input size during training and apply
# a normalization of the color channel based on statistics from imagenet
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((input_size, input_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])



# create a lightly dataset for embedding
dataset_test = lightly.data.LightlyDataset(
    input_dir=path_to_data,
    transform=test_transforms
)



# create a dataloader for embedding
dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

### Model

- ResNet18の最終層を取り除くことでBackboneレイヤーとしています
- ※Pretrained Modelは使用しないよう注意しましょう

In [9]:
# we use a pretrained resnet for this tutorial to speed
# up training time but you can also train one from scratch
# Do not use pretrained Model
resnet = torchvision.models.resnet18(pretrained=False)
backbone = nn.Sequential(*list(resnet.children())[:-1])

# create the SimSiam model using the backbone from above
model = lightly.models.SimSiam(
    backbone,
    num_ftrs=num_ftrs,
    #proj_hidden_dim=proj_hidden_dim, # defaultを使用
    #pred_hidden_dim=pred_hidden_dim, # defaultを使用
    #out_dim=out_dim, # defaultを使用
    num_mlp_layers=2
)

model.load_state_dict(torch.load(os.path.join(output_dir, '512_400_400_224_-0.8167325766863317.pth')))

<All keys matched successfully>

### Loss / Optimizer

In [10]:
# # SimSiam uses a symmetric negative cosine similarity loss
# criterion = lightly.loss.SymNegCosineSimilarityLoss()

# # scale the learning rate
# lr = 0.05 * batch_size / 256
# # use SGD with momentum and weight decay
# optimizer = torch.optim.SGD(
#     model.parameters(),
#     lr=lr,
#     momentum=0.9,
#     weight_decay=5e-4
# )

In [11]:
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

### SimSiamによるSelf-Supervised Learning

In [20]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

# avg_losses = []
# avg_loss = 0.
# avg_output_std = 0.
# for e in range(1, epochs + 1):

#     for (x0, x1), _, _ in dataloader_train_simsiam:
#         # move images to the gpu
#         x0 = x0.to(device)
#         x1 = x1.to(device)
        
#         with torch.cuda.amp.autocast(enabled=use_amp):
#             # run the model on both transforms of the images
#             # the output of the simsiam model is a y containing the predictions
#             # and projections for each input x
#             y0, y1 = model(x0, x1)

#             # backpropagation
#             loss = criterion(y0, y1)
#         loss.backward()

#         optimizer.step()
#         optimizer.zero_grad()

#         # calculate the per-dimension standard deviation of the outputs
#         # we can use this later to check whether the embeddings are collapsing
#         output, _ = y0
#         output = output.detach()
#         output = torch.nn.functional.normalize(output, dim=1)

#         output_std = torch.std(output, 0)
#         output_std = output_std.mean()

#         # use moving averages to track the loss and standard deviation
#         w = 0.9
#         avg_loss = w * avg_loss + (1 - w) * loss.item()
#         avg_output_std = w * avg_output_std + (1 - w) * output_std.item()
    
#     scheduler.step()

#     # the level of collapse is large if the standard deviation of the l2
#     # normalized output is much smaller than 1 / sqrt(dim)
#     collapse_level = max(0., 1 - math.sqrt(out_dim) * avg_output_std)
#     # print intermediate results
#     print(f'[Epoch {e:3d}] '
#         f'Loss = {avg_loss:.2f} | '
#         f'Collapse Level: {collapse_level:.2f} / 1.00')
    
#     avg_losses.append(avg_loss)
    
#     if e % 50 == 0:
#         model_path = os.path.join(output_dir, str(batch_size) + '_' + str(e) + '_' + str(input_size) + '_' + str(avg_loss) + '.pth')
#         torch.save(model.state_dict(), model_path)

SimSiam(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True

In [21]:
# fig = plt.figure()

# plt.plot(list(range(1, epochs)), avg_losses)

# fig.savefig(os.path.join(output_dir, "avg_losses_(batch_size: " + str(batch_size) + ", epochs: " + str(epochs) + ", input_size: " + str(input_size) + ").png"))

# 埋め込み表現の可視化

### embeddingの取得

In [22]:
embeddings = []
filenames = []

# disable gradients for faster calculations
model.eval()
with torch.no_grad():
    for i, (x, _, fnames) in enumerate(dataloader_test):
        # move the images to the gpu
        x = x.to(device)
        # embed the images with the pre-trained backbone
        y = model.backbone(x)
        y = y.squeeze()
        # store the embeddings and filenames in lists
        embeddings.append(y)
        filenames = filenames + list(fnames)

# concatenate the embeddings and convert to numpy
embeddings = torch.cat(embeddings, dim=0)
embeddings = embeddings.cpu().numpy()

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


#次元削減して特徴量化する

In [23]:
from sklearn.decomposition import PCA

In [29]:
pca = PCA(n_components=0.95)
pca.fit(embeddings)
print('保たれている情報: ', np.sum(pca.explained_variance_ratio_))
print('主成分の数: ', pca.n_components_)

保たれている情報:  0.9900422
主成分の数:  277


In [25]:
reduced_embeddings = pca.transform(embeddings)

In [27]:
import pandas as pd
embeddings_df = pd.DataFrame(reduced_embeddings)

In [28]:
embeddings_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69
0,0.003461,-0.150288,0.304685,-0.147037,-0.071831,-0.127435,-0.248672,0.121594,0.024245,0.103395,0.035013,-0.043587,-0.045812,-0.008263,0.042553,-0.094155,0.067464,0.051796,0.094922,-0.052956,-0.051839,-0.036569,-0.041889,0.142896,0.016438,-0.029533,-0.064217,-0.063474,0.011841,-0.028218,0.056763,0.046593,0.01476,0.017831,-0.046627,0.050472,-0.030912,0.04523,-0.03468,-0.011978,0.03475,-0.01373,-0.008522,0.041006,-0.00571,0.024626,-0.057183,-0.047388,0.011142,-0.031013,0.014076,-0.007436,0.023994,0.001321,-0.017687,0.039743,-0.017727,0.033123,-0.008643,-0.010735,-0.025228,-0.00297,-0.019823,-0.035016,-0.038778,-0.005889,-0.019397,-0.009862,0.016188,0.032983
1,0.846163,0.18872,0.041107,-0.074348,-0.040203,0.316707,-0.015704,0.010054,-0.095898,0.073579,-0.125106,0.016968,0.01515,0.145727,0.069973,0.137528,0.050251,0.076274,0.003677,-0.079682,0.087663,-0.087817,-0.092455,-0.02235,-0.058843,-0.039975,-0.013696,-0.005117,-0.074898,-0.01996,-0.014327,-0.04035,0.058463,-0.00027,-0.06154,0.039419,-0.008369,-0.051989,-0.000839,0.041522,-0.012149,-0.020909,0.011928,0.033585,-0.014651,-0.005886,-0.015974,-0.014351,-0.005237,-0.016221,0.003751,-0.007383,-0.015334,0.023362,-0.00269,-0.01174,-0.008084,0.00446,-0.002597,-0.014108,-0.009685,0.009777,0.033897,0.024911,0.029379,0.007045,-0.010398,4.3e-05,-0.010896,0.010791
2,0.334645,0.082801,-0.177321,0.032885,-0.082624,0.070577,-0.139059,-0.135601,0.15824,0.195388,-0.127144,0.006541,-0.031079,-0.038906,-0.096654,-0.085794,-0.085019,-0.057718,-0.05766,0.114305,0.099554,-0.029927,-0.016152,0.033189,0.010654,0.045596,0.03989,0.071587,-0.015964,-0.052727,-0.025776,-0.013673,-0.052457,0.024115,0.011283,0.032363,0.030217,-0.000711,-0.022785,0.035338,0.017371,0.010884,0.036519,-0.004338,-0.018168,0.009135,-0.021639,-0.018614,0.003046,0.002753,-0.016183,-0.013654,0.001981,0.020637,-0.019201,-0.010563,-0.054542,0.00476,0.006031,0.01238,-0.012826,0.034228,-0.005718,-0.009643,0.023429,0.020714,-0.003778,0.001901,0.009067,0.004808
3,-0.338829,0.159673,0.281051,-0.016789,-0.050039,-0.08037,-0.060682,-0.117925,0.159276,-0.119535,-0.018292,-0.051316,-0.081617,0.091963,-0.019635,-0.069005,-0.091063,0.001234,0.037284,-0.009765,0.003249,0.095891,-0.022283,0.027075,-0.061662,-0.033879,0.073014,0.026859,-0.024101,0.040082,0.019078,0.075158,0.06696,-0.083126,-0.017341,-0.027031,0.01037,0.065689,0.000179,-0.00032,-0.025215,-0.028083,-0.014913,-0.027398,0.012249,0.019203,0.054922,-0.033294,-0.036666,0.012671,-0.004333,0.013353,-0.021304,-0.011959,-0.004409,0.015165,-0.039639,-0.026121,-0.013817,-0.038016,0.013369,-0.034846,-0.00352,-0.006172,0.014212,0.003809,0.013298,0.001129,-0.037384,0.018928
4,0.192325,0.386284,0.128252,-0.080947,-0.505994,-0.18661,0.137052,-0.125845,-0.110409,0.087463,0.081504,0.035112,-0.177484,-0.008483,0.012824,0.048505,0.028079,-0.001145,-0.16076,-0.070084,0.004053,-0.026629,0.071748,0.040905,-0.047137,-0.018471,0.010865,-0.010989,-0.003723,0.085117,-0.022194,-0.077892,0.044872,0.026881,-0.018167,0.053841,0.072248,0.040532,-0.007161,-0.055044,-0.001472,-0.000322,0.061478,0.001224,0.006597,-0.008796,0.026923,0.002469,-0.024579,0.012729,0.014676,-0.030754,0.009284,0.038492,0.058528,-0.015413,0.007336,-0.036646,0.028848,-0.005285,0.006043,-0.00983,-0.034745,0.014938,-0.006283,0.021162,-0.008704,0.017701,-0.007579,0.003753


In [30]:
filenames

['0009e50b94be36ccad39.jpg',
 '000bd5e82eb22f199f44.jpg',
 '0015f44de1854f617516.jpg',
 '002bff09b09998d0be65.jpg',
 '00309fb1ef05416f9c1f.jpg',
 '00388a678879ba1efa27.jpg',
 '003a1562e97f79ba96dc.jpg',
 '004890880e8e7431147b.jpg',
 '005e1e7c6496902d23f3.jpg',
 '00718c32602425f504c1.jpg',
 '0075ffcdf3fa548a44b9.jpg',
 '0079cb204f9b176c8752.jpg',
 '007c091616828798b5e1.jpg',
 '007f5e3620b458d77212.jpg',
 '007ffbdd2e0775b7b8ff.jpg',
 '0084d67b69a2368bf3e0.jpg',
 '00990614b43285e49f4a.jpg',
 '009c0c03893beba8dfcc.jpg',
 '00a3b7f15d64cd5da957.jpg',
 '00aff1bb77901f69c94a.jpg',
 '00b0aa868d22c818b1c3.jpg',
 '00b3b35e76c7750a36b7.jpg',
 '00b4675fa1e15a74d6c5.jpg',
 '00b4c4f45f0b4410a90e.jpg',
 '00bf812ffe8a62d45661.jpg',
 '00c3692aa0f3c1d100d9.jpg',
 '00c93e990e799fb3d8c9.jpg',
 '00cc73f3f314db5e2406.jpg',
 '00ccee20935853fd5990.jpg',
 '00cddf1a88c0f4261c23.jpg',
 '00cf9b7ea3168851af00.jpg',
 '00db4ef4b89904547a77.jpg',
 '00e68e54a775dc4a8420.jpg',
 '00efa016fe319d3687ef.jpg',
 '00f067fbaeda

In [31]:
embeddings_df['filenames'] = filenames

In [32]:
embeddings_df['object_id'] = embeddings_df['filenames'].apply(lambda x: x[:-4])

In [33]:
embeddings_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,filenames,object_id
0,0.003461,-0.150288,0.304685,-0.147037,-0.071831,-0.127435,-0.248672,0.121594,0.024245,0.103395,0.035013,-0.043587,-0.045812,-0.008263,0.042553,-0.094155,0.067464,0.051796,0.094922,-0.052956,-0.051839,-0.036569,-0.041889,0.142896,0.016438,-0.029533,-0.064217,-0.063474,0.011841,-0.028218,0.056763,0.046593,0.01476,0.017831,-0.046627,0.050472,-0.030912,0.04523,-0.03468,-0.011978,0.03475,-0.01373,-0.008522,0.041006,-0.00571,0.024626,-0.057183,-0.047388,0.011142,-0.031013,0.014076,-0.007436,0.023994,0.001321,-0.017687,0.039743,-0.017727,0.033123,-0.008643,-0.010735,-0.025228,-0.00297,-0.019823,-0.035016,-0.038778,-0.005889,-0.019397,-0.009862,0.016188,0.032983,0009e50b94be36ccad39.jpg,0009e50b94be36ccad39
1,0.846163,0.18872,0.041107,-0.074348,-0.040203,0.316707,-0.015704,0.010054,-0.095898,0.073579,-0.125106,0.016968,0.01515,0.145727,0.069973,0.137528,0.050251,0.076274,0.003677,-0.079682,0.087663,-0.087817,-0.092455,-0.02235,-0.058843,-0.039975,-0.013696,-0.005117,-0.074898,-0.01996,-0.014327,-0.04035,0.058463,-0.00027,-0.06154,0.039419,-0.008369,-0.051989,-0.000839,0.041522,-0.012149,-0.020909,0.011928,0.033585,-0.014651,-0.005886,-0.015974,-0.014351,-0.005237,-0.016221,0.003751,-0.007383,-0.015334,0.023362,-0.00269,-0.01174,-0.008084,0.00446,-0.002597,-0.014108,-0.009685,0.009777,0.033897,0.024911,0.029379,0.007045,-0.010398,4.3e-05,-0.010896,0.010791,000bd5e82eb22f199f44.jpg,000bd5e82eb22f199f44
2,0.334645,0.082801,-0.177321,0.032885,-0.082624,0.070577,-0.139059,-0.135601,0.15824,0.195388,-0.127144,0.006541,-0.031079,-0.038906,-0.096654,-0.085794,-0.085019,-0.057718,-0.05766,0.114305,0.099554,-0.029927,-0.016152,0.033189,0.010654,0.045596,0.03989,0.071587,-0.015964,-0.052727,-0.025776,-0.013673,-0.052457,0.024115,0.011283,0.032363,0.030217,-0.000711,-0.022785,0.035338,0.017371,0.010884,0.036519,-0.004338,-0.018168,0.009135,-0.021639,-0.018614,0.003046,0.002753,-0.016183,-0.013654,0.001981,0.020637,-0.019201,-0.010563,-0.054542,0.00476,0.006031,0.01238,-0.012826,0.034228,-0.005718,-0.009643,0.023429,0.020714,-0.003778,0.001901,0.009067,0.004808,0015f44de1854f617516.jpg,0015f44de1854f617516
3,-0.338829,0.159673,0.281051,-0.016789,-0.050039,-0.08037,-0.060682,-0.117925,0.159276,-0.119535,-0.018292,-0.051316,-0.081617,0.091963,-0.019635,-0.069005,-0.091063,0.001234,0.037284,-0.009765,0.003249,0.095891,-0.022283,0.027075,-0.061662,-0.033879,0.073014,0.026859,-0.024101,0.040082,0.019078,0.075158,0.06696,-0.083126,-0.017341,-0.027031,0.01037,0.065689,0.000179,-0.00032,-0.025215,-0.028083,-0.014913,-0.027398,0.012249,0.019203,0.054922,-0.033294,-0.036666,0.012671,-0.004333,0.013353,-0.021304,-0.011959,-0.004409,0.015165,-0.039639,-0.026121,-0.013817,-0.038016,0.013369,-0.034846,-0.00352,-0.006172,0.014212,0.003809,0.013298,0.001129,-0.037384,0.018928,002bff09b09998d0be65.jpg,002bff09b09998d0be65
4,0.192325,0.386284,0.128252,-0.080947,-0.505994,-0.18661,0.137052,-0.125845,-0.110409,0.087463,0.081504,0.035112,-0.177484,-0.008483,0.012824,0.048505,0.028079,-0.001145,-0.16076,-0.070084,0.004053,-0.026629,0.071748,0.040905,-0.047137,-0.018471,0.010865,-0.010989,-0.003723,0.085117,-0.022194,-0.077892,0.044872,0.026881,-0.018167,0.053841,0.072248,0.040532,-0.007161,-0.055044,-0.001472,-0.000322,0.061478,0.001224,0.006597,-0.008796,0.026923,0.002469,-0.024579,0.012729,0.014676,-0.030754,0.009284,0.038492,0.058528,-0.015413,0.007336,-0.036646,0.028848,-0.005285,0.006043,-0.00983,-0.034745,0.014938,-0.006283,0.021162,-0.008704,0.017701,-0.007579,0.003753,00309fb1ef05416f9c1f.jpg,00309fb1ef05416f9c1f


In [34]:
embeddings_df.to_csv(os.path.join(output_dir, "embeddings.csv"), index=False)

### plot用のライブラリ読み込み

In [None]:
# # for plotting
# import os
# from PIL import Image

# import matplotlib.pyplot as plt
# import matplotlib.offsetbox as osb
# from matplotlib import rcParams as rcp

# # for resizing images to thumbnails
# import torchvision.transforms.functional as functional

# # for clustering and 2d representations
# from sklearn import random_projection

### embeddingの次元削減とNormalizing

In [None]:
# # for the scatter plot we want to transform the images to a two-dimensional
# # vector space using a random Gaussian projection
# projection = random_projection.GaussianRandomProjection(n_components=2)
# embeddings_2d = projection.fit_transform(embeddings)

# # normalize the embeddings to fit in the [0, 1] square
# M = np.max(embeddings_2d, axis=0)
# m = np.min(embeddings_2d, axis=0)
# embeddings_2d = (embeddings_2d - m) / (M - m)

### 散布図形式のサムネイル可視化

In [None]:
# def get_scatter_plot_with_thumbnails():
#     """Creates a scatter plot with image overlays.
#     """
#     # initialize empty figure and add subplot
#     fig = plt.figure(figsize=(12,12))
#     fig.suptitle('SimSiam Scatter Plot')
#     ax = fig.add_subplot(1, 1, 1)
#     # shuffle images and find out which images to show
#     shown_images_idx = []
#     shown_images = np.array([[1., 1.]])
#     iterator = [i for i in range(embeddings_2d.shape[0])]
#     np.random.shuffle(iterator)
#     for i in iterator:
#         # only show image if it is sufficiently far away from the others
#         dist = np.sum((embeddings_2d[i] - shown_images) ** 2, 1)
#         if np.min(dist) < 1.5e-3:
#             continue
#         shown_images = np.r_[shown_images, [embeddings_2d[i]]]
#         shown_images_idx.append(i)

#     # plot image overlays
#     for idx in shown_images_idx:
#         thumbnail_size = int(rcp['figure.figsize'][0] * 5.)
#         path = os.path.join(path_to_data, filenames[idx])
#         img = Image.open(path)
#         img = functional.resize(img, thumbnail_size)
#         img = np.array(img)
#         img_box = osb.AnnotationBbox(
#             osb.OffsetImage(img, cmap=plt.cm.gray_r),
#             embeddings_2d[idx],
#             pad=0.2,
#         )
#         ax.add_artist(img_box)

#     # set aspect ratio
#     ratio = 1. / ax.get_data_ratio()
#     ax.set_aspect(ratio, adjustable='box')


# # get a scatter plot with thumbnail overlays
# get_scatter_plot_with_thumbnails()

比較的近い色調・表現の絵同士が集まっており、表現学習が行えていることが分かります

### 類似画像の可視化

In [None]:
# def get_image_as_np_array(filename: str):
#     """Loads the image with filename and returns it as a numpy array.

#     """
#     img = Image.open(filename)
#     return np.asarray(img)[...,:3]


# def get_image_as_np_array_with_frame(filename: str, w: int = 5):
#     """Returns an image as a numpy array with a black frame of width w.

#     """
#     img = get_image_as_np_array(filename)
#     ny, nx, _ = img.shape
#     # create an empty image with padding for the frame
#     framed_img = np.zeros((w + ny + w, w + nx + w, 3))
#     framed_img = framed_img.astype(np.uint8)
#     # put the original image in the middle of the new one
#     framed_img[w:-w, w:-w] = img
#     return framed_img


# def plot_nearest_neighbors_3x3(example_image: str, i: int):
#     """Plots the example image and its eight nearest neighbors.

#     """
#     n_subplots = 9
#     # initialize empty figure
#     fig = plt.figure()
#     fig.suptitle(f"Nearest Neighbor Plot {i + 1}")
#     #
#     example_idx = filenames.index(example_image)
#     # get distances to the cluster center
#     distances = embeddings - embeddings[example_idx]
#     distances = np.power(distances, 2).sum(-1).squeeze()
#     # sort indices by distance to the center
#     nearest_neighbors = np.argsort(distances)[:n_subplots]
#     # show images
#     for plot_offset, plot_idx in enumerate(nearest_neighbors):
#         ax = fig.add_subplot(3, 3, plot_offset + 1)
#         # get the corresponding filename
#         fname = os.path.join(path_to_data, filenames[plot_idx])
#         if plot_offset == 0:
#             ax.set_title(f"Example Image")
#             plt.imshow(get_image_as_np_array_with_frame(fname))
#         else:
#             plt.imshow(get_image_as_np_array(fname))
#         # let's disable the axis
#         plt.axis("off")

In [None]:
# # show example images for each cluster
# example_images = [
#     '0a207830d8cca27de4be.jpg',
#     '000bd5e82eb22f199f44.jpg',
#     '4193ebdc9a860f646a40.jpg',
#     '0cd8af895677b51c5897.jpg',
#     '0a44488ae1db7d79d033.jpg',
# ]

# for i, example_image in enumerate(example_images):
#     plot_nearest_neighbors_3x3(example_image, i)

人物画・花・生き物などがそれぞれ類似画像として提案されているようです

## 学習済みモデルの活用

SimSiamによる学習済みのResNet18は`model.backbone`として呼び出すことができます  。  
これを用いて更に教師あり学習を行ったり、学習済みembeddingを推論に活用することが可能です。  
（教師あり学習を行う場合は最終層にLinear層を追加する必要があることに注意してください）

In [None]:
# model.backbone