In [2]:


import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
from torchvision import datasets, models, transforms
from tqdm import tqdm
import numpy as np


In [40]:
from resnetpass import resnet18_cbam

In [78]:
model = resnet18_cbam(pretrained=False)

In [44]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (ca): ChannelAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (max_pool): AdaptiveMaxPool2d(output_size=1)
        (fc1): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu1): ReLU()
        (fc2

In [45]:
device = torch.device('cuda:0')
class_order_file = 'class_order.pth'

In [46]:
!pip install pytorch-metric-learning

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.2[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [47]:
#model = models.resnet18(pretrained=False)
#model.fc = nn.Linear(512,512)

In [48]:
!pwd

/home/ubuntu/lialib/colab


In [49]:
torch.cuda.get_device_properties(0)

_CudaDeviceProperties(name='NVIDIA GeForce RTX 3060', major=8, minor=6, total_memory=12044MB, multi_processor_count=28)

In [50]:
#model.load_state_dict(model_state_dict['rn18'])
import os

if os.path.isfile(class_order_file ):
    class_order = torch.load(class_order_file)
else:
    class_order = torch.randperm(100)
    torch.save(class_order,class_order_file)

In [51]:
class_order

tensor([30, 43, 94, 73, 34, 67, 62, 33, 61, 15, 78, 28, 41, 68, 57,  0,  3, 22,
         5,  2, 65, 26, 54, 32, 24, 83, 72, 13, 29, 49, 98, 20, 31, 82, 71, 44,
        42, 91, 40, 99, 23, 97,  6, 90, 45, 63, 60, 56, 87, 47, 79, 51, 58, 17,
        16, 96, 52, 75, 92, 69, 35, 38, 53, 25,  1, 80, 39, 88, 21, 74, 36, 19,
        11, 10, 76, 89, 59, 37,  9, 14, 81, 95, 77, 64, 70, 46, 55, 66, 86, 85,
        18, 12,  7, 48, 93,  4, 27,  8, 50, 84])

In [52]:
preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
])

In [53]:
ds_train = datasets.CIFAR100(root='./data', train=True, download=True, transform=preprocess)
ds_test  = datasets.CIFAR100(root='./data', train=False, download=True, transform=preprocess)

Files already downloaded and verified
Files already downloaded and verified


In [54]:
def instances_from_classes(dataset, class_order):
    subset = []
    for i in range(len(dataset)):
        _, label = dataset[i]
        if label in class_order:
            subset.append(i)
    return subset


In [55]:
set_train = instances_from_classes(ds_train,class_order[:50])
set_test  = np.array(instances_from_classes(ds_test, class_order[:50]))

In [56]:
set_test.shape

(5000,)

In [57]:
sub_sample_train = sorted(np.random.permutation(set_train)[:300])
sub_sample_test  = sorted(np.random.permutation(set_test)[:300])

In [58]:
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=64, sampler=torch.utils.data.SubsetRandomSampler(set_train))
dl_test = torch.utils.data.DataLoader(ds_test,  batch_size=64, sampler=torch.utils.data.SubsetRandomSampler(set_test))
dl_sub_train = torch.utils.data.DataLoader(ds_train,  batch_size=64, sampler=torch.utils.data.SubsetRandomSampler(sub_sample_train))
dl_sub_test = torch.utils.data.DataLoader(ds_test,  batch_size=64, sampler=torch.utils.data.SubsetRandomSampler(sub_sample_test))

In [59]:
x,y = next(iter(dl_train))

In [60]:
y.shape

torch.Size([64])

In [61]:
y

tensor([20, 94,  5, 31, 56, 20, 15, 78,  0, 94, 29, 65, 30, 31, 23, 30, 40, 73,
        87, 90, 68, 62, 68, 24, 34, 31, 82, 56, 94, 78, 47, 91, 99, 43, 44, 47,
        33,  6, 56, 31, 43, 32, 98, 44, 24, 98, 28, 28, 61, 68, 57, 57,  2, 65,
        31, 22, 91, 94, 24, 68, 43, 68, 97, 91])

In [62]:

loss_func = nn.CrossEntropyLoss()


In [63]:
!pip install wandb

Defaulting to user installation because normal site-packages is not writeable
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.2[0m[39;49m -> [0m[32;49m24.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


In [64]:
import wandb

In [65]:
wandb.init(
    # set the wandb project where this run will be logged
    project="clip+dmil",
)

0,1
accuracy,▁▆█▇▇▇████
f1,▁▇█▇█▇█▇██
loss_train,█▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
precision,▁▆█▇▇▆▇▇▇█
recall,▁▆█▇▇▇████

0,1
accuracy,0.59
f1,0.58137
loss_train,0.00034
precision,0.61196
recall,0.59856


In [66]:

opt = optim.Adam(model.parameters(),lr=0.001)

In [67]:
scheduler = optim.lr_scheduler.MultiStepLR(opt,milestones=[45,90],gamma=0.1)

In [68]:
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (ca): ChannelAttention(
        (avg_pool): AdaptiveAvgPool2d(output_size=1)
        (max_pool): AdaptiveMaxPool2d(output_size=1)
        (fc1): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (relu1): ReLU()
        (fc2

In [69]:
from matplotlib import cm
import matplotlib.pyplot as plt
def tsne_proj(batch_emb, batch_y):
    tsne = TSNE(2, verbose=1)
    tsne_proj = tsne.fit_transform(batch_emb.to('cpu').detach().numpy())
    cmap = cm.get_cmap('tab10')
    fig, ax = plt.subplots(figsize=(15,15))
    num_categories = 100
    for lab in range(num_categories):
        if lab in batch_y:
            indices = batch_y==lab
            ax.scatter(tsne_proj[indices,0],tsne_proj[indices,1], c=np.array(cmap(lab)).reshape(1,4), label = lab ,alpha=0.5)
    ax.legend(fontsize='large', markerscale=2)
    plt.show()

In [70]:
from sklearn.metrics import classification_report

In [71]:
def eval_model(model,dl_sub_test):
    model.eval()
    with torch.no_grad():
        all_preds = []
        all_labels = []
        for x,y in dl_sub_test:
            x = x.to(device)
            y = y.to(device)
            emb = model(x)
            preds = emb.argmax(dim=1)
            all_preds.append(preds)
            all_labels.append(y)
        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)
    model.train()
    return all_preds, all_labels

In [72]:
pred,labels = eval_model(model,dl_sub_test)

In [73]:
from sklearn import metrics

In [74]:
f1 = metrics.f1_score(labels.cpu(),pred.cpu(),average='macro')
precision = metrics.precision_score(labels.cpu(),pred.cpu(),average='macro')
recall = metrics.recall_score(labels.cpu(),pred.cpu(),average='macro')
accuracy = metrics.accuracy_score(labels.cpu(),pred.cpu())



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [75]:
model.to(device)
loss_func.to(device)
loss_train = []
for epoch in range(100):
    model.train()
    loop = tqdm(dl_train)
    lloss = []
    for data, labels in loop:
        data = data.to(device)
        labels = labels.to(device)
        opt.zero_grad()
        pred = model(data)
        loss = loss_func(pred, labels)
        loss.backward()
        lloss.append(loss.item())
        opt.step()
    scheduler.step()
    if epoch %10 == 0:
        pred,labels = eval_model(model,dl_sub_test)
        f1 = metrics.f1_score(labels.cpu(),pred.cpu(),average='macro')
        precision = metrics.precision_score(labels.cpu(),pred.cpu(),average='macro')
        recall = metrics.recall_score(labels.cpu(),pred.cpu(),average='macro')
        accuracy = metrics.accuracy_score(labels.cpu(),pred.cpu())
        wandb.log({'precision':precision,'recall':recall,'f1':f1,'accuracy':accuracy},step=epoch)
    loss_train.append(np.mean(lloss))
    print(f"loss {loss_train[-1]}")
    wandb.log({'loss_train':loss_train[-1]},step=epoch)

100%|██████████| 391/391 [00:16<00:00, 23.67it/s]
  _warn_prf(average, modifier, msg_start, len(result))


loss 3.568660565959218


100%|██████████| 391/391 [00:16<00:00, 24.06it/s]


loss 2.4020654865542945


100%|██████████| 391/391 [00:16<00:00, 23.83it/s]


loss 1.8282806525754807


100%|██████████| 391/391 [00:16<00:00, 23.75it/s]


loss 1.4581588450295235


100%|██████████| 391/391 [00:16<00:00, 23.78it/s]


loss 1.1785588737034127


100%|██████████| 391/391 [00:16<00:00, 23.05it/s]


loss 0.9363632220441424


100%|██████████| 391/391 [00:16<00:00, 23.01it/s]


loss 0.6912650209863472


100%|██████████| 391/391 [00:17<00:00, 22.84it/s]


loss 0.4872148690549919


100%|██████████| 391/391 [00:16<00:00, 23.06it/s]


loss 0.3055298051337147


100%|██████████| 391/391 [00:17<00:00, 22.91it/s]


loss 0.1966261644192669


100%|██████████| 391/391 [00:17<00:00, 22.96it/s]


loss 0.14391837332902663


100%|██████████| 391/391 [00:16<00:00, 23.01it/s]


loss 0.12587034539855502


100%|██████████| 391/391 [00:16<00:00, 23.01it/s]


loss 0.09581386598536883


100%|██████████| 391/391 [00:17<00:00, 22.90it/s]


loss 0.09406608606562437


100%|██████████| 391/391 [00:17<00:00, 22.63it/s]


loss 0.08738058819280714


100%|██████████| 391/391 [00:17<00:00, 22.92it/s]


loss 0.0927869243037594


100%|██████████| 391/391 [00:17<00:00, 22.86it/s]


loss 0.07077688477752382


100%|██████████| 391/391 [00:16<00:00, 23.00it/s]


loss 0.06377080012389157


100%|██████████| 391/391 [00:16<00:00, 23.01it/s]


loss 0.062082323614898544


100%|██████████| 391/391 [00:16<00:00, 23.02it/s]


loss 0.049125740483231706


100%|██████████| 391/391 [00:16<00:00, 23.03it/s]


loss 0.04812006255054413


100%|██████████| 391/391 [00:17<00:00, 22.77it/s]


loss 0.06680742537131167


100%|██████████| 391/391 [00:16<00:00, 23.05it/s]


loss 0.050783618450965115


100%|██████████| 391/391 [00:16<00:00, 23.01it/s]


loss 0.0410077676267537


100%|██████████| 391/391 [00:16<00:00, 23.08it/s]


loss 0.037397687452013995


100%|██████████| 391/391 [00:16<00:00, 23.05it/s]


loss 0.04393505151657497


100%|██████████| 391/391 [00:17<00:00, 22.89it/s]


loss 0.04548772183053977


100%|██████████| 391/391 [00:17<00:00, 22.83it/s]


loss 0.033671332458319984


100%|██████████| 391/391 [00:16<00:00, 23.04it/s]


loss 0.023483879948177796


100%|██████████| 391/391 [00:17<00:00, 22.98it/s]


loss 0.049222882029593296


100%|██████████| 391/391 [00:16<00:00, 23.01it/s]


loss 0.034715530729335746


100%|██████████| 391/391 [00:17<00:00, 22.96it/s]


loss 0.01930413347052034


100%|██████████| 391/391 [00:17<00:00, 22.87it/s]


loss 0.023987615642392688


100%|██████████| 391/391 [00:17<00:00, 22.93it/s]


loss 0.04843903108276046


100%|██████████| 391/391 [00:17<00:00, 22.94it/s]


loss 0.03366203777625433


100%|██████████| 391/391 [00:17<00:00, 22.98it/s]


loss 0.020371860002472406


100%|██████████| 391/391 [00:16<00:00, 23.09it/s]


loss 0.016972800621779544


100%|██████████| 391/391 [00:17<00:00, 22.84it/s]


loss 0.020530897842915466


100%|██████████| 391/391 [00:16<00:00, 23.05it/s]


loss 0.015427839721950805


100%|██████████| 391/391 [00:16<00:00, 23.00it/s]


loss 0.041443968413855


100%|██████████| 391/391 [00:16<00:00, 23.04it/s]


loss 0.027502655881740477


100%|██████████| 391/391 [00:16<00:00, 23.08it/s]


loss 0.01961502062973788


100%|██████████| 391/391 [00:16<00:00, 23.07it/s]


loss 0.018268376191341987


100%|██████████| 391/391 [00:17<00:00, 22.79it/s]


loss 0.02118684638701284


100%|██████████| 391/391 [00:17<00:00, 22.94it/s]


loss 0.023723137568232255


100%|██████████| 391/391 [00:17<00:00, 22.96it/s]


loss 0.014593762740709573


100%|██████████| 391/391 [00:17<00:00, 23.00it/s]


loss 0.006164901230521167


100%|██████████| 391/391 [00:16<00:00, 23.02it/s]


loss 0.004308426093868078


100%|██████████| 391/391 [00:17<00:00, 22.98it/s]


loss 0.003405772395791662


100%|██████████| 391/391 [00:17<00:00, 22.93it/s]


loss 0.0027007067708841636


100%|██████████| 391/391 [00:17<00:00, 22.78it/s]
  _warn_prf(average, modifier, msg_start, len(result))


loss 0.0023472137788496435


100%|██████████| 391/391 [00:17<00:00, 22.98it/s]


loss 0.0018835831128353791


100%|██████████| 391/391 [00:17<00:00, 23.00it/s]


loss 0.001810732534226111


100%|██████████| 391/391 [00:17<00:00, 22.85it/s]


loss 0.0016130763146614828


100%|██████████| 391/391 [00:17<00:00, 22.99it/s]


loss 0.0014663738882406722


100%|██████████| 391/391 [00:17<00:00, 22.98it/s]


loss 0.001427437123039242


100%|██████████| 391/391 [00:17<00:00, 22.71it/s]


loss 0.001356918129104111


100%|██████████| 391/391 [00:17<00:00, 22.89it/s]


loss 0.001101463505238428


100%|██████████| 391/391 [00:16<00:00, 23.02it/s]


loss 0.001360349662271757


100%|██████████| 391/391 [00:17<00:00, 22.93it/s]


loss 0.0011649115154084386


100%|██████████| 391/391 [00:17<00:00, 22.94it/s]
  _warn_prf(average, modifier, msg_start, len(result))


loss 0.0010041763236814077


100%|██████████| 391/391 [00:17<00:00, 22.70it/s]


loss 0.0009433397416384471


100%|██████████| 391/391 [00:16<00:00, 23.00it/s]


loss 0.001089082404640932


100%|██████████| 391/391 [00:17<00:00, 22.96it/s]


loss 0.0009467902323322655


100%|██████████| 391/391 [00:17<00:00, 22.78it/s]


loss 0.0012065145125057485


100%|██████████| 391/391 [00:17<00:00, 22.94it/s]


loss 0.0008661783052064404


100%|██████████| 391/391 [00:17<00:00, 22.72it/s]


loss 0.0006482906280914048


100%|██████████| 391/391 [00:17<00:00, 22.88it/s]


loss 0.0007396837004714662


100%|██████████| 391/391 [00:17<00:00, 22.94it/s]


loss 0.0007257903585351272


100%|██████████| 391/391 [00:17<00:00, 22.94it/s]


loss 0.0007963967928633122


100%|██████████| 391/391 [00:17<00:00, 22.86it/s]


loss 0.0006891657817935216


100%|██████████| 391/391 [00:17<00:00, 22.97it/s]


loss 0.000558114643363506


100%|██████████| 391/391 [00:17<00:00, 22.75it/s]


loss 0.000501317435198108


100%|██████████| 391/391 [00:17<00:00, 22.93it/s]


loss 0.0005853643251100348


100%|██████████| 391/391 [00:17<00:00, 22.97it/s]


loss 0.0005418699680269449


100%|██████████| 391/391 [00:17<00:00, 22.94it/s]


loss 0.0006959800252157515


100%|██████████| 391/391 [00:17<00:00, 22.90it/s]


loss 0.000519693784616769


100%|██████████| 391/391 [00:17<00:00, 22.78it/s]


loss 0.0005387465607131059


100%|██████████| 391/391 [00:17<00:00, 22.77it/s]


loss 0.0006894892267992615


100%|██████████| 391/391 [00:17<00:00, 22.95it/s]


loss 0.0005100340694230755


100%|██████████| 391/391 [00:17<00:00, 22.90it/s]
  _warn_prf(average, modifier, msg_start, len(result))


loss 0.0003408014255812835


100%|██████████| 391/391 [00:17<00:00, 22.89it/s]


loss 0.0004662373354277753


100%|██████████| 391/391 [00:17<00:00, 22.96it/s]


loss 0.0008722621491160634


100%|██████████| 391/391 [00:16<00:00, 23.03it/s]


loss 0.00028366782400618477


100%|██████████| 391/391 [00:17<00:00, 22.78it/s]


loss 0.0004015527691062325


100%|██████████| 391/391 [00:17<00:00, 22.81it/s]


loss 0.0005529725643535368


100%|██████████| 391/391 [00:17<00:00, 22.59it/s]


loss 0.00037697195887903247


100%|██████████| 391/391 [00:17<00:00, 22.61it/s]


loss 0.0005086646037399331


100%|██████████| 391/391 [00:17<00:00, 22.65it/s]


loss 0.0003079691246534182


100%|██████████| 391/391 [00:17<00:00, 22.53it/s]


loss 0.0004894381586161247


100%|██████████| 391/391 [00:17<00:00, 22.39it/s]


loss 0.0003772863322170972


100%|██████████| 391/391 [00:17<00:00, 22.74it/s]


loss 0.0002687504069366873


100%|██████████| 391/391 [00:17<00:00, 22.71it/s]


loss 0.00033642052497566334


100%|██████████| 391/391 [00:17<00:00, 22.96it/s]


loss 0.00028492217346088894


100%|██████████| 391/391 [00:16<00:00, 23.17it/s]


loss 0.00016225199471912622


100%|██████████| 391/391 [00:16<00:00, 23.63it/s]


loss 0.00032890894775116004


100%|██████████| 391/391 [00:16<00:00, 23.90it/s]


loss 0.0003169662148748172


100%|██████████| 391/391 [00:16<00:00, 24.06it/s]


loss 0.0002885019965575906


100%|██████████| 391/391 [00:16<00:00, 23.91it/s]


loss 0.0002656542985389717


100%|██████████| 391/391 [00:16<00:00, 23.77it/s]

loss 0.0003544157526190349





In [76]:
pred,labels = eval_model(model,dl_test)
f1 = metrics.f1_score(labels.cpu(),pred.cpu(),average='macro')
precision = metrics.precision_score(labels.cpu(),pred.cpu(),average='macro')
recall = metrics.recall_score(labels.cpu(),pred.cpu(),average='macro')
accuracy = metrics.accuracy_score(labels.cpu(),pred.cpu())
wandb.log({'precision':precision,'recall':recall,'f1':f1,'accuracy':accuracy},step=epoch)

In [77]:
print(f"precision {precision}")
print(f"recall {recall}")
print(f"f1 {f1}")
print(f"accuracy {accuracy}")

precision 0.6973167424506321
recall 0.6958
f1 0.6944480520473264
accuracy 0.6958


In [39]:
torch.save({'rn18':model.state_dict(),'class_order':class_order,'loss_train':loss_train},'rn18_pass_100epochs_cross_entropy.pth')