In [13]:
import torch 

device = torch.device("cuda:2")
target_z = torch.load("/home/sjoshi/mtt-distillation/target_rep/CIFAR100/train_rep_r50_128_dim.pt", map_location=device)

In [14]:
import torch
import torchvision 
from torchvision import transforms
from PIL import Image

def ColourDistortion(s=1.0):
    # s is the strength of color distortion.
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort


mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
transform_train = transforms.Compose([
        transforms.RandomResizedCrop((32,32), interpolation=Image.BICUBIC),
        transforms.RandomHorizontalFlip(),
        ColourDistortion(s=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
dataset = torchvision.datasets.CIFAR100("/data", train=True, transform=transform_train)



In [15]:
from typing import Any
from torch.utils.data import Dataset 
class DatasetWithIndices(Dataset):
    def __init__(self, dataset) -> None:
        super().__init__()
        self.dataset = dataset
    
    def __getitem__(self, index: Any) -> Any:
        return index, self.dataset[index]

    def __len__(self):
        return len(self.dataset)

In [43]:
from networks import ConvNet, ResNet18
from utils import get_default_convnet_setting
trainloader = torch.utils.data.DataLoader(DatasetWithIndices(dataset), batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
device = torch.device("cuda:2")
net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting()
model = ConvNet(channel=3, num_classes=128, net_depth=net_depth, net_act=net_act, net_width=net_width, net_norm=net_norm, net_pooling=net_pooling)

model = ResNet18(channel=3, num_classes=128)


In [44]:
from torch import nn, optim
from tqdm import tqdm 
import torch.nn.functional as F

criterion = nn.MSELoss()

model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
pbar = tqdm(range(10), desc="pre-training")
for epoch in pbar:
    loss_avg, num_exp = 0, 0
    for idx, datum in trainloader:
        img = datum[0].float().to(device)
        n_b = img.shape[0]
        student_z = model(img)
        
        student_dist = F.cosine_similarity(student_z.unsqueeze(1), student_z.unsqueeze(0), dim=2)
        teacher_z = target_z[idx]
        teacher_dist = F.cosine_similarity(teacher_z.unsqueeze(1), teacher_z.unsqueeze(0), dim=2)
        loss = torch.log(torch.sum(torch.exp(student_dist - teacher_dist)))

        loss_avg += loss.item()
        num_exp += n_b

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    loss_avg /= num_exp
    pbar.set_postfix_str(f"loss: {loss_avg}")

pre-training: 100%|██████████| 10/10 [02:02<00:00, 12.24s/it, loss: 0.21730676044464112]


In [56]:
from torch.utils.data import Subset 
import random

Z = []
Y = []
random.seed(0)
random_subset = list(random.sample(range(50000), 1000))
clf_cifar100 = Subset(torchvision.datasets.CIFAR100("/data", transform=transform), indices=random_subset)
clf_dataloader = torch.utils.data.DataLoader(clf_cifar100, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)
with torch.no_grad():
    for X, y in tqdm(clf_dataloader, desc="encoding"):
        Z.append(model.features(X.to(device)).view(-1, 2048))
        #Z.append(model(X.to(device)))
        Y.append(y.to(device))
Z = torch.cat(Z, dim=0)
Y = torch.cat(Y, dim=0)
print(Z.shape)
print(Y.shape)

encoding: 100%|██████████| 4/4 [00:00<00:00,  6.17it/s]

torch.Size([1000, 2048])
torch.Size([1000])





In [57]:
def train_clf(X, y, representation_dim, num_classes, device, reg_weight=1e-3, iter=500):
    print('\nL2 Regularization weight: %g' % reg_weight)

    criterion = nn.CrossEntropyLoss()
    n_lbfgs_steps = iter

    # Should be reset after each epoch for a completely independent evaluation
    clf = nn.Linear(representation_dim, num_classes).to(device)
    clf_optimizer = optim.LBFGS(clf.parameters())
    clf.train()

    t = tqdm(range(n_lbfgs_steps), desc='Loss: **** | Train Acc: ****% ', bar_format='{desc}{bar}{r_bar}')
    for _ in t:
        def closure():
            clf_optimizer.zero_grad()
            raw_scores = clf(X)
            loss = criterion(raw_scores, y)
            loss += reg_weight * clf.weight.pow(2).sum()
            loss.backward()

            _, predicted = raw_scores.max(1)
            correct = predicted.eq(y).sum().item()

            t.set_description('Loss: %.3f | Train Acc: %.3f%% ' % (loss, 100. * correct / y.shape[0]))

            return loss

        clf_optimizer.step(closure)

    return clf


def test_clf(testloader, device, net, clf, features=True):
    criterion = nn.CrossEntropyLoss()
    net.eval()
    clf.eval()
    test_clf_loss = 0
    correct = 0
    total = 0
    acc_per_point = []
    with torch.no_grad():
        t = tqdm(enumerate(testloader), total=len(testloader), desc='Loss: **** | Test Acc: ****% ',
                 bar_format='{desc}{bar}{r_bar}')
        for batch_idx, (inputs, targets) in t:
            inputs, targets = inputs.to(device), targets.to(device)
            representation = None
            if features:
                representation = net.features(inputs).view(-1, 2048)
            else:
                representation = net(inputs)
            # test_repr_loss = criterion(representation, targets)
            raw_scores = clf(representation)
            clf_loss = criterion(raw_scores, targets)
            test_clf_loss += clf_loss.item()
            _, predicted = raw_scores.max(1)
            total += targets.size(0)
            acc_per_point.append(predicted.eq(targets))
            correct += acc_per_point[-1].sum().item()
            t.set_description('Loss: %.3f | Test Acc: %.3f%% ' % (test_clf_loss / (batch_idx + 1), 100. * correct / total))
            
    acc = 100. * correct / total
    return acc, torch.cat(acc_per_point, dim=0).cpu().numpy()

def top5accuracy(output, target, topk=(5,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        print(correct)
        for k in topk:
            correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size).item())
        return res

In [58]:
cifar100_test = torchvision.datasets.CIFAR100("/data", train=False, transform=transform)
testloader = torch.utils.data.DataLoader(cifar100_test, batch_size=256, shuffle=False, num_workers=4, pin_memory=True)

In [59]:
clf = train_clf(Z, Y, Z.shape[1], 100, device, iter=100)


L2 Regularization weight: 0.001


Loss: 0.269 | Train Acc: 100.000% : ██████████| 100/100 [00:02<00:00, 35.34it/s]


In [60]:
acc, acc_per_point = test_clf(testloader, device, model, clf, features=True)

Loss: 4.679 | Test Acc: 11.230% : ██████████| 40/40 [00:00<00:00, 61.26it/s]
