In [1]:
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip, RandomResizedCrop, RandomRotation, ColorJitter, RandomGrayscale, RandomApply, Resize
from torch.utils.data import DataLoader, Subset

import timm
from tqdm import tqdm

import torch
import torch.nn as nn

from dataset import load_full_isic, MedMNIST

import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#### CONFIGURATION ####
epochs = 100
num_workers = 0
batch_size = 64
pin_memory = False
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [3]:
class SimSiamAugmentations:
    def __init__(self, global_crops_scale=(0.2, 1.0), size=224):
        self.global_crops_scale = global_crops_scale
        self.image_size = size

        self.augmentations = Compose([
            RandomHorizontalFlip(),
            RandomResizedCrop(self.image_size, scale=global_crops_scale),
            RandomRotation(10),
            RandomApply([ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.5),
            RandomGrayscale(p=0.2),
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        return self.augmentations(x), self.augmentations(x)


norm_only = Compose([
    Resize((224, 224)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset, train_knn_dataset = load_full_isic(SimSiamAugmentations(), norm_only)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

In [4]:
train, val, test = MedMNIST(64, "breastmnist", SimSiamAugmentations(), norm_only).get_loaders()

Using downloaded and verified file: /Users/lukaskuhn/.medmnist/breastmnist_224.npz
Using downloaded and verified file: /Users/lukaskuhn/.medmnist/breastmnist_224.npz
Using downloaded and verified file: /Users/lukaskuhn/.medmnist/breastmnist_224.npz


In [27]:
total_train_size = 10000

targets = torch.tensor(train_knn_dataset.labels)
# get all indices where targets is 1
positive_indices = torch.where(targets == 1)[0]

# get 10.000 random indices where targets is 0
negative_indices = torch.where(targets == 0)[0]
negative_indices_train = negative_indices[torch.randperm(negative_indices.size(0))[:total_train_size-len(positive_indices)]]

# combine positive and negative indices
indices = torch.cat([positive_indices, negative_indices_train])

train_dataset = Subset(train_knn_dataset, indices)

# get 50% of the positive indices
positive_indices_knn_val = positive_indices[torch.randperm(positive_indices.size(0))[:len(positive_indices)//2]]

# fill up to 1000 indices with negative indices
negative_indices_knn_val = negative_indices[torch.randperm(negative_indices.size(0))[:1000-len(positive_indices_knn_val)]]

knn_val_dataset = Subset(train_knn_dataset, torch.cat([positive_indices_knn_val, negative_indices_knn_val]))

# get the rest of the positive indices
positive_indices_knn_train = positive_indices[torch.randperm(positive_indices.size(0))[len(positive_indices)//2:]]

# fill up to 5000 indices with negative indices
negative_indices_knn_train = negative_indices[torch.randperm(negative_indices.size(0))[1000-len(positive_indices_knn_val):(5000-len(positive_indices_knn_train))+(1000-len(positive_indices_knn_val))]]

knn_train_dataset = Subset(train_knn_dataset, torch.cat([positive_indices_knn_train, negative_indices_knn_train]))

In [182]:
linear_probing_train = torch.cat([negative_indices[torch.randperm(negative_indices.size(0))[:393]], positive_indices])
linear_probing_dataset = Subset(train_knn_dataset, linear_probing_train)

linear_probing_loader = DataLoader(linear_probing_dataset, batch_size=64, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)

In [174]:
knn_train_loader, knn_val_loader = DataLoader(knn_train_dataset, batch_size=64), DataLoader(knn_val_dataset, batch_size=64)

In [None]:
class SimSiamWrapper(nn.Module):
    def __init__(self, base_encoder, dim, pred_dim):
        super(SimSiamWrapper, self).__init__()

        self.encoder = base_encoder 
        self.encoder.head = nn.Identity() # if we remove the head we should be able to use this as is

        self.predictor = nn.Sequential(nn.Linear(dim, pred_dim, bias=False),
                                        nn.BatchNorm1d(pred_dim),
                                        nn.ReLU(inplace=True), # hidden layer
                                        nn.Linear(pred_dim, dim))
        
    def forward(self, x1, x2):
        z1 = self.encoder(x1)
        z2 = self.encoder(x2)

        p1 = self.predictor(z1)
        p2 = self.predictor(z2)

        return p1, p2, z1.detach(), z2.detach() # detach the z's as a stop-gradient

In [None]:
base_encoder, dim = timm.create_model('deit_tiny_patch16_224', pretrained=False), 192
model = SimSiamWrapper(base_encoder, dim, 512).to(device)
model.train()

In [None]:
criterion = nn.CosineSimilarity(dim=1).to(device)
lr = 0.05 * batch_size / 256
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
def adjust_learning_rate(optimizer, init_lr, epoch, args):
    """Decay the learning rate based on schedule"""
    cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / epochs))
    for param_group in optimizer.param_groups:
        if 'fix_lr' in param_group and param_group['fix_lr']:
            param_group['lr'] = init_lr
        else:
            param_group['lr'] = cur_lr

In [None]:
losses = []
for e in range(epochs):
    with tqdm(train_loader, unit='batch') as t:
        t.set_description(f"Epoch {e+1}")
        for images, _ in t:
            x1, x2 = images[0].to(device), images[1].to(device)

            p1, p2, z1, z2 = model(x1, x2)

            loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5
            
            losses.append(loss.item())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            t.set_postfix(loss=loss.item())
        
        adjust_learning_rate(optimizer, lr, e, epochs)

In [42]:
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score


def compute_knn(backbone, data_loader_train, data_loader_val):
    device = next(backbone.parameters()).device

    data_loaders = {
        "train": data_loader_train,
        "val": data_loader_val,
    }

    lists = {
        "X_train": [],
        "y_train": [],
        "X_val": [],
        "y_val": [],
    }

    for name, data_loader in data_loaders.items():
        for imgs, y in data_loader:
            imgs = imgs.to(device)
            lists[f"X_{name}"].append(backbone(imgs).detach().cpu().numpy())
            lists[f"y_{name}"].append(y.detach().cpu().numpy())

    arrays = {k: np.concatenate(l) for k,l in lists.items()}
    
    estimator = KNeighborsClassifier(1)
    estimator.fit(arrays["X_train"], arrays["y_train"])
    y_val_pred = estimator.predict(arrays["X_val"])

    acc = accuracy_score(arrays["y_val"], y_val_pred)

    return acc, y_val_pred

In [176]:
base_encoder, dim = timm.create_model('deit_tiny_patch16_224', pretrained=False), 192
base_encoder.head = nn.Identity()

base_encoder.load_state_dict(torch.load("checkpoints/model_44.pt", map_location=device))
base_encoder.to(device)

  base_encoder.load_state_dict(torch.load("checkpoints/model_44.pt", map_location=device))


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=192, out_features=576, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=192, out_features=192, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=192, out_features=768, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): Identity()


In [177]:
acc, y_val_pred = compute_knn(base_encoder, knn_train_loader, knn_val_loader)

In [180]:
acc, y_val_pred.sum()

(0.878, np.int64(136))

In [183]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AdvancedNN(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim, dropout_rate=0.5):
        super(AdvancedNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim1)
        self.bn1 = nn.BatchNorm1d(hidden_dim1)
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.bn2 = nn.BatchNorm1d(hidden_dim2)
        self.fc3 = nn.Linear(hidden_dim2, output_dim)
        self.dropout = nn.Dropout(dropout_rate)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Hyperparameters
input_dim = 192  # Dimension of representations
hidden_dim1 = 256  # Number of neurons in the first hidden layer
hidden_dim2 = 128   # Number of neurons in the second hidden layer
output_dim = 2     # Number of classes
dropout_rate = 0.2 # Dropout rate

# Initialize the model
model = AdvancedNN(input_dim, hidden_dim1, hidden_dim2, output_dim, dropout_rate)
model.to(device)

AdvancedNN(
  (fc1): Linear(in_features=192, out_features=256, bias=True)
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=256, out_features=128, bias=True)
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc3): Linear(in_features=128, out_features=2, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [184]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training loop
model.train()

for epoch in range(100):
    epoch_loss = []
    with tqdm(linear_probing_loader, unit="batch") as t:
        for x, y in t:
            x = base_encoder(x.to(device)).detach()
            y = y.to(device)

            optimizer.zero_grad()
            y_hat = model(x)
            loss = F.cross_entropy(y_hat, y)
            epoch_loss.append(loss.item())
            loss.backward()
            optimizer.step()
            t.set_postfix(loss=np.mean(epoch_loss))

100%|██████████| 13/13 [00:04<00:00,  3.23batch/s, loss=0.736]
100%|██████████| 13/13 [00:02<00:00,  4.94batch/s, loss=0.73] 
100%|██████████| 13/13 [00:02<00:00,  5.00batch/s, loss=0.711]
100%|██████████| 13/13 [00:02<00:00,  5.02batch/s, loss=0.703]
100%|██████████| 13/13 [00:02<00:00,  4.98batch/s, loss=0.701]
100%|██████████| 13/13 [00:02<00:00,  4.96batch/s, loss=0.692]
  8%|▊         | 1/13 [00:00<00:05,  2.37batch/s, loss=0.713]


KeyboardInterrupt: 