# Utils

In [None]:
# utils
from torchvision import datasets
from tqdm import tqdm
from torch.utils.data import DataLoader
from datetime import datetime
import wandb
import torch
import torchvision.transforms as T
from torchvision import models
from torch import optim
import torch.nn as nn
from dataclasses import dataclass
from torch.optim.lr_scheduler import StepLR, MultiStepLR, CosineAnnealingLR
import torch.nn.functional as F


def set_random_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


class WanDBWriter:
    def __init__(self, config):
        self.writer = None
        self.selected_module = ""

        wandb.login()

        if not hasattr(config, 'wandb_project'):
            raise ValueError("please specify project name for wandb")

        wandb.init(
            project=getattr(config, 'wandb_project'),
            config=config
        )
        self.wandb = wandb

        self.step = 0
        self.mode = ""
        self.timer = datetime.now()

    def set_step(self, step, mode="train"):
        self.mode = mode
        self.step = step
        if step == 0:
            self.timer = datetime.now()
        else:
            duration = datetime.now() - self.timer
            self.add_scalar("steps_per_sec", 1 / duration.total_seconds())
            self.timer = datetime.now()

    def finish(self):
        wandb.finish()

    def scalar_name(self, scalar_name):
        return f"{self.mode}/{scalar_name}"

    def watch_model(self, model, criterion=None):
        self.wandb.watch(model)

    def add_scalar(self, scalar_name, scalar):
        self.wandb.log({
            self.scalar_name(scalar_name): scalar,
        }, step=self.step)

    def add_scalars(self, tag, scalars):
        self.wandb.log({
            **{f"{scalar_name}_{tag}_{self.mode}": scalar for scalar_name, scalar in scalars.items()}
        }, step=self.step)

    def add_image(self, scalar_name, image):
        self.wandb.log({
            self.scalar_name(scalar_name): self.wandb.Image(image)
        }, step=self.step)

    def add_audio(self, scalar_name, audio, sample_rate=None):
        # audio = audio.detach().cpu().numpy().T
        audio = audio.T
        self.wandb.log({
            self.scalar_name(scalar_name): self.wandb.Audio(audio, sample_rate=sample_rate)
        }, step=self.step)

    def add_text(self, scalar_name, text):
        self.wandb.log({
            self.scalar_name(scalar_name): self.wandb.Html(text)
        }, step=self.step)

    def add_pr_curve(self, scalar_name, scalar):
        raise NotImplementedError()

    def add_embedding(self, scalar_name, scalar):
        raise NotImplementedError()

# Supervised Baseline

In [None]:
train_transforms = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    # T.Resize((h,w))
])

test_transforms = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
@dataclass
class SupervisedBaselineConfig:
    wandb_project = 'SLL_HW2'
    num_workers = 2
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407

    num_epochs = 90
    save_epochs = 10
    batch_size = 256
    optim = 'SGD'
    lr = 0.1
    momentum = 0.9
    nesterov = False
    weight_decay = 1e-4

    scheduler = 'MultiStepLR'
    milestones = [30, 50, 70, 80]
    gamma = 0.1

In [None]:
config = SupervisedBaselineConfig()
set_random_seed(config.seed)
train_dataset = datasets.STL10('data', 'train', download=True, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)
test_dataset = datasets.STL10('data', 'test', download=True, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True)

model = models.resnet18(num_classes=10)
model = model.to(config.device)

# loss, optimizer and hyperparameters
current_step = 0
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum,
#                       weight_decay=config.weight_decay, nesterov=True)
optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum,
                      weight_decay=config.weight_decay, nesterov=config.nesterov)
scheduler = MultiStepLR(optimizer, milestones=config.milestones, gamma=config.gamma)
scaler = torch.cuda.amp.GradScaler()
logger = WanDBWriter(config)
# train
tqdm_bar = tqdm(total=config.num_epochs * len(train_loader) - current_step)
for epoch in range(config.num_epochs):
    accuracy, loss = 0, 0
    for i, (imgs, labels) in enumerate(train_loader):
        current_step += 1
        tqdm_bar.update(1)
        logger.set_step(current_step)

        imgs, labels = imgs.to(config.device), labels.to(config.device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        # loss.backward()
        # optimizer.step()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        _, predicted = torch.max(outputs.data, 1)
        accuracy += (predicted == labels).sum().item()
        loss += loss.item()
    
    scheduler.step()
    logger.add_scalar('train/loss', loss / len(train_loader))
    logger.add_scalar('train/accuracy', accuracy / len(train_loader) / config.batch_size)
    # logger.add_scalar('lr', scheduler.get_last_lr())
    logger.add_image(f'train/img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image(f'train/img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image(f'train/img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)

    # evaluate
    model.eval()
    accuracy, loss = 0, 0
    for i, (imgs, labels) in enumerate(test_loader):
        imgs, labels = imgs.to(config.device), labels.to(config.device)

        with torch.cuda.amp.autocast():
            with torch.no_grad():
                outputs = model(imgs)
                loss = criterion(outputs, labels)
        
        _, predicted = torch.max(outputs.data, 1)
        accuracy += (predicted == labels).sum().item()
        loss += loss.item()
    
    logger.add_scalar('test/loss', loss / len(test_loader))
    logger.add_scalar('test/accuracy', accuracy / len(test_loader) / config.batch_size)
    logger.add_image(f'test/img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image(f'test/img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image(f'test/img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    model.train()

    if config.save_epochs != 0 and epoch % config.save_epochs == config.save_epochs - 1:
        torch.save(model.state_dict(), f'model_{epoch}.pth')

logger.finish()

# SimCLR

In [None]:
s = 1.0
size = 96
color_jitter = T.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
rnd_color_jitter = T.RandomApply([color_jitter], p=0.8)
rnd_gray = T.RandomGrayscale(p=0.2)

train_transforms = T.Compose([
    T.RandomResizedCrop(size),
    T.RandomHorizontalFlip(p=0.5),
    rnd_color_jitter,
    rnd_gray,
    T.GaussianBlur(kernel_size=int(0.1 * 96)),
    T.ToTensor(),
])

class ContrastiveImages():
    def __init__(self, transform):
        self.transform = transform
        self.n_views = 2

    def __call__(self, img):
        return [self.transform(img) for i in range(self.n_views)]

In [None]:
class NCE_loss(nn.Module):
    def __init__(self, config, temperature=0.1):
        super(NCE_loss, self).__init__()
        self.config = config
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss()
        self.n_views = 2
        self.shape = self.config.batch_size * self.n_views
        self.diagonal = torch.eye(self.shape, dtype=torch.bool, device=self.config.device)
        # print(self.diagonal.device, 'hey')
    
    def forward(self, outputs):
        labels_matrix = [torch.arange(self.config.batch_size) for i in range(self.n_views)]
        labels_matrix = torch.cat(labels_matrix, 0)
        labels_matrix = (labels_matrix.unsqueeze(0) == labels_matrix.unsqueeze(1)).to(self.config.device)
        # labels_matrix.to(self.config.device)
        # print(self.config.device, labels_matrix.device, self.diagonal.device)
        # print(labels_matrix.shape)
        labels_matrix = labels_matrix[~self.diagonal].view(self.shape, -1)

        outputs = F.normalize(outputs)

        similarity = outputs @ outputs.T
        # print(similarity.shape)
        similarity = similarity[~self.diagonal].view(self.shape, -1)

        negative = similarity[~labels_matrix.bool()].view(self.shape, -1) / self.temperature
        positive = similarity[labels_matrix.bool()].view(self.shape, -1) / self.temperature

        logits = torch.cat([positive, negative], dim=1)
        labels = torch.zeros(self.shape, dtype=torch.long).to(self.config.device)
        loss = self.criterion(logits, labels)
        
        return loss, logits, labels

In [None]:
@dataclass
class SimCLR_config:
    wandb_project = 'SLL_HW2'
    num_workers = 8
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407
    
    num_epochs = 200
    save_epochs = 10
    eval_epochs = 10
    batch_size = 256
    optim = 'Adam'
    lr = 3e-4
    weight_decay = 1e-4
    
    num_features = 512

    scheduler = 'CosineAnnealingLR'

In [None]:
config = SimCLR_config()
set_random_seed(config.seed)

train_dataset = datasets.STL10('data', 'unlabeled', download=True, transform=ContrastiveImages(train_transforms))
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)

test_dataset = datasets.STL10('data', 'train', download=True, transform=ContrastiveImages(train_transforms))
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)

model = models.resnet18(num_classes=config.num_features)
in_features = model.fc.in_features
projection_g = nn.Sequential(
    nn.Linear(in_features, in_features),
    nn.ReLU(),
    model.fc
)
model.fc = projection_g
model.to(config.device)

optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader))

# loss, optimizer and hyperparameters
current_step = 0
criterion = NCE_loss(config)
scaler = torch.cuda.amp.GradScaler()
logger = WanDBWriter(config)

tqdm_bar = tqdm(total=config.num_epochs * len(train_loader) - current_step)
for epoch in range(config.num_epochs):
    for i, (imgs, _) in enumerate(train_loader):
        current_step += 1
        tqdm_bar.update(1)
        logger.set_step(current_step)
        imgs = torch.cat(imgs, dim=0).to(config.device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss, logits, labels, = criterion(outputs)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        _, predicted = torch.max(logits.data, 1)
        accuracy = (predicted == ).sum().item() / len(labels)

        logger.add_scalar('accuracy', accuracy)
        logger.add_scalar('loss', loss.item())
    
    if epoch >= 9:
        scheduler.step()
    
    logger.add_scalar('lr', optimizer.param_groups[0]["lr"])
    logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    
    if config.eval_epochs != 0 and epoch % config.eval_epochs == config.eval_epochs - 1:
        model.eval()
        current_step_test = 0
        for i, (imgs, labels) in enumerate(test_loader):
            current_step_test += 1
            logger.set_step(current_step, 'test')
            imgs = torch.cat(imgs, dim=0).to(config.device)

            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    outputs = model(imgs)
                    loss, logits, labels, = criterion(outputs)

            _, predicted = torch.max(logits.data, 1)
            accuracy = (predicted == labels).sum().item() / len(labels)

            logger.add_scalar('loss', loss)
            logger.add_scalar('accuracy', accuracy)

        logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        model.train()

    if config.save_epochs != 0 and epoch % config.save_epochs == config.save_epochs - 1:
        torch.save(model.state_dict(), f'model_SimCLR_2_{epoch}.pth')

logger.finish()

# BYOL

In [None]:
s = 1.0
size = 96
color_jitter = T.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
rnd_color_jitter = T.RandomApply([color_jitter], p=0.8)
rnd_gray = T.RandomGrayscale(p=0.2)

train_transforms = T.Compose([
    T.RandomResizedCrop(size),
    T.RandomHorizontalFlip(p=0.5),
    rnd_color_jitter,
    rnd_gray,
    T.GaussianBlur(kernel_size=int(0.1 * 96)),
    T.ToTensor(),
])

class ContrastiveImages():
    def __init__(self, transform):
        self.transform = transform
        self.n_views = 2

    def __call__(self, img):
        return [self.transform(img) for i in range(self.n_views)]

In [None]:
@dataclass
class BYOL_config:
    wandb_project = 'SLL_HW2'
    num_workers = 16
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407

    num_epochs = 1000
    save_epochs = 20
    eval_epochs = 20
    batch_size = 512

    optim = 'Adam'
    lr = 3e-4
    weight_decay = 1e-4

    mlp_hidden_size = 4096
    projection_size = 256
    moving_average = 0.99

    model_save_name = "model_BYOL_oshibka"
    
    scheduler = 'CosineAnnealingLR'

In [None]:
class BYOL_loss(nn.Module):
    def __init__(self, online_model, offline_model, projection):
        super(BYOL_loss, self).__init__()
        self.online_model = online_model
        self.offline_model = offline_model
        self.model_predict = model_predict

    def regression_loss(self, x, y):
        x = F.normalize(x, dim=1)
        y = F.normalize(y, dim=1)
        return 2 - 2 * (x * y).sum(dim=-1)

    def forward(self, imgs_view1, imgs_view2):
        # print(0)
        online_network_out_1 = self.model_predict(self.online_model(imgs_view1))
        # print(1)
        z_std1 = self.online_model.z_std
        online_network_out_2 = self.model_predict(self.online_model(imgs_view2))
        z_std2 = self.online_model.z_std

        with torch.no_grad():
            offline_network_out_1 = self.offline_model(imgs_view1)
            offline_network_out_2 = self.offline_model(imgs_view2)

        loss1 = self.regression_loss(online_network_out_1, offline_network_out_2)
        loss2 = self.regression_loss(online_network_out_2, offline_network_out_1)
        # print(loss1.mean(), loss2.mean())

        return (loss1 + loss2).mean(), z_std1, z_std2

In [None]:
class BYOL_network(nn.Module):
    def __init__(self, config):
        super(BYOL_network, self).__init__()
        self.resnet = models.resnet18()
        self.in_features = self.resnet.fc.in_features
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])
        self.projection = nn.Sequential(
            nn.Linear(self.in_features, config.mlp_hidden_size),
            nn.BatchNorm1d(config.mlp_hidden_size),
            nn.ReLU(),
            nn.Linear(config.mlp_hidden_size, config.projection_size)
        )
        self.z_std = None
    
    def forward(self, img):
        # print(img.shape)
        z = self.resnet(img)
        # print('z', z.shape)
        z = z.view(z.shape[0], z.shape[1])
        # print('z', z.shape)
        self.z_std = z.std()
        q_z = self.projection(z)
        # print(q_z.shape)
        return q_z

In [None]:
config = BYOL_config()
set_random_seed(config.seed)
# loader
train_dataset = datasets.STL10('data', 'unlabeled', download=True, transform=ContrastiveImages(train_transforms))
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)

test_dataset = datasets.STL10('data', 'train', download=True, transform=ContrastiveImages(train_transforms))
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)
# models
model_online = BYOL_network(config)
model_offline = BYOL_network(config)
model_predict = nn.Sequential(
    nn.Linear(config.projection_size, config.mlp_hidden_size),
    nn.BatchNorm1d(config.mlp_hidden_size),
    nn.ReLU(),
    nn.Linear(config.mlp_hidden_size, config.projection_size)
)
model_online.to(config.device)
model_offline.to(config.device)
model_predict.to(config.device)

for param_online, param_offline in zip(model_online.parameters(), model_offline.parameters()):
    param_offline.data.copy_(param_online.data)
    param_offline.requires_grad = False
# loss, optimizer and hyperparameters
optimizer = optim.Adam(list(model_online.parameters()) + list(model_predict.parameters()),
                       lr=config.lr, weight_decay=config.weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader))
current_step = 0
criterion = BYOL_loss(model_online, model_offline, model_predict)
scaler = torch.cuda.amp.GradScaler()
logger = WanDBWriter(config)
# train
tqdm_bar = tqdm(total=config.num_epochs * len(train_loader) - current_step)
for epoch in range(config.num_epochs):
    for i, ((imgs_view1, imgs_view2), label) in enumerate(train_loader):
        current_step += 1
        tqdm_bar.update(1)
        logger.set_step(current_step)
        imgs_view1, imgs_view2 = imgs_view1.to(config.device), imgs_view2.to(config.device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            loss, z_std1, z_std2 = criterion(imgs_view1, imgs_view2)
            # print(loss)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        logger.add_scalar('loss', loss.item())
        logger.add_scalar('std of z1', z_std1.item())
        logger.add_scalar('std of z2', z_std2.item())
        
        with torch.no_grad():
            for param_online, param_offline in zip(model_online.parameters(), model_offline.parameters()):
                param_offline.data = (param_offline.data * config.moving_average
                                      + param_online.data * (1.0 - config.moving_average))

    if epoch >= 9:
        scheduler.step()

    logger.add_scalar('lr', optimizer.param_groups[0]["lr"])
    logger.add_image('img0', imgs_view1[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img1', imgs_view2[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)

    # evaluate
    if config.eval_epochs != 0 and epoch % config.eval_epochs == config.eval_epochs - 1:
        model_online.eval()
        model_predict.eval()
        current_step_test = 0
        for i, ((imgs_view1, imgs_view2), label) in enumerate(train_loader):
            current_step_test += 1
            logger.set_step(current_step, 'test')
            imgs_view1, imgs_view2 = imgs_view1.to(config.device), imgs_view2.to(config.device)

            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    loss, z_std1, z_std2 = criterion(imgs_view1, imgs_view2)

            logger.add_scalar('loss', loss.item())
            logger.add_scalar('std of z1', z_std1.item())
            logger.add_scalar('std of z2', z_std2.item())

        logger.add_image('img0', imgs_view1[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img1', imgs_view2[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        model_online.train()
        model_predict.train()

    if config.save_epochs != 0 and epoch % config.save_epochs == config.save_epochs - 1:
        torch.save(model_online.state_dict(), f'{config.model_save_name}_{epoch}.pth')

logger.finish()

# MOCO

In [None]:
s = 1.0
size = 96
color_jitter = T.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
rnd_color_jitter = T.RandomApply([color_jitter], p=0.8)
rnd_gray = T.RandomGrayscale(p=0.2)

train_transforms = T.Compose([
    T.RandomResizedCrop(size),
    T.RandomHorizontalFlip(p=0.5),
    rnd_color_jitter,
    rnd_gray,
    T.GaussianBlur(kernel_size=int(0.1 * 96)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

class ContrastiveImages():
    def __init__(self, transform):
        self.transform = transform
        self.n_views = 2

    def __call__(self, img):
        return [self.transform(img) for i in range(self.n_views)]

In [None]:
@dataclass
class MOCO_config:
    wandb_project: str = 'SLL_HW2'
    num_workers: int = 16
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed: int = 3407

    num_epochs: int = 200
    save_epochs: int = 10
    eval_epochs: int = 10
    batch_size: int = 1024
    dim: int = 128
    # K: int = 65536
    K: int = 16384
    temperature: float = 0.07
    moving_average: float = 0.999

    optim: str = 'SGD'
    lr: float = 0.03
    momentum: float = 0.9
    weight_decay: float = 1e-4
    
    model_save_name: str = 'MOCO_2'

In [None]:
class MOCO_network(nn.Module):
    def __init__(self, config):
        super(MOCO_network, self).__init__()

        self.model_q = models.resnet18(num_classes=config.dim)
        self.model_k = models.resnet18(num_classes=config.dim)

        in_features = self.model_q.fc.in_features
        self.model_q.fc = nn.Sequential(
            nn.Linear(in_features, in_features),
            nn.ReLU(),
            nn.Linear(in_features, config.dim)
        )
        self.model_k.fc = nn.Sequential(
            nn.Linear(in_features, in_features),
            nn.ReLU(),
            nn.Linear(in_features, config.dim)
        )

        for param_q, param_k in zip(self.model_q.parameters(), self.model_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False
        
        # self.queue = torch.randn(config.dim, config.K)
        self.register_buffer("queue", torch.randn(config.dim, config.K))
        self.queue = F.normalize(self.queue, dim=0)
        # self.queue_ptr = torch.zeros(1, dtype=torch.long)
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

        self.N = config.batch_size
        self.C = config.dim
        self.T = config.temperature
        self.device = config.device
        self.K = config.K

    def forward(self, q_imgs, k_imgs):
        q = F.normalize(self.model_q(q_imgs), dim=1)

        with torch.no_grad():
            for param_q, param_k in zip(self.model_q.parameters(), self.model_k.parameters()):
                param_k.data = (param_k.data * config.moving_average + param_q.data * (1.0 - config.moving_average))
        
            # wo batch shuffling
            k = F.normalize(self.model_k(k_imgs), dim=1)

        l_pos = torch.bmm(q.view(self.N, 1, self.C), k.view(self.N, self.C, 1)).squeeze(-1)
        l_neg = torch.mm(q.view(self.N, self.C), self.queue.clone().detach().view(self.C, self.K))
        logits = torch.cat([l_pos, l_neg], dim=1)
        logits /= self.T

        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(self.device)

        assert self.K % self.N == 0
        ptr = int(self.queue_ptr)
        self.queue[:, ptr : ptr + self.N] = k.T
        ptr = (ptr + self.N) % self.K
        self.queue_ptr[0] = ptr

        return logits, labels

In [None]:
config = MOCO_config()
set_random_seed(config.seed)
# loader
train_dataset = datasets.STL10('data', 'unlabeled', download=True, transform=ContrastiveImages(train_transforms))
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)

test_dataset = datasets.STL10('data', 'train', download=True, transform=ContrastiveImages(train_transforms))
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)
# models
model = MOCO_network(config)
model.to(config.device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), config.lr, momentum=config.momentum, weight_decay=config.weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader))
torch.backends.cudnn.benchmark = True
current_step = 0
scaler = torch.cuda.amp.GradScaler()
logger = WanDBWriter(config)
tqdm_bar = tqdm(total=config.num_epochs * len(train_loader) - current_step)
model.train()
for epoch in range(config.num_epochs):
    for i, ((q_imgs, k_imgs), label) in enumerate(train_loader):
        current_step += 1
        tqdm_bar.update(1)
        logger.set_step(current_step)
        q_imgs, k_imgs = q_imgs.to(config.device), k_imgs.to(config.device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            output, target = model(q_imgs, k_imgs)
            loss = criterion(output, target)
            # print(loss)
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        _, predicted = torch.max(output.data, 1)
        accuracy = (predicted == target).sum().item() / len(target)

        logger.add_scalar('loss', loss.item())
        logger.add_scalar('accuracy', accuracy)
    
    if epoch >= 9:
        scheduler.step()

    logger.add_scalar('lr', optimizer.param_groups[0]["lr"])
    logger.add_image('img0', q_imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img1', k_imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)

    # evaluate
    if config.eval_epochs != 0 and epoch % config.eval_epochs == config.eval_epochs - 1:
        model.eval()
        current_step_test = 0
        for i, ((q_imgs, k_imgs), label) in enumerate(train_loader):
            current_step_test += 1
            logger.set_step(current_step, 'test')
            q_imgs, k_imgs = q_imgs.to(config.device), k_imgs.to(config.device)

            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    output, target = model(q_imgs, k_imgs)
                    loss = criterion(output, target)
        
            _, predicted = torch.max(output.data, 1)
            accuracy = (predicted == target).sum().item() / len(target)

            logger.add_scalar('loss', loss.item())
            logger.add_scalar('accuracy', accuracy)

        logger.add_image('img0', q_imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img1', k_imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        model.train()

    if config.save_epochs != 0 and epoch % config.save_epochs == config.save_epochs - 1:
        torch.save(model.state_dict(), f'{config.model_save_name}_{epoch}.pth')

logger.finish()

# Linear probing

## SimCLR

In [None]:
@dataclass
class SimCLR_LP_config:
    wandb_project = 'SLL_HW2'
    num_workers = 2
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407

    num_epochs = 200
    save_epochs = 20
    eval_epochs = 10
    batch_size = 256
    optim = 'Adam'
    lr = 3e-4
    weight_decay = 1e-4

    num_features = 512
    model_load_path = "model_SimCLR_2_89.pth"
    model_save_namne = "model_SimCLR_2_89_LP"
    # scheduler = 'CosineAnnealingLR'

In [None]:
config = SimCLR_LP_config()
set_random_seed(config.seed)
train_dataset = datasets.STL10('data', 'train', download=True, transform=T.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)
test_dataset = datasets.STL10('data', 'test', download=True, transform=T.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True)

model = models.resnet18(num_classes=config.num_features)
in_features = model.fc.in_features
projection_g = nn.Sequential(
    nn.Linear(in_features, in_features),
    nn.ReLU(),
    model.fc
)
model.fc = projection_g
model.load_state_dict(torch.load(config.model_load_path))
for param in model.parameters():
    param.requires_grad = False

model.fc = nn.Linear(in_features, 10)
model.to(config.device)

# loss, optimizer and hyperparameters
current_step = 0
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum,
#                       weight_decay=config.weight_decay, nesterov=True)
optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
# scheduler = StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
scaler = torch.cuda.amp.GradScaler()
logger = WanDBWriter(config)
# train
tqdm_bar = tqdm(total=config.num_epochs * len(train_loader) - current_step)
for epoch in range(config.num_epochs):
    accuracy, loss = 0, 0
    for i, (imgs, labels) in enumerate(train_loader):
        current_step += 1
        tqdm_bar.update(1)
        logger.set_step(current_step)

        imgs, labels = imgs.to(config.device), labels.to(config.device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        # loss.backward()
        # optimizer.step()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        _, predicted = torch.max(outputs.data, 1)
        accuracy += (predicted == labels).sum().item()
        loss += loss.item()
    
    # scheduler.step()
    logger.add_scalar('loss', loss / len(train_loader))
    logger.add_scalar('accuracy', accuracy / len(train_loader) / config.batch_size)
    logger.add_scalar('lr', optimizer.param_groups[0]["lr"])
    logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)

    # evaluate
    if config.eval_epochs != 0 and epoch % config.eval_epochs == config.eval_epochs - 1:
        model.eval()
        accuracy, loss = 0, 0
        current_step_test = 0
        for i, (imgs, labels) in enumerate(test_loader):
            current_step_test += 1
            logger.set_step(current_step, 'test')
            imgs, labels = imgs.to(config.device), labels.to(config.device)

            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    outputs = model(imgs)
                    loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            accuracy += (predicted == labels).sum().item()
            loss += loss.item()
        
        logger.add_scalar('loss', loss / len(test_loader))
        logger.add_scalar('accuracy', accuracy / len(test_loader) / config.batch_size)
        logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        model.train()

    if config.save_epochs != 0 and epoch % config.save_epochs == config.save_epochs - 1:
        torch.save(model.state_dict(), f'{config.model_save_namne}_{epoch}.pth')

logger.finish()

## BYOL

In [None]:
@dataclass
class BYOL_LP_config:
    wandb_project = 'SLL_HW2'
    num_workers = 2
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407

    num_epochs = 200
    save_epochs = 20
    eval_epochs = 10
    batch_size = 256

    optim = 'Adam'
    lr = 3e-4
    weight_decay = 1e-4

    mlp_hidden_size = 4096
    projection_size = 256
    moving_average = 0.99

    model_load_path = "model_BYOL_oshibka_219.pth"
    model_save_name = "model_BYOL_oshibka_219_LP"

    # scheduler = 'CosineAnnealingLR'

In [None]:
class BYOL_network_LP(nn.Module):
    def __init__(self, byol_model, in_features):
        super(BYOL_network_LP, self).__init__()
        self.byol_model = byol_model
        for param in self.byol_model.parameters():
            param.requires_grad = False
        self.fc = nn.Linear(in_features, 10)
    
    def forward(self, x):
        x = self.byol_model(x)
        x = x.view(x.shape[0], x.shape[1])
        return self.fc(x)

In [None]:
config = BYOL_LP_config()
set_random_seed(config.seed)
train_dataset = datasets.STL10('data', 'train', download=True, transform=T.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)
test_dataset = datasets.STL10('data', 'test', download=True, transform=T.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True)

model = BYOL_network(config)
model.load_state_dict(torch.load(config.model_load_path))
in_features = model.in_features
model = nn.Sequential(*list(model.children())[:-1])
model = BYOL_network_LP(model, in_features)
model.to(config.device)

# loss, optimizer and hyperparameters
current_step = 0
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum,
#                       weight_decay=config.weight_decay, nesterov=True)
optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
# scheduler = StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
scaler = torch.cuda.amp.GradScaler()
logger = WanDBWriter(config)
# train
tqdm_bar = tqdm(total=config.num_epochs * len(train_loader) - current_step)
for epoch in range(config.num_epochs):
    accuracy, loss = 0, 0
    for i, (imgs, labels) in enumerate(train_loader):
        current_step += 1
        tqdm_bar.update(1)
        logger.set_step(current_step)

        imgs, labels = imgs.to(config.device), labels.to(config.device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        # loss.backward()
        # optimizer.step()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        _, predicted = torch.max(outputs.data, 1)
        accuracy += (predicted == labels).sum().item()
        loss += loss.item()
    
    # scheduler.step()
    logger.add_scalar('loss', loss / len(train_loader))
    logger.add_scalar('accuracy', accuracy / len(train_loader) / config.batch_size)
    logger.add_scalar('lr', optimizer.param_groups[0]["lr"])
    logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)

    # evaluate
    if config.eval_epochs != 0 and epoch % config.eval_epochs == config.eval_epochs - 1:
        model.eval()
        accuracy, loss = 0, 0
        current_step_test = 0
        for i, (imgs, labels) in enumerate(test_loader):
            current_step_test += 1
            logger.set_step(current_step, 'test')
            imgs, labels = imgs.to(config.device), labels.to(config.device)

            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    outputs = model(imgs)
                    loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            accuracy += (predicted == labels).sum().item()
            loss += loss.item()
        
        logger.add_scalar('loss', loss / len(test_loader))
        logger.add_scalar('accuracy', accuracy / len(test_loader) / config.batch_size)
        logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        model.train()

    if config.save_epochs != 0 and epoch % config.save_epochs == config.save_epochs - 1:
        torch.save(model.state_dict(), f'{config.model_save_name}_{epoch}.pth')

logger.finish()

## MOCO

In [None]:
@dataclass
class MOCO_LP_config:
    wandb_project: str = 'SLL_HW2'
    num_workers: int = 2
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed: int = 3407

    num_epochs: int = 200
    save_epochs: int = 20
    eval_epochs: int = 10
    batch_size: int = 256
    optim: str = 'Adam'
    lr: float = 3e-4
    weight_decay: float = 1e-4

    dim: int = 128
    K: int = 16384
    temperature: float = 0.07
    moving_average: float = 0.999

    model_load_path: str = "MOCO_2_199.pth"
    model_save_name: str = "MOCO_2_199_LP"

In [None]:
config = MOCO_LP_config()
set_random_seed(config.seed)

transforms = T.Compose([
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = datasets.STL10('data', 'train', download=True, transform=transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)
test_dataset = datasets.STL10('data', 'test', download=True, transform=transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True)

# models
model = MOCO_network(config)
model.load_state_dict(torch.load(config.model_load_path))
model = model.model_q
for param in model.parameters():
    param.requires_grad = False
model.fc = nn.Linear(model.fc[0].in_features, 10)
model.to(config.device)

# loss, optimizer and hyperparameters
current_step = 0
criterion = nn.CrossEntropyLoss()
# optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum,
#                       weight_decay=config.weight_decay, nesterov=True)
optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
# scheduler = StepLR(optimizer, step_size=config.step_size, gamma=config.gamma)
scaler = torch.cuda.amp.GradScaler()
logger = WanDBWriter(config)
# train
tqdm_bar = tqdm(total=config.num_epochs * len(train_loader) - current_step)
for epoch in range(config.num_epochs):
    accuracy, loss = 0, 0
    for i, (imgs, labels) in enumerate(train_loader):
        current_step += 1
        tqdm_bar.update(1)
        logger.set_step(current_step)

        imgs, labels = imgs.to(config.device), labels.to(config.device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss_ = criterion(outputs, labels)
        # loss.backward()
        # optimizer.step()
        scaler.scale(loss_).backward()
        scaler.step(optimizer)
        scaler.update()

        _, predicted = torch.max(outputs.data, 1)
        accuracy += (predicted == labels).sum().item()
        loss += loss_.item()
    
    # scheduler.step()
    logger.add_scalar('loss', loss / len(train_loader))
    logger.add_scalar('accuracy', accuracy / len(train_loader) / config.batch_size)
    logger.add_scalar('lr', optimizer.param_groups[0]["lr"])
    logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)

    # evaluate
    if config.eval_epochs != 0 and epoch % config.eval_epochs == config.eval_epochs - 1:
        model.eval()
        accuracy, loss = 0, 0
        current_step_test = 0
        for i, (imgs, labels) in enumerate(test_loader):
            current_step_test += 1
            logger.set_step(current_step, 'test')
            imgs, labels = imgs.to(config.device), labels.to(config.device)

            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    outputs = model(imgs)
                    loss_ = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            accuracy += (predicted == labels).sum().item()
            loss += loss_.item()
        
        logger.add_scalar('loss', loss / len(test_loader))
        logger.add_scalar('accuracy', accuracy / len(test_loader) / config.batch_size)
        logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        model.train()

    if config.save_epochs != 0 and epoch % config.save_epochs == config.save_epochs - 1:
        torch.save(model.state_dict(), f'{config.model_save_name}_{epoch}.pth')

logger.finish()

# Fine-tuning

## SimCLR

In [None]:
@dataclass
class SimCLR_FT_config:
    wandb_project = 'SLL_HW2'
    num_workers = 2
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407

    num_epochs = 90
    save_epochs = 10
    eval_epochs = 10
    batch_size = 256
    optim = 'SGD'
    nesterov = True
    momentum = 0.9
    lr = 0.05
    weight_decay = 1e-4

    num_features = 512
    model_load_path = "model_SimCLR_2_89.pth"
    model_save_name = "model_SimCLR_2_89_FT"

    scheduler = 'MultiStepLR'
    milestones = [30, 50, 70, 80]
    gamma = 0.1

In [None]:
train_transforms = T.Compose([
    T.RandomResizedCrop(96),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
config = SimCLR_FT_config()
set_random_seed(config.seed)
train_dataset = datasets.STL10('data', 'train', download=True, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)
test_dataset = datasets.STL10('data', 'test', download=True, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True)

model = models.resnet18(num_classes=config.num_features)
in_features = model.fc.in_features
projection_g = nn.Sequential(
    nn.Linear(in_features, in_features),
    nn.ReLU(),
    model.fc
)
model.fc = projection_g
model.load_state_dict(torch.load(config.model_load_path))
model.fc = nn.Linear(in_features, 10)
model.to(config.device)

# loss, optimizer and hyperparameters
current_step = 0
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum,
                      weight_decay=config.weight_decay, nesterov=config.nesterov)
scheduler = MultiStepLR(optimizer, milestones=config.milestones, gamma=config.gamma)
scaler = torch.cuda.amp.GradScaler()
logger = WanDBWriter(config)
# train
tqdm_bar = tqdm(total=config.num_epochs * len(train_loader) - current_step)
for epoch in range(config.num_epochs):
    accuracy, loss = 0, 0
    for i, (imgs, labels) in enumerate(train_loader):
        current_step += 1
        tqdm_bar.update(1)
        logger.set_step(current_step)

        imgs, labels = imgs.to(config.device), labels.to(config.device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        # loss.backward()
        # optimizer.step()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        _, predicted = torch.max(outputs.data, 1)
        accuracy += (predicted == labels).sum().item()
        loss += loss.item()
    
    scheduler.step()
    logger.add_scalar('loss', loss / len(train_loader))
    logger.add_scalar('accuracy', accuracy / len(train_loader) / config.batch_size)
    logger.add_scalar('lr', optimizer.param_groups[0]["lr"])
    logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)

    # evaluate
    if config.eval_epochs != 0 and epoch % config.eval_epochs == config.eval_epochs - 1:
        model.eval()
        accuracy, loss = 0, 0
        current_step_test = 0
        for i, (imgs, labels) in enumerate(test_loader):
            current_step_test += 1
            logger.set_step(current_step, 'test')
            imgs, labels = imgs.to(config.device), labels.to(config.device)

            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    outputs = model(imgs)
                    loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            accuracy += (predicted == labels).sum().item()
            loss += loss.item()
        
        logger.add_scalar('loss', loss / len(test_loader))
        logger.add_scalar('accuracy', accuracy / len(test_loader) / config.batch_size)
        logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        model.train()

    if config.save_epochs != 0 and epoch % config.save_epochs == config.save_epochs - 1:
        torch.save(model.state_dict(), f'{config.model_save_name}_{epoch}.pth')

logger.finish()

## BYOL

In [None]:
@dataclass
class BYOL_FT_config:
    wandb_project = 'SLL_HW2'
    num_workers = 2
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407

    num_epochs = 90
    save_epochs = 10
    eval_epochs = 10
    batch_size = 256

    optim = 'SGD'
    nesterov = True
    momentum = 0.9
    lr = 0.05
    weight_decay = 1e-4

    mlp_hidden_size = 4096
    projection_size = 256
    moving_average = 0.99

    model_load_path = "model_BYOL_oshibka_219.pth"
    model_save_name = "model_BYOL_oshibka_219_FT"

    scheduler = 'MultiStepLR'
    milestones = [30, 50, 70, 80]
    gamma = 0.1

In [None]:
train_transforms = T.Compose([
    T.RandomResizedCrop(96),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
class BYOL_network_FT(nn.Module):
    def __init__(self, byol_model, in_features):
        super(BYOL_network_FT, self).__init__()
        self.byol_model = byol_model
        self.fc = nn.Linear(in_features, 10)
    
    def forward(self, x):
        x = self.byol_model(x)
        x = x.view(x.shape[0], x.shape[1])
        return self.fc(x)

In [None]:
config = BYOL_FT_config()
set_random_seed(config.seed)
train_dataset = datasets.STL10('data', 'train', download=True, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)
test_dataset = datasets.STL10('data', 'test', download=True, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True)

model = BYOL_network(config)
model.load_state_dict(torch.load(config.model_load_path))
in_features = model.in_features
model = nn.Sequential(*list(model.children())[:-1])
model = BYOL_network_FT(model, in_features)
model.to(config.device)

# loss, optimizer and hyperparameters
current_step = 0
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum,
                      weight_decay=config.weight_decay, nesterov=config.nesterov)
scheduler = MultiStepLR(optimizer, milestones=config.milestones, gamma=config.gamma)
scaler = torch.cuda.amp.GradScaler()
logger = WanDBWriter(config)
# train
tqdm_bar = tqdm(total=config.num_epochs * len(train_loader) - current_step)
for epoch in range(config.num_epochs):
    accuracy, loss = 0, 0
    for i, (imgs, labels) in enumerate(train_loader):
        current_step += 1
        tqdm_bar.update(1)
        logger.set_step(current_step)

        imgs, labels = imgs.to(config.device), labels.to(config.device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss = criterion(outputs, labels)
        # loss.backward()
        # optimizer.step()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        _, predicted = torch.max(outputs.data, 1)
        accuracy += (predicted == labels).sum().item()
        loss += loss.item()
    
    scheduler.step()
    logger.add_scalar('loss', loss / len(train_loader))
    logger.add_scalar('accuracy', accuracy / len(train_loader) / config.batch_size)
    logger.add_scalar('lr', optimizer.param_groups[0]["lr"])
    logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)

    # evaluate
    if config.eval_epochs != 0 and epoch % config.eval_epochs == config.eval_epochs - 1:
        model.eval()
        accuracy, loss = 0, 0
        current_step_test = 0
        for i, (imgs, labels) in enumerate(test_loader):
            current_step_test += 1
            logger.set_step(current_step, 'test')
            imgs, labels = imgs.to(config.device), labels.to(config.device)

            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    outputs = model(imgs)
                    loss = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            accuracy += (predicted == labels).sum().item()
            loss += loss.item()
        
        logger.add_scalar('loss', loss / len(test_loader))
        logger.add_scalar('accuracy', accuracy / len(test_loader) / config.batch_size)
        logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        model.train()

    if config.save_epochs != 0 and epoch % config.save_epochs == config.save_epochs - 1:
        torch.save(model.state_dict(), f'{config.model_save_name}_{epoch}.pth')

logger.finish()

## MOCO

In [None]:
@dataclass
class MOCO_FT_config:
    wandb_project: str = 'SLL_HW2'
    num_workers: int = 2
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed: int = 3407

    num_epochs: int = 90
    save_epochs: int = 10
    eval_epochs: int = 10
    batch_size: int = 256
    optim: str = 'SGD'
    nesterov: bool = True
    momentum: float = 0.9
    lr: float = 0.05
    weight_decay: float = 1e-4

    dim: int = 128
    K: int = 16384
    temperature: float = 0.07
    moving_average: float = 0.999

    scheduler = 'MultiStepLR'
    milestones = [30, 50, 70, 80]
    gamma = 0.1
    
    model_load_path: str = "MOCO_2_199.pth"
    model_save_name: str = "MOCO_2_199_FT"

In [None]:
train_transforms = T.Compose([
    T.RandomResizedCrop(96),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
config = MOCO_FT_config()
set_random_seed(config.seed)
train_dataset = datasets.STL10('data', 'train', download=True, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)
test_dataset = datasets.STL10('data', 'test', download=True, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True)

# models
model = MOCO_network(config)
model.load_state_dict(torch.load(config.model_load_path))
model = model.model_q
model.fc = nn.Linear(model.fc[0].in_features, 10)
model.to(config.device)

# loss, optimizer and hyperparameters
current_step = 0
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=config.lr, momentum=config.momentum,
                      weight_decay=config.weight_decay, nesterov=config.nesterov)
scheduler = MultiStepLR(optimizer, milestones=config.milestones, gamma=config.gamma)
scaler = torch.cuda.amp.GradScaler()
logger = WanDBWriter(config)
# train
tqdm_bar = tqdm(total=config.num_epochs * len(train_loader) - current_step)
for epoch in range(config.num_epochs):
    accuracy, loss = 0, 0
    for i, (imgs, labels) in enumerate(train_loader):
        current_step += 1
        tqdm_bar.update(1)
        logger.set_step(current_step)

        imgs, labels = imgs.to(config.device), labels.to(config.device)

        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss_ = criterion(outputs, labels)
        # loss.backward()
        # optimizer.step()
        scaler.scale(loss_).backward()
        scaler.step(optimizer)
        scaler.update()

        _, predicted = torch.max(outputs.data, 1)
        accuracy += (predicted == labels).sum().item()
        loss += loss_.item()
    
    scheduler.step()
    logger.add_scalar('loss', loss / len(train_loader))
    logger.add_scalar('accuracy', accuracy / len(train_loader) / config.batch_size)
    logger.add_scalar('lr', optimizer.param_groups[0]["lr"])
    logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
    logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)

    # evaluate
    if config.eval_epochs != 0 and epoch % config.eval_epochs == config.eval_epochs - 1:
        model.eval()
        accuracy, loss = 0, 0
        current_step_test = 0
        for i, (imgs, labels) in enumerate(test_loader):
            current_step_test += 1
            logger.set_step(current_step, 'test')
            imgs, labels = imgs.to(config.device), labels.to(config.device)

            with torch.cuda.amp.autocast():
                with torch.no_grad():
                    outputs = model(imgs)
                    loss_ = criterion(outputs, labels)
            
            _, predicted = torch.max(outputs.data, 1)
            accuracy += (predicted == labels).sum().item()
            loss += loss_.item()
        
        logger.add_scalar('loss', loss / len(test_loader))
        logger.add_scalar('accuracy', accuracy / len(test_loader) / config.batch_size)
        logger.add_image('img0', imgs[0].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img1', imgs[1].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        logger.add_image('img2', imgs[2].detach().cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5)
        model.train()

    if config.save_epochs != 0 and epoch % config.save_epochs == config.save_epochs - 1:
        torch.save(model.state_dict(), f'{config.model_save_name}_{epoch}.pth')

logger.finish()

# t-SNE

## Supervised

In [None]:
import pandas as pd
from sklearn.manifold import TSNE
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import random
import warnings
warnings.filterwarnings('ignore')

In [None]:
palette = sns.color_palette("dark", 10)
def draw_tsne(model, dataloder, title):
    np.random.seed(3407)
    random.seed(3407)
    model.eval()
    train_embeds = torch.tensor([], device='cuda:0')
    train_labels = torch.tensor([], device='cuda:0')
    for i, (imgs, labels) in enumerate(dataloder):
        imgs, labels = imgs.to(config.device), labels.to(config.device)
        with torch.no_grad():
            outputs = model(imgs)
            outputs = outputs.view((outputs.shape[0], outputs.shape[1]))
            
        train_embeds = torch.cat((train_embeds, outputs))
        train_labels = torch.cat((train_labels, labels))
    

    index_to_label = {
        0: "airplane", 
        1: "bird",
        2: "car",
        3: "cat",
        4: "deer",
        5: "dog",
        6: "horse",
        7: "monkey",
        8: "ship",
        9: "truck"
    }

    feat_cols = ['embed' + str(i) for i in range(train_embeds.shape[1]) ]

    df = pd.DataFrame(train_embeds.cpu().tolist(), columns=feat_cols)
    df['y'] = train_labels.cpu()
    df['label'] = df['y'].apply(lambda i: index_to_label[int(i)])
    df = df.sort_values(by=['label'])
    
    data = df[feat_cols].values
    tsne = TSNE(n_jobs=-1, learning_rate='auto', init='pca')
    tsne_results = tsne.fit_transform(data)

    sns.color_palette("dark")
    df['tsne-2d-one'] = tsne_results[:,0]
    df['tsne-2d-two'] = tsne_results[:,1]

    plt.figure(figsize=(16,10))
    sns.scatterplot(
        x="tsne-2d-one", y="tsne-2d-two",
        hue="label",
        palette=palette,
        data=df,
        legend="full",
        alpha=0.3
    ).set(title=title)
    plt.show()

In [None]:
config = SupervisedBaselineConfig()
set_random_seed(config.seed)

model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('model_89.pth'))
model = nn.Sequential(*list(model.children())[:-1])
model.to(config.device)

In [None]:
train_dataset = datasets.STL10('data', 'train', download=True, transform=test_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=False)

draw_tsne(model, train_loader, 'Supervised train')

In [None]:
test_dataset = datasets.STL10('data', 'test', download=True, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True)

draw_tsne(model, test_loader, 'Supervised test')

## SimCLR

In [None]:
@dataclass
class SimCLR_config:
    wandb_project = 'SLL_HW2'
    num_workers = 8
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407
    
    num_epochs = 200
    save_epochs = 10
    eval_epochs = 10
    batch_size = 256
    optim = 'Adam'
    lr = 3e-4
    weight_decay = 1e-4
    
    num_features = 512

    scheduler = 'CosineAnnealingLR'
    model_load_path = "model_SimCLR_2_89.pth"
    model_save_namne = "model_SimCLR_29_LP"

In [None]:
config = SimCLR_config()
set_random_seed(config.seed)

model = models.resnet18(num_classes=config.num_features)
in_features = model.fc.in_features
projection_g = nn.Sequential(
    nn.Linear(in_features, in_features),
    nn.ReLU(),
    model.fc
)
model.fc = projection_g
model.load_state_dict(torch.load(config.model_load_path))
model = nn.Sequential(*list(model.children())[:-1])
model.to(config.device)

In [None]:
s = 1.0
size = 96
color_jitter = T.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
rnd_color_jitter = T.RandomApply([color_jitter], p=0.8)
rnd_gray = T.RandomGrayscale(p=0.2)


train_transforms = T.Compose([
    T.RandomResizedCrop(size),
    T.RandomHorizontalFlip(p=0.5),
    rnd_color_jitter,
    rnd_gray,
    T.GaussianBlur(kernel_size=int(0.1 * size)),
    T.ToTensor(),
])

In [None]:
train_dataset = datasets.STL10('data', 'train', download=True, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=False)

draw_tsne(model, train_loader, 'SimCLR train')

In [None]:
test_dataset = datasets.STL10('data', 'test', download=True, transform=train_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=False)

draw_tsne(model, test_loader, 'SimCLR test')

## BYOL

In [None]:
@dataclass
class BYOL_config:
    wandb_project = 'SLL_HW2'
    num_workers = 2
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407

    num_epochs = 1000
    save_epochs = 20
    eval_epochs = 20
    batch_size = 512

    optim = 'Adam'
    lr = 3e-4
    weight_decay = 1e-4

    mlp_hidden_size = 4096
    projection_size = 256
    moving_average = 0.99

    model_save_name = "model_BYOL"
    model_load_path = "model_BYOL_oshibka_219.pth"
    scheduler = 'CosineAnnealingLR'

In [None]:
config = BYOL_config()
set_random_seed(config.seed)
model = BYOL_network(config)
model.load_state_dict(torch.load(config.model_load_path))
model = nn.Sequential(*list(model.children())[:-1])
model.to(config.device)

In [None]:
train_dataset = datasets.STL10('data', 'train', download=True, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=False)

draw_tsne(model, train_loader, 'BYOL train')

In [None]:
test_dataset = datasets.STL10('data', 'test', download=True, transform=train_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=False)

draw_tsne(model, test_loader, 'BYOL test')

## MOCO

In [None]:
config = MOCO_config()
set_random_seed(config.seed)
# loader
train_dataset = datasets.STL10('data', 'train', download=True, transform=train_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)

test_dataset = datasets.STL10('data', 'test', download=True, transform=train_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)
# models
model = MOCO_network(config)
model.load_state_dict(torch.load(config.model_load_path))
model = nn.Sequential(*list(model.children())[:-1])
model.to(config.device)

In [None]:
draw_tsne(model, train_loader, 'MOCO train')

In [None]:
draw_tsne(model, test_loader, 'MOCO test')

## SimCLR-FT

In [None]:
@dataclass
class SimCLR_FT_config:
    wandb_project = 'SLL_HW2'
    num_workers = 2
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407

    num_epochs = 90
    save_epochs = 10
    eval_epochs = 10
    batch_size = 256
    optim = 'SGD'
    nesterov = True
    momentum = 0.9
    lr = 0.05
    weight_decay = 1e-4

    num_features = 512
    model_load_path = "model_SimCLR_2_89_FT_59.pth"

    scheduler = 'MultiStepLR'
    milestones = [30, 50, 70, 80]
    gamma = 0.1

In [None]:
config = SimCLR_FT_config()
set_random_seed(config.seed)

model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load(config.model_load_path))
model = nn.Sequential(*list(model.children())[:-1])
model.to(config.device)

In [None]:
train_dataset = datasets.STL10('data', 'train', download=True, transform=test_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=False)

draw_tsne(model, train_loader, 'SimCLR-FT train')

In [None]:
test_dataset = datasets.STL10('data', 'test', download=True, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=False)

draw_tsne(model, test_loader, 'SimCLR-FT test')

## BYOL-FT

In [None]:
@dataclass
class BYOL_FT_config:
    wandb_project = 'SLL_HW2'
    num_workers = 2
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407

    num_epochs = 90
    save_epochs = 10
    eval_epochs = 10
    batch_size = 256

    optim = 'SGD'
    nesterov = True
    momentum = 0.9
    lr = 0.05
    weight_decay = 1e-4

    mlp_hidden_size = 4096
    projection_size = 256
    moving_average = 0.99

    model_load_path = "model_BYOL_oshibka_219_FT_79.pth"
    # model_save_name = "model_BYOL_oshibka_219_FT"

    scheduler = 'MultiStepLR'
    milestones = [30, 50, 70, 80]
    gamma = 0.1

In [None]:
config = BYOL_FT_config()
set_random_seed(config.seed)

model = BYOL_network(config)
in_features = model.in_features
model = nn.Sequential(*list(model.children())[:-1])
model = BYOL_network_FT(model, in_features)
model.load_state_dict(torch.load(config.model_load_path))
model = nn.Sequential(*list(model.children())[:-1])
model.to(config.device)

In [None]:
train_dataset = datasets.STL10('data', 'train', download=True, transform=test_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=False)

draw_tsne(model, train_loader, 'BYOL-FT train')

In [None]:
test_dataset = datasets.STL10('data', 'test', download=True, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=False)

draw_tsne(model, test_loader, 'BYOL-FT test')

## MOCO

In [None]:
@dataclass
class MOCO_FT_config:
    wandb_project: str = 'SLL_HW2'
    num_workers: int = 2
    device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed: int = 3407

    num_epochs: int = 90
    save_epochs: int = 10
    eval_epochs: int = 10
    batch_size: int = 256
    optim: str = 'SGD'
    nesterov: bool = True
    momentum: float = 0.9
    lr: float = 0.05
    weight_decay: float = 1e-4

    dim: int = 128
    K: int = 16384
    temperature: float = 0.07
    moving_average: float = 0.999

    scheduler = 'MultiStepLR'
    milestones = [30, 50, 70, 80]
    gamma = 0.1
    
    model_load_path: str = "MOCO_2_199_FT_89.pth"
    model_save_name: str = "MOCO_2_199_FT"

In [None]:
test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

In [None]:
config = MOCO_FT_config()
set_random_seed(config.seed)

model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load(config.model_load_path))
model = nn.Sequential(*list(model.children())[:-1])
model.to(config.device)

In [None]:
train_dataset = datasets.STL10('data', 'train', download=True, transform=test_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=False)

draw_tsne(model, train_loader, 'MOCO-FT train')

In [None]:
test_dataset = datasets.STL10('data', 'test', download=True, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=False)

draw_tsne(model, test_loader, 'MOCO-FT test')

# OOD robustness

In [None]:
# ya cringe
index_to_label = {
    0: "airplane", 
    1: "bird",
    2: "car",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "horse",
    7: "monkey",
    8: "ship",
    9: "truck"
}

index_to_label_cifar = {
    0: "airplane",
    1: "automobile",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck"
}

cifar_index_to_stl_index = {
    0: 0,
    1: 2,
    2: 1,
    3: 3,
    4: 4,
    5: 5,
    6: 7,
    7: 6,
    8: 8,
    9: 9
}

def get_accuracy_cifar(model, test_loader, name):
    model.eval()
    accuracy = 0
    for i, (imgs, labels) in enumerate(test_loader):
        labels = torch.tensor([cifar_index_to_stl_index[x.item()] for x in labels])
        imgs, labels = imgs.to(config.device), labels.to(config.device)

        with torch.cuda.amp.autocast():
            with torch.no_grad():
                outputs = model(imgs)
                loss = criterion(outputs, labels)

        _, predicted = torch.max(outputs.data, 1)
        accuracy += (predicted == labels).sum().item()

    print(f'Accuracy for {name} = {accuracy / len(test_dataset)} on CIFAR-10')

def get_accuracy_stl(model, test_loader, name):
    model.eval()
    accuracy = 0
    for i, (imgs, labels) in enumerate(test_loader):
        # labels = torch.tensor([cifar_index_to_stl_index[x.item()] for x in labels])
        imgs, labels = imgs.to(config.device), labels.to(config.device)

        with torch.cuda.amp.autocast():
            with torch.no_grad():
                outputs = model(imgs)
                loss = criterion(outputs, labels)

        _, predicted = torch.max(outputs.data, 1)
        accuracy += (predicted == labels).sum().item()

    print(f'Accuracy for {name} = {accuracy / len(test_dataset)} on STL-10\n')

## Supervised

In [None]:
test_transforms = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
])

config = SupervisedBaselineConfig()
set_random_seed(config.seed)

model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('model_89.pth'))
model.to(config.device)

In [None]:
test_dataset = datasets.STL10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_stl(model, test_loader, 'Supervised')


test_dataset = datasets.CIFAR10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_cifar(model, test_loader, 'Supervised')

## SimCLR-LP

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('model_SimCLR_2_89_LP_199.pth'))
model.to(config.device)

test_transforms = T.Compose([
    T.Resize(96),
    T.ToTensor()
])

In [None]:
test_dataset = datasets.STL10('data', 'test', transform=T.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_stl(model, test_loader, 'CimCLR-LP')

test_dataset = datasets.CIFAR10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_cifar(model, test_loader, 'CimCLR-LP')

## SimCLR-FT

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('model_SimCLR_2_89_FT_59.pth'))
model.to(config.device)

In [None]:
test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_dataset = datasets.STL10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_stl(model, test_loader, 'CimCLR-FT')

test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
])

test_dataset = datasets.CIFAR10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_cifar(model, test_loader, 'CimCLR-FT')

## BYOL-LP

In [None]:
@dataclass
class BYOL_FT_config:
    wandb_project = 'SLL_HW2'
    num_workers = 2
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    seed = 3407

    num_epochs = 90
    save_epochs = 10
    eval_epochs = 10
    batch_size = 256

    optim = 'SGD'
    nesterov = True
    momentum = 0.9
    lr = 0.05
    weight_decay = 1e-4

    mlp_hidden_size = 4096
    projection_size = 256
    moving_average = 0.99

    model_load_path = "model_BYOL_oshibka_219.pth"
    model_save_name = "model_BYOL_oshibka_219_FT"

    scheduler = 'MultiStepLR'
    milestones = [30, 50, 70, 80]
    gamma = 0.1

In [None]:
config = BYOL_FT_config()
model = BYOL_network(config)
in_features = model.in_features
model = nn.Sequential(*list(model.children())[:-1])
model = BYOL_network_FT(model, in_features)
model.load_state_dict(torch.load('model_BYOL_oshibka_219_LP_199.pth'))
model.to(config.device)

In [None]:
test_dataset = datasets.STL10('data', 'test', transform=T.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_stl(model, test_loader, 'BYOL-LP')


test_transforms = T.Compose([
    T.Resize(96),
    T.ToTensor(),
])

test_dataset = datasets.CIFAR10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_cifar(model, test_loader, 'BYOL-LP')

## BYOL-FT

In [None]:
config = BYOL_FT_config()
model = BYOL_network(config)
in_features = model.in_features
model = nn.Sequential(*list(model.children())[:-1])
model = BYOL_network_FT(model, in_features)
model.load_state_dict(torch.load('model_BYOL_oshibka_219_FT_79.pth'))
model.to(config.device)

In [None]:
test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_dataset = datasets.STL10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_stl(model, test_loader, 'BYOL-FT')

test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
])

test_dataset = datasets.CIFAR10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_cifar(model, test_loader, 'BYOL-FT')

## MOCO-LP

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('MOCO_2_199_LP_199.pth'))
model.to(config.device)

In [None]:
test_transforms = T.Compose([
    T.Resize(96),
    # T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_dataset = datasets.STL10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_stl(model, test_loader, 'MOCO-LP')

test_transforms = T.Compose([
    T.Resize(96),
    # T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
])

test_dataset = datasets.CIFAR10('data', 'test', download=True, transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_cifar(model, test_loader, 'MOCO-LP')

## MOCO-FT

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('MOCO_2_199_FT_89.pth'))
model.to(config.device)

In [None]:
test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_dataset = datasets.STL10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_stl(model, test_loader, 'MOCO-FT')

test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
])

test_dataset = datasets.CIFAR10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

get_accuracy_cifar(model, test_loader, 'MOCO-FT')

# Optimum width

In [None]:
from scipy.optimize import fmin_l_bfgs_b
from IPython.display import clear_output


def calc_loss(model):
    # set_random_seed(config.seed)
    model.eval()
    model.to(config.device)
    loss = 0
    criterion = nn.CrossEntropyLoss()
    # print('start loss calculation')
    for i, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.to(config.device), labels.to(config.device)
        
        with torch.no_grad():
            outputs = model(imgs)
            loss_ = criterion(outputs, labels)
            # print(loss_)
        loss += loss_.item()

    # print('loss calculated, it is', loss / len(train_loader))
    return loss / len(train_loader)


def store_weights(x: np.ndarray):
    # print('start model storing')
    # set_random_seed(config.seed)
    x_ = torch.tensor(x, dtype=torch.float32)
    model.eval()
    left_index = 0
    for param in model.parameters():
        param_size = param.data.size()
        param_size_smooth = 1
        for sz in param_size:
            param_size_smooth *= sz
        param.data = x_[left_index: left_index+param_size_smooth].view(param_size)
        left_index += param_size_smooth
    # print('model stored')
    return model


def calc_minus_loss_in_x(x: np.ndarray):
    # print('hey')
    model = store_weights(x)
    # return (-1. * calc_loss(model), None)
    return -1. * calc_loss(model)
    # return (13, None) 


def calc_loss_and_grad(model):
    # set_random_seed(config.seed)
    model.eval()
    model.to(config.device)
    loss = 0
    criterion = nn.CrossEntropyLoss()
    # print('start loss calculation')
    for i, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.to(config.device), labels.to(config.device)
        
        outputs = model(imgs)
        loss_ = criterion(outputs, labels)
        loss_.backward()

        loss += loss_.item()

    print('loss calculated, it is', loss / len(train_loader))
    grad = torch.tensor([], device=config.device, dtype=torch.float64)
    for param in model.parameters():
        grad_param = param.grad.data.view(-1)
        grad = torch.cat((grad, grad_param))

    return loss / len(train_loader), grad.cpu() / len(train_loader)


def calc_minus_loss_and_grad_in_x(x: np.ndarray):
    # print('hey')
    model = store_weights(x)
    loss, grad = calc_loss_and_grad(model)
    # print('loss', loss, 'grad shape', grad.shape, 'grad type', grad.dtype)
    return -1. * loss, grad


def calc_sharpness(model, train_loader, name, epsilon=1e-3):
    # get x0 - minimum of function
    x0 = torch.tensor([], device=config.device)
    for param in model.parameters():
        x0 = torch.cat((x0, param.data.view(-1)))
    x0 = x0.detach().cpu().numpy()

    # calc f(x_0)
    f_x0 = calc_loss(model)
    # let A be Identity matrix. Get x_min and x_max
    x_min = x0 - epsilon * (np.abs(x0) + 1)
    x_max = x0 + epsilon * (np.abs(x0) + 1)

    # get f_x_max
    bounds = np.concatenate([np.reshape(x_min, (x_min.shape[0], 1)),
                         np.reshape(x_max, (x_max.shape[0], 1))], 1)

    x_max, f_x_max, d = fmin_l_bfgs_b(func=calc_minus_loss_and_grad_in_x, x0=x0,
                                    bounds=bounds, maxiter=10, m=10)
    
    # calc formula
    f_x_max_real = -1.0 * f_x_max
    sharpness = (f_x_max_real - f_x0) / (1 + f_x0) * 100.
    # clear_output()
    print(f'Sharpness for {name} model = {sharpness}')
    return sharpness

## Supervised

In [None]:
test_transforms = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

config = SupervisedBaselineConfig()
set_random_seed(config.seed)

train_dataset = datasets.STL10('data', 'train', download=True, transform=test_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('model_89.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, train_loader, 'Supervised', epsilon=1e-3)

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('model_89.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, train_loader, 'Supervised', epsilon=5e-4)

## SimCLR-LP

In [None]:
config = SimCLR_LP_config()
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('model_SimCLR_2_89_LP_199.pth'))
model.to(config.device)

train_dataset = datasets.STL10('data', 'train', download=True, transform=T.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)

In [None]:
sharpness = calc_sharpness(model, train_loader, 'SimCLR-LP', epsilon=1e-3)

In [None]:
model.load_state_dict(torch.load('model_SimCLR_2_89_LP_199.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, train_loader, 'SimCLR-LP', epsilon=5e-4)

## SimCLR-FT

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('model_SimCLR_2_89_FT_59.pth'))
model.to(config.device)

test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_dataset = datasets.STL10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

In [None]:
sharpness = calc_sharpness(model, test_loader, 'SimCLR-FT', epsilon=1e-3)

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('model_SimCLR_2_89_FT_59.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, test_loader, 'SimCLR-FT', epsilon=5e-4)

## BYOL-LP

In [None]:
train_dataset = datasets.STL10('data', 'train', download=True, transform=T.ToTensor())
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)

config = BYOL_FT_config()
model = BYOL_network(config)
in_features = model.in_features
model = nn.Sequential(*list(model.children())[:-1])
model = BYOL_network_FT(model, in_features)
model.load_state_dict(torch.load('model_BYOL_oshibka_219_LP_199.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, test_loader, 'BYOL-LP', epsilon=1e-3)

In [None]:
model = BYOL_network(config)
in_features = model.in_features
model = nn.Sequential(*list(model.children())[:-1])
model = BYOL_network_FT(model, in_features)
model.load_state_dict(torch.load('model_BYOL_oshibka_219_LP_199.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, test_loader, 'BYOL-LP', epsilon=5e-4)

## BYOL-FT

In [None]:
test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = datasets.STL10('data', 'train', download=True, transform=test_transforms)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size,
                          shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True)

In [None]:
config = BYOL_FT_config()
model = BYOL_network(config)
in_features = model.in_features
model = nn.Sequential(*list(model.children())[:-1])
model = BYOL_network_FT(model, in_features)
model.load_state_dict(torch.load('model_BYOL_oshibka_219_FT_79.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, test_loader, 'BYOL-FT', epsilon=1e-3)

In [None]:
model = BYOL_network(config)
in_features = model.in_features
model = nn.Sequential(*list(model.children())[:-1])
model = BYOL_network_FT(model, in_features)
model.load_state_dict(torch.load('model_BYOL_oshibka_219_FT_79.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, test_loader, 'BYOL-FT', epsilon=5e-4)

## MOCO-LP

In [None]:
test_transforms = T.Compose([
    T.Resize(96),
    # T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_dataset = datasets.STL10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('MOCO_2_199_LP_199.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, test_loader, 'MOCO-LP', epsilon=1e-3)

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('MOCO_2_199_LP_199.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, test_loader, 'MOCO-LP', epsilon=5e-4)

## MOCO-FT



In [None]:
test_transforms = T.Compose([
    T.Resize(110),
    T.CenterCrop(96),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

test_dataset = datasets.STL10('data', 'test', transform=test_transforms)
test_loader = DataLoader(test_dataset, batch_size=config.batch_size,
                          shuffle=False, num_workers=config.num_workers, pin_memory=True, drop_last=False)

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('MOCO_2_199_FT_89.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, test_loader, 'MOCO-FT', epsilon=1e-3)

In [None]:
model = models.resnet18(num_classes=10)
model.load_state_dict(torch.load('MOCO_2_199_FT_89.pth'))
model.to(config.device)

sharpness = calc_sharpness(model, test_loader, 'MOCO-FT', epsilon=5e-4)