In [None]:
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop, RandomRotation, ColorJitter, RandomGrayscale, RandomApply
from torch.utils.data import DataLoader

import timm
from tqdm import tqdm

import torch
import torch.nn as nn

from dataset import load_full_isic

import math

In [13]:
timm.create_model("deit_base_patch16_224")

VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity(

In [None]:
#### CONFIGURATION ####
epochs = 100
num_workers = 0
batch_size = 64
pin_memory = False
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [None]:
class SimSiamAugmentations:
    def __init__(self, global_crops_scale=(0.2, 1.0), size=224):
        self.global_crops_scale = global_crops_scale
        self.image_size = size

        self.augmentations = Compose([
            RandomHorizontalFlip(),
            RandomResizedCrop(self.image_size, scale=global_crops_scale),
            RandomRotation(10),
            RandomApply([ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.5),
            RandomGrayscale(p=0.2),
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        return self.augmentations(x), self.augmentations(x)


norm_only = Compose([
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset, _ = load_full_isic(SimSiamAugmentations(), norm_only)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

In [None]:
class SimSiamWrapper(nn.Module):
    def __init__(self, base_encoder, dim, pred_dim):
        super(SimSiamWrapper, self).__init__()

        self.encoder = base_encoder 
        self.encoder.head = nn.Identity() # if we remove the head we should be able to use this as is

        self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False),
                                        nn.BatchNorm1d(pred_dim),
                                        nn.ReLU(inplace=True), # hidden layer
                                        nn.Linear(pred_dim, dim))
        
    def forward(self, x1, x2):
        z1 = self.encoder(x1)
        z2 = self.encoder(x2)

        p1 = self.predictor(z1)
        p2 = self.predictor(z2)

        return p1, p2, z1.detach(), z2.detach() # detach the z's as a stop-gradient

In [None]:
base_encoder, dim = timm.create_model('deit_tiny_patch16_224', pretrained=False), 192
model = SimSiamWrapper(base_encoder, dim, 512).to(device)
model.train()

In [None]:
criterion = nn.CosineSimilarity(dim=1).to(device)
lr = 0.05 * batch_size / 256
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
def adjust_learning_rate(optimizer, init_lr, epoch, args):
    """Decay the learning rate based on schedule"""
    cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / epochs))
    for param_group in optimizer.param_groups:
        if 'fix_lr' in param_group and param_group['fix_lr']:
            param_group['lr'] = init_lr
        else:
            param_group['lr'] = cur_lr

In [None]:
losses = []
for e in range(epochs):
    with tqdm(train_loader, unit='batch') as t:
        t.set_description(f"Epoch {e+1}")
        for images, _ in t:
            x1, x2 = images[0].to(device), images[1].to(device)

            p1, p2, z1, z2 = model(x1, x2)

            loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5
            
            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t.set_postfix(loss=loss.item())
        
        adjust_learning_rate(optimizer, lr, e, epochs)