In [1]:
from typing import Any, OrderedDict, List, Dict
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.models.resnet import resnet50
from torchvision.transforms import v2 as T
import matplotlib.pyplot as plt
import json
from torchmetrics import Accuracy
from lightly.transforms.utils import IMAGENET_NORMALIZE
from PIL import Image

## Load Data

In [2]:
class ImageDataset(Dataset):
    def __init__(self, set_map: List[Dict], transform=None) -> None:
        ''' Each item in set_map is expected to contain:
                img_path: Full path to image,
                label: Label corresponding to image at img_path
        '''

        self.set_map = set_map
        self.transform=transform

    def __len__(self):
        return len(self.set_map)
    
    def __getitem__(self, index):   
        sample = self.set_map[index]

        # image = read_image(sample['img_path'])
        image = Image.open(sample['img_path'])

        if self.transform:
            image = self.transform(image)

        return dict(image=image, **sample)

In [3]:
with open('./data/train.json', 'r') as file:
    train_set_map = json.load(file)

In [4]:
with open('./data/test.json', 'r') as file:
    test_set_map = json.load(file)

In [5]:
label_map = {
    'DGG': 0,
    'PH': 1,
    'EH': 2
}

domain_map = {
    'cartoon': 0,
    'art_painting': 1,
    'photo': 2,
}

In [6]:
for elem in train_set_map:
    elem['label'] = label_map[elem['label']]
for elem in test_set_map:
    elem['label'] = label_map[elem['label']]

In [7]:
transform = T.Compose([
    T.Resize(96),
    T.ToTensor(),
    T.Normalize(
        mean=IMAGENET_NORMALIZE["mean"],
        std=IMAGENET_NORMALIZE["std"],
    ),
])



In [8]:
train_set = ImageDataset(train_set_map, transform=transform)
test_set = ImageDataset(test_set_map, transform=transform)

In [9]:
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=4, persistent_workers=True, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4, persistent_workers=True, pin_memory=True)

## Validate Data

In [10]:
stats = {}
for elem in train_set_map:
    domain = elem['domain']
    label = elem['label']

    if domain not in stats.keys() :
        stats[domain] = {}

    if label not in stats[domain].keys():
        stats[domain][label] = 1
    else:
        stats[domain][label] += 1

In [11]:
stats

{'art_painting': {0: 771, 2: 77, 1: 77},
 'photo': {1: 641, 2: 64, 0: 64},
 'cartoon': {2: 696, 1: 70, 0: 70}}

In [12]:
stats = {}
for elem in test_set_map:
    domain = elem['domain']
    label = elem['label']

    if domain not in stats.keys() :
        stats[domain] = {}

    if label not in stats[domain].keys():
        stats[domain][label] = 1
    else:
        stats[domain][label] += 1

In [13]:
stats

{'cartoon': {2: 85, 1: 72, 0: 82},
 'art_painting': {1: 75, 0: 77, 2: 41},
 'photo': {1: 71, 2: 46, 0: 58}}

## Train Models

In [14]:
import pytorch_lightning as L

In [15]:
from typing import Any


class LinearProbe(L.LightningModule):
    def __init__(self, backbone, emb_dim, num_classes, lr=1e-2) -> None:
        super().__init__()
        
        self.backbone: nn.Module = backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.linear_head: nn.Module = nn.Linear(emb_dim, num_classes)
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr

        self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.linear_head(x)

        return x

    def training_step(self, batch, bacth_idx):
        X = batch['image']
        t = batch['label']

        y = self.forward(X)
        loss = self.criterion(y, t)

        print(loss.item())

        return loss
    
    def test_step(self, batch, batch_idx):
        X = batch['image']
        t = batch['label']

        y = self.forward(X)
        acc = self.test_accuracy(y, t)

        self.log('accuracy', acc, on_epoch=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.linear_head.parameters(), lr=self.lr)

In [16]:
def get_backbone_from_ckpt(ckpt_path: str) -> torch.nn.Module:
    state_dict = torch.load(ckpt_path)["state_dict"]
    state_dict = OrderedDict([
        (".".join(name.split(".")[1:]), param) for name, param in state_dict.items() if name.startswith("backbone")
    ])

    return state_dict

In [17]:
# Baseline Model
model_bl  = resnet50()
weights_bl = get_backbone_from_ckpt("./r50_bt.ckpt")
model_bl.load_state_dict(weights_bl, strict=False)
model_bl.fc = torch.nn.Identity()
model_bl = model_bl.cuda()

In [18]:
# MixStyle Model
model_ms  = resnet50()
weights_ms = get_backbone_from_ckpt("./r50_bt_ms.ckpt")
model_ms.load_state_dict(weights_ms, strict=False)
model_ms.fc = torch.nn.Identity()
model_ms = model_ms.cuda()

In [19]:
baseline_module = LinearProbe(model_bl, emb_dim=2048, num_classes=3)

In [20]:
mixstyle_module = LinearProbe(model_ms, emb_dim=2048, num_classes=3)

In [21]:
trainer = L.Trainer(
    max_epochs=100
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [22]:
trainer.fit(baseline_module, train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | backbone      | ResNet             | 23.5 M
1 | linear_head   | Linear             | 6.1 K 
2 | criterion     | CrossEntropyLoss   | 0     
3 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
6.1 K     Trainable params
23.5 M    Non-trainable params
23.5 M    Total params
94.057    Total estimated model params size (MB)
/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

1.175296425819397
0.31649571657180786
0.3100240230560303
0.4458526074886322
0.37306973338127136
0.23230485618114471
0.25868722796440125
0.12730392813682556
0.2822354733943939
0.24386008083820343
0.14292922616004944
0.12474647909402847
0.18661458790302277
0.2298240065574646
0.28033334016799927
0.090495266020298
0.2454240918159485
0.24327369034290314
0.31870877742767334
0.18903718888759613
0.07631593942642212
0.10033032298088074
0.06573393195867538
0.18748188018798828
0.1052197515964508
0.10693855583667755
0.16194495558738708
0.1376449465751648
0.1421748548746109
0.13850520551204681
0.13192619383335114
0.05953967571258545
0.03717106208205223
0.1296527087688446
0.07481472939252853
0.055209096521139145
0.07285809516906738
0.07221223413944244
0.08966465294361115
0.07240042090415955
0.06396527588367462
0.05232544615864754
0.05549689382314682
0.061534661799669266
0.043995797634124756
0.06096421927213669
0.043296556919813156
0.08755286037921906
0.03782903775572777
0.049575433135032654
0.045216

`Trainer.fit` stopped: `max_epochs=100` reached.


0.005249543581157923


In [23]:
trainer.test(baseline_module, dataloaders=train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:492: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Testing: |          | 0/? [00:00<?, ?it/s]

/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 256. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 226. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'accuracy': 0.999209463596344}]

In [24]:
trainer.test(baseline_module, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 95. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'accuracy': 0.8698517084121704}]

In [25]:
trainer = L.Trainer(
    max_epochs=100
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [26]:
trainer.fit(mixstyle_module, train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | backbone      | ResNet             | 23.5 M
1 | linear_head   | Linear             | 6.1 K 
2 | criterion     | CrossEntropyLoss   | 0     
3 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
6.1 K     Trainable params
23.5 M    Non-trainable params
23.5 M    Total params
94.057    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

1.0774705410003662
0.3226809799671173
0.29349106550216675
0.21751800179481506
0.13583745062351227
0.27827849984169006
0.325274795293808
0.2226797640323639
0.32344603538513184
0.16000759601593018
0.09119223803281784
0.25491365790367126
0.18937762081623077
0.16831061244010925
0.29702064394950867
0.1677790880203247
0.22914467751979828
0.20766809582710266
0.15974678099155426
0.19694367051124573
0.16692513227462769
0.1833813637495041
0.0689651295542717
0.07936429232358932
0.1593993902206421
0.12070804834365845
0.14030073583126068
0.12539277970790863
0.12420614063739777
0.15066400170326233
0.15351879596710205
0.13898831605911255
0.05832604691386223
0.05888712406158447
0.10706485062837601
0.09885002672672272
0.08950735628604889
0.10407847911119461
0.14243493974208832
0.07994271069765091
0.05500555783510208
0.05675293132662773
0.050482481718063354
0.07092445343732834
0.08632172644138336
0.11938898265361786
0.08508981764316559
0.08793926984071732
0.07573039829730988
0.08525687456130981
0.082909

`Trainer.fit` stopped: `max_epochs=100` reached.


0.002428445965051651


In [27]:
trainer.test(mixstyle_module, dataloaders=train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'accuracy': 1.0}]

In [28]:
trainer.test(mixstyle_module, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'accuracy': 0.8830313086509705}]