In [8]:
import torch
import torch.nn as nn
import pathlib
import numpy as np
import torchvision.transforms as transforms

from tqdm import tqdm as tqdm
from simclr import SimCLR
from torchvision.datasets import CIFAR10
from torch.utils.data import random_split
from torch.utils.data.dataloader import DataLoader
from flash.core.optimizers import LARS


from torch.utils.tensorboard import SummaryWriter
from simclr.modules.transformations import TransformsSimCLR
from simclr.modules import NT_Xent


In [9]:
image_size = 224
batch_size = 128

# To make it work in both Jupyter and standalone:
if "__file__" in globals():
    root = pathlib.Path(__file__).parent.resolve()
else:
    # Probably running interactively; in Jupyter, notebook path is
    # typically 'os.getcwd()', if it's not that's where we are going
    # to store the CIFAR data.
    import os
    root = pathlib.Path(os.getcwd())
    
    
dataset = CIFAR10(root=root, download=True, transform = TransformsSimCLR(size = image_size))
torch.manual_seed(43)
train_loader = DataLoader(dataset, 
                          batch_size, 
                          shuffle=False,
                          drop_last = True,
                          num_workers=2,
                          sampler = None)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to C:\Users\TestAccount\Desktop\repos\DD2412project\cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting C:\Users\TestAccount\Desktop\repos\DD2412project\cifar-10-python.tar.gz to C:\Users\TestAccount\Desktop\repos\DD2412project


In [10]:
global_step = 0
epochs = 50
temperature = 0.2
projection_dim = 64

encoder = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False) 
n_features = encoder.fc.in_features  # get dimensions of last fully-connected layer
model = SimCLR(encoder, projection_dim, n_features)

Downloading: "https://github.com/pytorch/vision/archive/v0.10.0.zip" to C:\Users\TestAccount/.cache\torch\hub\v0.10.0.zip


In [11]:
optimizer = LARS(model.parameters(), lr = 0.075 * np.sqrt(batch_size), weight_decay = 1e-6)
criterion = NT_Xent(batch_size, temperature = temperature, world_size=1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                                                       epochs, 
                                                       eta_min=0, 
                                                       last_epoch=-1)

In [12]:
writer = SummaryWriter()

In [48]:
from typing import Callable, Iterator, Dict
from torch import linalg as LA
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

patch_typeguard()  # use before @typechecked

import numpy as np
from scipy.stats import ortho_group

@typechecked
def random_ortho_matrix_gen(dim_a: int, dim_b: int) -> Iterator[TensorType["dim_a", "dim_b"]]:
    assert dim_a >= dim_b, f"Assuming we want projection matrices (broad and short), got dims {dim_a} x {dim_b}"
    while True:
        m = ortho_group.rvs(dim=dim_a)
        for i in range(dim_a // dim_b):
            yield torch.Tensor(m[i*dim_b : (i+1)*dim_b])

@typechecked
def prior_sampler(name: str, batch_size: int, feature_dim: int) -> Callable[[], 
                                                                            TensorType["batch_size", "feature_dim"]]:
    """
    Constructs a sampling function from a named distribution. E.g., with `name=="Uniform hypersphere"`, 
    the resulting function samples `batch_size` vectors of length `feature_dim` on a uniform hypersphere.
    """
    @typechecked
    def hypersphere_sampler() -> Callable[[], TensorType[batch_size, feature_dim]]:
        X = torch.normal(mean=0, std=1, size=(batch_size, feature_dim))
        return X / LA.norm(X, dim=1).unsqueeze(1)
    
    d : Dict[str,  Callable[[], TensorType["batch_size", "feature_dim"]]] = {
         "Uniform hypersphere": hypersphere_sampler, 
         "Uniform hypercube": lambda: torch.rand(size=(batch_size, feature_dim)),
         "Normal distribution": lambda: torch.normal(mean=0, std=1, size=(batch_size, feature_dim)),
        }
    if name not in d:
        raise ValueError(f"Distr '{name}' not in {d.keys()}")
    return d[name]

class SWD_contrastiveloss(nn.Module):
    @typechecked
    def __init__(self, batch_size: int, feature_dim: int, prior_name: str,
                 normalize_before_align: bool = True,
                 SWD_dim: int=-1, SWD_lambda: float = 1.):
        super(SWD_contrastiveloss, self).__init__()
        self.batch_size = batch_size
        self.feature_dim = feature_dim
        #self.temperature = temperature
        self.normalize_before_align = normalize_before_align
        self.sample_prior = prior_sampler(prior_name, batch_size=2*batch_size, feature_dim=feature_dim)
        self.ortho_matrix_gen = random_ortho_matrix_gen(feature_dim, SWD_dim)
        self.lmbda = SWD_lambda
        
    @typechecked
    def forward(self, z_i: TensorType["batch_size", "feature_dim"],
                z_j: TensorType["batch_size", "feature_dim"]):
        # Following "Algorithm 1" in the paper
        n : int = self.batch_size
        d = z_i.size(dim = -1)
        
        # Project zi/zj onto hypersphere (i.e. normalize).
        if self.normalize_before_align:
            zi /= LA.norm(zi, dim=1).unsqueeze(1)
            zj /= LA.norm(zj, dim=1).unsqueeze(1)
        loss_align = ((zi - zj)**2).sum() / (n*d)
        Z : TensorType[2*self.batch_size, self.feature_dim] = torch.cat((z_i, z_j), dim=0)
        P : TensorType[2*self.batch_size, self.feature_dim] = self.sample_prior()
        W : TensorType[self.feature_dim, self.SWD_dim] = self.ortho_matrix_gen()
        H_perp, P_perp = Z @ W, P @ W
        loss_distr = torch.Tensor(0.)
        for j in range(self.SWD_dim):
            hj, pj = H_perp[:, j], P_perp[:, j]
            hj, _ = torch.sort(hj)
            pj, _ = torch.sort(pj)
            loss_distr += ((hj - pj)**2).sum()
        loss_distr /= (self.feature_dim * self.SWD_dim)
        return loss_align + self.lmbda * loss_distr
    
swd_sphere_crit = SWD_contrastiveloss(batch_size, feature_dim=64, SWD_lambda=5, SWD_dim=64, prior_name = "Uniform hypersphere")
swd_normal_crit = SWD_contrastiveloss(batch_size, normalize_before_align=False, feature_dim=64, SWD_lambda=5, SWD_dim=64, prior_name = "Normal distribution")
swd_cube_crit = SWD_contrastiveloss(batch_size, normalize_before_align=False, feature_dim=64, SWD_lambda=5, SWD_dim=64, prior_name = "Uniform hypercube")

In [None]:
def train(global_step, loader, model, criterion, optimizer, writer):
    loss_epoch = 0
    for steps, ((i, j), _) in enumerate(loader):
    
        optimizer.zero_grad()
        h_i, h_j, z_i, z_j = model(i, j)
        loss = criterion(z_i, z_j)
        loss.backward()
        optimizer.step()

        if steps % 50 == 0:
            print(f"Step [{steps}/{len(loader)}]\t Loss: {loss.item()}")

        writer.add_scalar("Loss/train_epoch", loss.item(), global_step)
        loss_epoch += loss.item()
        global_step += 1
    return loss_epoch

for epoch in tqdm(range(epochs)):
    loss_epoch = train(global_step, train_loader, model, swd_cube_crit, optimizer, writer)
    scheduler.step()
    writer.add_scalar("Loss/train", loss_epoch / len(loader), epoch)
    print(
        f"Epoch [{epoch}/{epochs}]\t Loss: {loss_epoch / len(train_loader)}\t"
    )

  0%|                                                                                           | 0/50 [00:00<?, ?it/s]