# はじめに
このノートブックではSelf-Supervised Learning（以下SSL）の一種である**SimSiam**を用いて、与えられた美術作品に関しての表現学習を行います。SSLはラベルを利用せずデータそのものから擬似的に教師データを作ってデータの表現を学習する手法で、今回のように事前学習済み重みが禁止されている場合においても「より良い重み」を得るのに使える可能性があります。   
SSLの一種であるSimSiamは2つのネットワークに対して異なるデータ拡張を施して類似度を近づけるように学習させる手法で、負例を用意せずにシンプルに学習を行えるという特徴があります。  
<img src="https://tech.fusic.co.jp/uploads/exploring_siam_arcs.png" width=60%>    
今回はSSLを簡単に行うことができるライブラリ**lightly**を用いてSimSiamによる表現学習を行います。公式のチュートリアルに従って、表現学習を行い、得られたembeddingからデータの分布を確認することが目的となります。  
(※Pretrained Modelは使用しないよう注意しましょう) 
  
【参考資料】  
[Exploring Simple Siamese Representation Learning](https://arxiv.org/abs/2011.10566)  
[自己教師あり学習の新しいアプローチ / SimSiam: Exploring Simple Siamese Representation Learning](https://speakerdeck.com/sansandsoc/simsiam-exploring-simple-siamese-representation-learning)  
[【論文読み】Exploring Simple Siamese Representation Learning](https://tech.fusic.co.jp/posts/2020-12-25-ml-simsiam-representation-learning/)  
[https://github.com/lightly-ai/lightly](https://github.com/lightly-ai/lightly)  
[Train SimSiam on satellite images](https://docs.lightly.ai/tutorials/package/tutorial_simsiam_esa.html)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# For Colab
!pip install lightly

Collecting lightly
[?25l  Downloading https://files.pythonhosted.org/packages/86/8f/0345b730f0f06a3bfb6b1e04ea1b132ef302cc00b6456bbfdb9c5a5015d6/lightly-1.1.15-py3-none-any.whl (240kB)
[K     |█▍                              | 10kB 19.5MB/s eta 0:00:01[K     |██▊                             | 20kB 25.4MB/s eta 0:00:01[K     |████                            | 30kB 30.2MB/s eta 0:00:01[K     |█████▌                          | 40kB 24.1MB/s eta 0:00:01[K     |██████▉                         | 51kB 17.0MB/s eta 0:00:01[K     |████████▏                       | 61kB 14.2MB/s eta 0:00:01[K     |█████████▌                      | 71kB 15.7MB/s eta 0:00:01[K     |███████████                     | 81kB 15.1MB/s eta 0:00:01[K     |████████████▎                   | 92kB 16.3MB/s eta 0:00:01[K     |█████████████▋                  | 102kB 15.2MB/s eta 0:00:01[K     |███████████████                 | 112kB 15.2MB/s eta 0:00:01[K     |████████████████▍               | 122kB 15.

In [None]:
# For Colab
# !unzip dataset_atmaCup11.zip
# !mkdir imgs
# !unzip photos.zip -d ./imgs/

# SimSiamによる表現学習

### ライブラリの読み込み

In [None]:
import math
import torch
import torch.nn as nn
import torchvision
import numpy as np
import lightly

### Config

- `batch_size`はより大きい方がよいかもしれません（論文参照）
- `num_ftrs`は画像認識モデルのembedding数を表しており、今回は`resnet18`を使用しているため512となります

In [None]:
use_amp = True

num_workers = 2
batch_size = 512
seed = 1
epochs = 500
input_size = 224

# dimension of the embeddings
num_ftrs = 512
# dimension of the output of the prediction and projection heads
out_dim = proj_hidden_dim = 512
# the prediction head uses a bottleneck architecture
pred_hidden_dim = 128
# use 2 layers in the projection head
num_mlp_layers = 2

### Seed / 画像データのPath

In [None]:


# seed torch and numpy
torch.manual_seed(0)
np.random.seed(0)

# set the path to the dataset

path_to_data = '/content/drive/MyDrive/atmaCup/#11/dataset_atmaCup11/inputs/photos'

### DataLoader

- `collate_fn`でDataAugmentationと適用される確率を指定しています。
  - hf_prob: Horizontal flip
  - vf_prob: Vertical flip
  - rr_prob: Random (+90 degree) rotation
  - min_scale: Random Cropの最小スケール
  - cj_prob: Color jitter
  - cj_bright: Brightness jitter
  - cj_contrast: Contrast jitter
  - cj_hue: Hue jitter
  - cj_sat: Saturation jitter
- 参考： https://docs.lightly.ai/lightly.data.html#lightly.data.collate.ImageCollateFunction

In [None]:
# define the augmentations for self-supervised learning
collate_fn = lightly.data.ImageCollateFunction(
    input_size=input_size,
    # require invariance to flips and rotations
    hf_prob=0.5,
    vf_prob=0.5,
    rr_prob=0.5,
    # satellite images are all taken from the same height
    # so we use only slight random cropping
    min_scale=0.5,
    # use a weak color jitter for invariance w.r.t small color changes
    # cj_prob=0.2,
    # cj_bright=0.1,
    # cj_contrast=0.1,
    # cj_hue=0.1,
    # cj_sat=0.1,
)

# create a lightly dataset for training, since the augmentations are handled
# by the collate function, there is no need to apply additional ones here
dataset_train_simsiam = lightly.data.LightlyDataset(
    input_dir=path_to_data
)

# create a dataloader for training
dataloader_train_simsiam = torch.utils.data.DataLoader(
    dataset_train_simsiam,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=num_workers
)

# create a torchvision transformation for embedding the dataset after training
# here, we resize the images to match the input size during training and apply
# a normalization of the color channel based on statistics from imagenet
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize((input_size, input_size)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=lightly.data.collate.imagenet_normalize['mean'],
        std=lightly.data.collate.imagenet_normalize['std'],
    )
])



# create a lightly dataset for embedding
dataset_test = lightly.data.LightlyDataset(
    input_dir=path_to_data,
    transform=test_transforms
)



# create a dataloader for embedding
dataloader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=num_workers
)

### Model

- ResNet18の最終層を取り除くことでBackboneレイヤーとしています
- ※Pretrained Modelは使用しないよう注意しましょう

In [None]:
# we use a pretrained resnet for this tutorial to speed
# up training time but you can also train one from scratch
# Do not use pretrained Model
resnet = torchvision.models.resnet18(pretrained=False)
backbone = nn.Sequential(*list(resnet.children())[:-1])

# create the SimSiam model using the backbone from above
model = lightly.models.SimSiam(
    backbone,
    num_ftrs=num_ftrs,
    #proj_hidden_dim=proj_hidden_dim, # defaultを使用
    #pred_hidden_dim=pred_hidden_dim, # defaultを使用
    #out_dim=out_dim, # defaultを使用
    num_mlp_layers=2
)


### Loss / Optimizer

In [None]:
# SimSiam uses a symmetric negative cosine similarity loss
criterion = lightly.loss.SymNegCosineSimilarityLoss()

# scale the learning rate
lr = 0.05 * batch_size / 256
# use SGD with momentum and weight decay
optimizer = torch.optim.SGD(
    model.parameters(),
    lr=lr,
    momentum=0.9,
    weight_decay=5e-4
)

In [None]:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

### SimSiamによるSelf-Supervised Learning

In [None]:
import os

dataset_root = '/content/drive/MyDrive/atmaCup/#11/dataset_atmaCup11'
assert dataset_root is not None
output_dir = os.path.join(dataset_root, "simsam_tutorial")
os.makedirs(output_dir, exist_ok=True)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

avg_losses = []
avg_loss = 0.
avg_output_std = 0.
for e in range(1, epochs + 1):

    for (x0, x1), _, _ in dataloader_train_simsiam:
        # move images to the gpu
        x0 = x0.to(device)
        x1 = x1.to(device)
        
        with torch.cuda.amp.autocast(enabled=use_amp):
            # run the model on both transforms of the images
            # the output of the simsiam model is a y containing the predictions
            # and projections for each input x
            y0, y1 = model(x0, x1)

            # backpropagation
            loss = criterion(y0, y1)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        # calculate the per-dimension standard deviation of the outputs
        # we can use this later to check whether the embeddings are collapsing
        output, _ = y0
        output = output.detach()
        output = torch.nn.functional.normalize(output, dim=1)

        output_std = torch.std(output, 0)
        output_std = output_std.mean()

        # use moving averages to track the loss and standard deviation
        w = 0.9
        avg_loss = w * avg_loss + (1 - w) * loss.item()
        avg_output_std = w * avg_output_std + (1 - w) * output_std.item()
    
    scheduler.step()

    # the level of collapse is large if the standard deviation of the l2
    # normalized output is much smaller than 1 / sqrt(dim)
    collapse_level = max(0., 1 - math.sqrt(out_dim) * avg_output_std)
    # print intermediate results
    print(f'[Epoch {e:3d}] '
        f'Loss = {avg_loss:.2f} | '
        f'Collapse Level: {collapse_level:.2f} / 1.00')
    
    avg_losses.append(avg_loss)
    
    if e % 50 == 0:
        model_path = os.path.join(output_dir, str(batch_size) + '_' + str(e) + '_' + str(input_size) + '_' + str(avg_loss) + '.pth')
        torch.save(model.state_dict(), model_path)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


[Epoch   1] Loss = -0.07 | Collapse Level: 0.59 / 1.00
[Epoch   2] Loss = -0.43 | Collapse Level: 0.60 / 1.00
[Epoch   3] Loss = -0.52 | Collapse Level: 0.60 / 1.00
[Epoch   4] Loss = -0.51 | Collapse Level: 0.59 / 1.00
[Epoch   5] Loss = -0.44 | Collapse Level: 0.57 / 1.00
[Epoch   6] Loss = -0.44 | Collapse Level: 0.58 / 1.00
[Epoch   7] Loss = -0.44 | Collapse Level: 0.58 / 1.00
[Epoch   8] Loss = -0.43 | Collapse Level: 0.57 / 1.00
[Epoch   9] Loss = -0.44 | Collapse Level: 0.57 / 1.00
[Epoch  10] Loss = -0.47 | Collapse Level: 0.58 / 1.00
[Epoch  11] Loss = -0.47 | Collapse Level: 0.57 / 1.00
[Epoch  12] Loss = -0.47 | Collapse Level: 0.57 / 1.00
[Epoch  13] Loss = -0.47 | Collapse Level: 0.57 / 1.00
[Epoch  14] Loss = -0.47 | Collapse Level: 0.57 / 1.00
[Epoch  15] Loss = -0.46 | Collapse Level: 0.56 / 1.00
[Epoch  16] Loss = -0.52 | Collapse Level: 0.58 / 1.00
[Epoch  17] Loss = -0.48 | Collapse Level: 0.56 / 1.00
[Epoch  18] Loss = -0.51 | Collapse Level: 0.58 / 1.00
[Epoch  19

In [None]:
fig = plt.figure()

plt.plot(list(range(1, epochs)), avg_losses)

fig.savefig(os.path.join(output_dir, "avg_losses_(batch_size: " + str(batch_size) + ", epochs: " + str(epochs) + ", input_size: " + str(input_size) + ").png"))

# 埋め込み表現の可視化

### embeddingの取得

In [None]:
embeddings = []
filenames = []

# disable gradients for faster calculations
model.eval()
with torch.no_grad():
    for i, (x, _, fnames) in enumerate(dataloader_test):
        # move the images to the gpu
        x = x.to(device)
        # embed the images with the pre-trained backbone
        y = model.backbone(x)
        y = y.squeeze()
        # store the embeddings and filenames in lists
        embeddings.append(y)
        filenames = filenames + list(fnames)

# concatenate the embeddings and convert to numpy
embeddings = torch.cat(embeddings, dim=0)
embeddings = embeddings.cpu().numpy()

### plot用のライブラリ読み込み

In [None]:
# for plotting
import os
from PIL import Image

import matplotlib.pyplot as plt
import matplotlib.offsetbox as osb
from matplotlib import rcParams as rcp

# for resizing images to thumbnails
import torchvision.transforms.functional as functional

# for clustering and 2d representations
from sklearn import random_projection

### embeddingの次元削減とNormalizing

In [None]:
# for the scatter plot we want to transform the images to a two-dimensional
# vector space using a random Gaussian projection
projection = random_projection.GaussianRandomProjection(n_components=2)
embeddings_2d = projection.fit_transform(embeddings)

# normalize the embeddings to fit in the [0, 1] square
M = np.max(embeddings_2d, axis=0)
m = np.min(embeddings_2d, axis=0)
embeddings_2d = (embeddings_2d - m) / (M - m)

### 散布図形式のサムネイル可視化

In [None]:
def get_scatter_plot_with_thumbnails():
    """Creates a scatter plot with image overlays.
    """
    # initialize empty figure and add subplot
    fig = plt.figure(figsize=(12,12))
    fig.suptitle('SimSiam Scatter Plot')
    ax = fig.add_subplot(1, 1, 1)
    # shuffle images and find out which images to show
    shown_images_idx = []
    shown_images = np.array([[1., 1.]])
    iterator = [i for i in range(embeddings_2d.shape[0])]
    np.random.shuffle(iterator)
    for i in iterator:
        # only show image if it is sufficiently far away from the others
        dist = np.sum((embeddings_2d[i] - shown_images) ** 2, 1)
        if np.min(dist) < 1.5e-3:
            continue
        shown_images = np.r_[shown_images, [embeddings_2d[i]]]
        shown_images_idx.append(i)

    # plot image overlays
    for idx in shown_images_idx:
        thumbnail_size = int(rcp['figure.figsize'][0] * 5.)
        path = os.path.join(path_to_data, filenames[idx])
        img = Image.open(path)
        img = functional.resize(img, thumbnail_size)
        img = np.array(img)
        img_box = osb.AnnotationBbox(
            osb.OffsetImage(img, cmap=plt.cm.gray_r),
            embeddings_2d[idx],
            pad=0.2,
        )
        ax.add_artist(img_box)

    # set aspect ratio
    ratio = 1. / ax.get_data_ratio()
    ax.set_aspect(ratio, adjustable='box')


# get a scatter plot with thumbnail overlays
get_scatter_plot_with_thumbnails()

比較的近い色調・表現の絵同士が集まっており、表現学習が行えていることが分かります

### 類似画像の可視化

In [None]:
def get_image_as_np_array(filename: str):
    """Loads the image with filename and returns it as a numpy array.

    """
    img = Image.open(filename)
    return np.asarray(img)[...,:3]


def get_image_as_np_array_with_frame(filename: str, w: int = 5):
    """Returns an image as a numpy array with a black frame of width w.

    """
    img = get_image_as_np_array(filename)
    ny, nx, _ = img.shape
    # create an empty image with padding for the frame
    framed_img = np.zeros((w + ny + w, w + nx + w, 3))
    framed_img = framed_img.astype(np.uint8)
    # put the original image in the middle of the new one
    framed_img[w:-w, w:-w] = img
    return framed_img


def plot_nearest_neighbors_3x3(example_image: str, i: int):
    """Plots the example image and its eight nearest neighbors.

    """
    n_subplots = 9
    # initialize empty figure
    fig = plt.figure()
    fig.suptitle(f"Nearest Neighbor Plot {i + 1}")
    #
    example_idx = filenames.index(example_image)
    # get distances to the cluster center
    distances = embeddings - embeddings[example_idx]
    distances = np.power(distances, 2).sum(-1).squeeze()
    # sort indices by distance to the center
    nearest_neighbors = np.argsort(distances)[:n_subplots]
    # show images
    for plot_offset, plot_idx in enumerate(nearest_neighbors):
        ax = fig.add_subplot(3, 3, plot_offset + 1)
        # get the corresponding filename
        fname = os.path.join(path_to_data, filenames[plot_idx])
        if plot_offset == 0:
            ax.set_title(f"Example Image")
            plt.imshow(get_image_as_np_array_with_frame(fname))
        else:
            plt.imshow(get_image_as_np_array(fname))
        # let's disable the axis
        plt.axis("off")

In [None]:
# show example images for each cluster
example_images = [
    '0a207830d8cca27de4be.jpg',
    '000bd5e82eb22f199f44.jpg',
    '4193ebdc9a860f646a40.jpg',
    '0cd8af895677b51c5897.jpg',
    '0a44488ae1db7d79d033.jpg',
]

for i, example_image in enumerate(example_images):
    plot_nearest_neighbors_3x3(example_image, i)

人物画・花・生き物などがそれぞれ類似画像として提案されているようです

## 学習済みモデルの活用

SimSiamによる学習済みのResNet18は`model.backbone`として呼び出すことができます  。  
これを用いて更に教師あり学習を行ったり、学習済みembeddingを推論に活用することが可能です。  
（教師あり学習を行う場合は最終層にLinear層を追加する必要があることに注意してください）

In [None]:
model.backbone