In [7]:
# action plan
# train a distilled resnet18 using the class prototypes
# apply dmil strategy to the distilled model 
# evaluate the distilled model



In [8]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
from torchvision import datasets, models, transforms

In [9]:
modelrn18 = models.resnet18(pretrained=False)

In [10]:
import clip
import numpy as np

In [11]:
model, preprocess = clip.load("ViT-B/32")
#model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408


In [12]:
preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7f865f7508b0>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [13]:
ds_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=preprocess)
ds_test  = datasets.CIFAR100(root='./data', train=False, download=True, transform=preprocess)

Files already downloaded and verified
Files already downloaded and verified


In [14]:
class_order = torch.randperm(100)

In [15]:
class_order

tensor([61, 47, 99, 51, 94,  3, 56, 25, 70,  7, 96, 42, 38, 12, 78, 84, 71, 32,
        33, 62, 73, 90, 64, 82, 40, 21, 60, 16, 37, 58, 67,  9, 81,  1,  5, 57,
        48, 87, 53, 92, 45, 54, 59, 34, 66, 75, 46,  0,  2,  6, 55, 86, 50, 26,
        43, 49, 30,  8, 20, 89, 11, 18, 17, 52, 65, 72, 63, 76, 79, 41, 14, 10,
        69, 28, 77, 35, 80, 85, 97, 93, 74, 22, 23, 13, 24, 83, 19,  4, 29, 36,
        44, 15, 88, 91, 98, 31, 27, 68, 95, 39])

In [16]:
np.unique(class_order,return_counts=True)

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
        34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
        51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
        68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
        85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
 array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]))

In [17]:
class_order[:50]

tensor([61, 47, 99, 51, 94,  3, 56, 25, 70,  7, 96, 42, 38, 12, 78, 84, 71, 32,
        33, 62, 73, 90, 64, 82, 40, 21, 60, 16, 37, 58, 67,  9, 81,  1,  5, 57,
        48, 87, 53, 92, 45, 54, 59, 34, 66, 75, 46,  0,  2,  6])

In [18]:
#write a function to retrieve only examples from the first 50 classes defined on class_order in a class subsample
def instances_from_classes(dataset, class_order):
    subset = []
    for i in range(len(dataset)):
        _, label = dataset[i]
        if label in class_order:
            subset.append(i)
    return subset


In [19]:

set_train = instances_from_classes(ds_train,class_order[:50])
set_test  = instances_from_classes(ds_test, class_order[:50])

In [20]:
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=64, sampler=torch.utils.data.SubsetRandomSampler(set_train))
dl_test = torch.utils.data.DataLoader(ds_test,  batch_size=64, sampler=torch.utils.data.SubsetRandomSampler(set_test))

In [21]:
x,y = next(iter(dl_train))

In [22]:
x,y

(tensor([[[[-1.2959, -1.2959, -1.2959,  ..., -1.4273, -1.4273, -1.4273],
           [-1.2959, -1.2959, -1.2959,  ..., -1.4273, -1.4273, -1.4273],
           [-1.2959, -1.2959, -1.2959,  ..., -1.4273, -1.4273, -1.4273],
           ...,
           [-0.1718, -0.1718, -0.1718,  ..., -0.8288, -0.8288, -0.8288],
           [-0.1572, -0.1572, -0.1572,  ..., -0.8434, -0.8434, -0.8434],
           [-0.1572, -0.1572, -0.1572,  ..., -0.8434, -0.8434, -0.8434]],
 
          [[-1.4069, -1.4069, -1.4069,  ..., -1.5870, -1.5870, -1.5870],
           [-1.4069, -1.4069, -1.4069,  ..., -1.5870, -1.5870, -1.5870],
           [-1.4069, -1.4069, -1.4069,  ..., -1.5870, -1.5870, -1.5870],
           ...,
           [-0.6415, -0.6415, -0.6415,  ..., -1.0918, -1.0767, -1.0767],
           [-0.6415, -0.6415, -0.6415,  ..., -1.0918, -1.0767, -1.0767],
           [-0.6265, -0.6265, -0.6265,  ..., -1.0918, -1.0767, -1.0767]],
 
          [[-1.1816, -1.1816, -1.1816,  ..., -1.3380, -1.3380, -1.3380],
           [-

In [23]:
modelrn18 = models.resnet18(pretrained=False)



In [24]:
modelrn18.fc = nn.Linear(512,512)

In [25]:
clip_pred = model.encode_image(x)

In [26]:
rn18_pred = modelrn18(x)

In [27]:
loss = nn.MSELoss()
loss(rn18_pred,clip_pred)

tensor(0.4908, grad_fn=<MseLossBackward0>)

In [28]:
from tqdm import tqdm
import sklearn.metrics as metrics

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
criterion = nn.MSELoss()
criterion(rn18_pred,clip_pred)

In [31]:
opt = optim.Adam(modelrn18.parameters(), lr=1e-3)

In [32]:
#freeze clip parameters
for param in model.visual.parameters():
    param.requires_grad = False

In [34]:
stop = False
loss_train = []
while (not stop):
    model.train()
    lloss = []
    loop = tqdm(dl_train)
    for x,y in loop:
        x = x.to(device)
        y = y.to(device)
        pred_clip = model.encode_image(x)
        pred_resnet = modelrn18(x)
        closs = criterion(pred_clip,pred_resnet)
        closs.backward()
        opt.step()
        opt.zero_grad()
        lloss.append(closs.item())
    loss_train.append(np.mean(lloss))
    lloss = []
    model.eval()

  0%|          | 0/391 [00:00<?, ?it/s]

: 