In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.transforms as T
from matplotlib import pyplot as plt
%matplotlib inline

import sys
import os
sys.path.append(os.path.dirname(os.getcwd()) + "/src/")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def num_correct(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k)
        return res

def get_lr(step, total_steps, lr_max, lr_min):
    """Compute learning rate according to cosine annealing schedule."""
    return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))

In [10]:
weights = torch.load("../results/simclr_stl10_03-20-2023_16-16-57/checkpoints/checkpoint_epoch999.pt")["model_state"]
#weights = torch.load("../results/vi-20s-40refine-thresh0.1-infonce1e-2_02-24-2023_12-41-03/checkpoints/checkpoint_epoch599.pt")["model_state"]
#weights = torch.load("../results/vi-20s-infonce1e-2-constspeed-kl1e-4-attn_02-28-2023_10-41-44/checkpoints/checkpoint_epoch999.pt")["model_state"]
device = "cuda:0"

model = torch.hub.load("pytorch/vision:v0.10.0", "resnet18", pretrained=False).to(device)
model.fc = nn.Identity()
#model.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
#model.maxpool = nn.Identity()
model.requires_grad = False

own_state = model.state_dict()
for name, param in weights.items():
    name = name.replace("backbone.backbone_network.", "")
    if name not in own_state:
        continue
    if isinstance(param, nn.Parameter):
        # backwards compatibility for serialized parameters
        param = param.data
    own_state[name].copy_(param)

Using cache found in /home/kfallah/.cache/torch/hub/pytorch_vision_v0.10.0


In [11]:
t = T.Compose(
    [T.Resize(70, interpolation=3), T.CenterCrop(64), T.ToTensor(), T.Normalize((0.43, 0.42, 0.39), (0.27, 0.26, 0.27))]
)
train_data = torchvision.datasets.STL10("../datasets", split="train", transform=t)
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=512, drop_last=False)

train_data = torchvision.datasets.STL10("../datasets", split="test", transform=t)
val_dataloader = torch.utils.data.DataLoader(train_data, batch_size=512, shuffle=False)

train_x = []
train_y = []
for idx, batch in enumerate(train_dataloader):
    x, y = batch
    x = x.to(device)
    feat = model(x)
    train_x.append(feat.detach().cpu())
    train_y.append(y)
train_x = torch.cat(train_x)
train_y = torch.cat(train_y)

val_x = []
val_y = []
for idx, batch in enumerate(val_dataloader):
    x, y = batch
    x = x.to(device)
    feat = model(x)
    val_x.append(feat.detach().cpu())
    val_y.append(y)
val_x = torch.cat(val_x)
val_y = torch.cat(val_y)



In [23]:
clf = nn.Linear(512, 10).to(device)

lr_start, lr_end = 1e-2, 1e-6
gamma = (lr_end / lr_start) ** (1 / 500)
optimizer = torch.optim.Adam(clf.parameters(), lr=lr_start, weight_decay=5e-6)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

criterion = nn.CrossEntropyLoss().to(device)
train_x, train_y = train_x.to(device), train_y.to(device)
val_x, val_y = val_x.to(device), val_y.to(device)

for e in range(500):
    perm = torch.randperm(len(train_x)).view(-1, 500)
    for idx in perm:
        optimizer.zero_grad()
        criterion(clf(train_x[idx]), train_y[idx]).backward()
        optimizer.step()
        scheduler.step()


    if (e+1) % 100 == 0:
        y_pred = clf(val_x)
        pred_top = y_pred.topk(max([1, 5]), 1, largest=True, sorted=True).indices
        acc = {
            t: (pred_top[:, :t] == val_y[..., None]).float().sum(1).mean().cpu().item()
            for t in [1, 5]
        }
        print(f"Epoch {e}: " + str(acc[1]))

Epoch 99: 0.7738750576972961
Epoch 199: 0.7728750109672546
Epoch 299: 0.7715000510215759
Epoch 399: 0.7710000276565552
Epoch 499: 0.7710000276565552


In [6]:
for e in range(100):
    model.train()
    epoch_loss = []
    for idx, batch in enumerate(train_dataloader):
        x, y = batch
        x, y = x.to(device), y.to(device)
        # Send inputs through model
        feat = model(x)
        y_logit = linear_head(feat).squeeze(1)
        loss = F.cross_entropy(y_logit, y)

        epoch_loss.append(loss.item())

        # Backpropagate loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    model.eval()
    num_top1_correct = 0
    num_top5_correct = 0
    total = 0
    val_loss = []
    with torch.no_grad():
        for idx, batch in enumerate(val_dataloader):
            x, y = batch
            x, y = x.to(device), y.to(device)

            feat = model(x)
            y_logit = linear_head(feat).squeeze(1)
            loss = F.cross_entropy(y_logit, y)
            val_loss.append(loss.item())

            batch_top1, batch_top5 = num_correct(y_logit, y, topk=(1, 5))
            num_top1_correct += batch_top1.item()
            num_top5_correct += batch_top5.item()
            total += len(x)

    num_top1_acc = num_top1_correct / total
    num_top5_acc = num_top5_correct / total
    print(f"Epoch {e+1}, train loss: {np.mean(epoch_loss):.3E}, val loss: {np.mean(val_loss):.3E}, top1 acc: {num_top1_acc*100:.2f}%")

Epoch 1, train loss: 4.331E-01, val loss: 4.098E-01, top1 acc: 87.41%
Epoch 2, train loss: 4.073E-01, val loss: 3.950E-01, top1 acc: 87.57%
Epoch 3, train loss: 3.936E-01, val loss: 3.848E-01, top1 acc: 87.80%
Epoch 4, train loss: 3.812E-01, val loss: 3.730E-01, top1 acc: 87.89%
Epoch 5, train loss: 3.797E-01, val loss: 3.722E-01, top1 acc: 87.95%
Epoch 6, train loss: 3.764E-01, val loss: 3.660E-01, top1 acc: 88.14%
Epoch 7, train loss: 3.687E-01, val loss: 3.677E-01, top1 acc: 88.02%
Epoch 8, train loss: 3.646E-01, val loss: 3.614E-01, top1 acc: 88.19%
Epoch 9, train loss: 3.660E-01, val loss: 3.604E-01, top1 acc: 88.34%
Epoch 10, train loss: 3.628E-01, val loss: 3.523E-01, top1 acc: 88.43%
Epoch 11, train loss: 3.633E-01, val loss: 3.552E-01, top1 acc: 88.36%
Epoch 12, train loss: 3.621E-01, val loss: 3.511E-01, top1 acc: 88.52%
Epoch 13, train loss: 3.560E-01, val loss: 3.541E-01, top1 acc: 88.44%
Epoch 14, train loss: 3.546E-01, val loss: 3.456E-01, top1 acc: 88.59%
Epoch 15, train