In [1]:
from typing import Any, OrderedDict, List, Dict
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models.resnet import resnet50
from torchvision.transforms import v2 as T
import matplotlib.pyplot as plt
import json
from torchmetrics import Accuracy
from lightly.transforms.utils import IMAGENET_NORMALIZE
from PIL import Image
from camelyon_ds import Camelyon17Dataset

## Load Data

In [2]:
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=IMAGENET_NORMALIZE["mean"],
        std=IMAGENET_NORMALIZE["std"],
    ),
])



In [3]:
train_set = Camelyon17Dataset('/data', train=True, transform=transform)
test_set = Camelyon17Dataset('/data', train=False, transform=transform)

len(train_set), len(test_set)

(22000, 2000)

In [4]:
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4, persistent_workers=True, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4, persistent_workers=True, pin_memory=True)

## Train Models

In [5]:
import pytorch_lightning as L

In [6]:
from typing import Any


class LinearProbe(L.LightningModule):
    def __init__(self, backbone, emb_dim, num_classes, lr=1e-2) -> None:
        super().__init__()
        
        self.backbone: nn.Module = backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.linear_head: nn.Module = nn.Linear(emb_dim, num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr

        self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.linear_head(x)

        return x

    def training_step(self, batch, bacth_idx):
        X = batch['image']
        t = batch['label']

        y = self.forward(X)
        loss = self.criterion(y, t)

        print(loss.item())

        return loss
    
    def test_step(self, batch, batch_idx):
        X = batch['image']
        t = batch['label']

        y = self.forward(X)
        acc = self.test_accuracy(y, t)

        self.log('accuracy', acc, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.linear_head.parameters(), lr=self.lr)

In [7]:
def get_backbone_from_ckpt(ckpt_path: str) -> torch.nn.Module:
    state_dict = torch.load(ckpt_path)["state_dict"]
    state_dict = OrderedDict([
        (".".join(name.split(".")[1:]), param) for name, param in state_dict.items() if name.startswith("backbone")
    ])

    return state_dict

In [8]:
# Baseline Model
model_bl  = resnet50()
weights_bl = get_backbone_from_ckpt("./data/r50_camelyon_bl.ckpt")
model_bl.load_state_dict(weights_bl, strict=False)
model_bl.fc = torch.nn.Identity()
model_bl = model_bl.cuda()

In [9]:
# DANN Model
model_ms  = resnet50()
weights_ms = get_backbone_from_ckpt("./data/r50_camelyon_dann.ckpt")
model_ms.load_state_dict(weights_ms, strict=False)
model_ms.fc = torch.nn.Identity()
model_ms = model_ms.cuda()

In [10]:
baseline_module = LinearProbe(model_bl, emb_dim=2048, num_classes=2, lr=3e-3)

In [11]:
mixstyle_module = LinearProbe(model_ms, emb_dim=2048, num_classes=2, lr=3e-3)

In [12]:
trainer = L.Trainer(
    max_epochs=20
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
trainer.fit(baseline_module, train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | backbone      | ResNet             | 23.5 M
1 | linear_head   | Linear             | 4.1 K 
2 | criterion     | CrossEntropyLoss   | 0     
3 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
4.1 K     Trainable params
23.5 M    Non-trainable params
23.5 M    Total params
94.049    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

0.7748351097106934
0.3930361866950989
0.36561936140060425
0.1539621651172638
0.13090857863426208
0.1741265058517456
0.12321457266807556
0.09465347230434418
0.06610023975372314
0.0831676796078682
0.06008481606841087
0.07234650105237961
0.07587406039237976
0.12214092910289764
0.06472671031951904
0.06907260417938232
0.06879724562168121
0.08123808354139328
0.08457714319229126
0.05592609569430351
0.07577507197856903
0.1053270474076271
0.01542244665324688
0.08638010174036026
0.07204083353281021
0.051907751709222794
0.06806445121765137
0.06119644641876221
0.05219480022788048
0.09589998424053192
0.08551240712404251
0.018202565610408783
0.06939946860074997
0.08605039864778519
0.04719291627407074
0.03927775099873543
0.05886506289243698
0.08311272412538528
0.10904814302921295
0.043525949120521545
0.06388027220964432
0.028819402679800987
0.03666999563574791
0.03820370137691498
0.04177453741431236
0.03935424983501434
0.043311819434165955
0.06282440572977066
0.051638808101415634
0.03690973296761513


`Trainer.fit` stopped: `max_epochs=20` reached.


In [14]:
trainer.test(baseline_module, dataloaders=train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:492: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'accuracy': 0.9968181848526001}]

In [15]:
trainer.test(baseline_module, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'accuracy': 0.968000054359436}]

In [16]:
trainer = L.Trainer(
    max_epochs=20
)

GPU available: True (cuda), used: True


TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [17]:
trainer.fit(mixstyle_module, train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | backbone      | ResNet             | 23.5 M
1 | linear_head   | Linear             | 4.1 K 
2 | criterion     | CrossEntropyLoss   | 0     
3 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
4.1 K     Trainable params
23.5 M    Non-trainable params
23.5 M    Total params
94.049    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

0.7968271374702454
5.9435858726501465
3.664652109146118
0.3239632844924927
2.627224922180176
1.973006010055542
0.27939552068710327
0.2512572407722473
0.7290131449699402
0.9662607908248901
0.8116952776908875
0.25051185488700867
0.19135038554668427
0.1785527616739273
0.16405411064624786
0.2068192958831787
0.3124607801437378
0.2758455276489258
0.1924758106470108
0.110450878739357
0.0657539963722229
0.094569131731987
0.09154082089662552
0.17134159803390503
0.10733956098556519
0.19085811078548431
0.19814352691173553
0.058545760810375214
0.0803857371211052
0.05228205770254135
0.22349348664283752
0.046831805258989334
0.09646009653806686
0.10329664498567581
0.09272825717926025
0.14019916951656342
0.1402517557144165
0.26475685834884644
0.1256134957075119
0.15644535422325134
0.15390998125076294
0.2732483148574829
0.1186642050743103
0.160844624042511
0.1635025292634964
0.17944683134555817
0.0681806206703186
0.0859154462814331
0.07882576435804367
0.176054447889328
0.18240563571453094
0.10918153822

`Trainer.fit` stopped: `max_epochs=20` reached.


In [18]:
trainer.test(mixstyle_module, dataloaders=train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'accuracy': 0.993363618850708}]

In [19]:
trainer.test(mixstyle_module, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'accuracy': 0.9675000309944153}]