In [1]:
import os
import collections
import random
import json

from data_modules.domainnet_dataset import ImageDataset
from data_modules.domainnet_metadata import DOMAIN_NET_DOMAINS, DOMAIN_NET_CLASSES, DOMAIN_NET_DIVISIONS

from torch.utils.data import DataLoader
from torchvision.transforms import v2 as T

from yasin_utils.image import imagenet_normalize

In [2]:
division_map = {}
for k, v in DOMAIN_NET_DIVISIONS.items():
    for i in v:
        division_map[i] = k

In [3]:
# root = '/data/domainnet_v1.0'
# set_map = []
# for domain in DOMAIN_NET_DOMAINS:
#     try:
#         labels = os.listdir(os.path.join(root, domain))
#     except:
#         raise Exception(f'{domain} directory not found.')
#     for label in labels:
#         for image in os.listdir(os.path.join(root, domain, label)):
#             if label in division_map:
#                 set_map.append(
#                     dict(
#                         img_path=os.path.join(root, domain, label, image),
#                         label=division_map[label],
#                         domain=domain
#                         )
#                     )

In [4]:
# random.shuffle(set_map)

In [5]:
# train_set_map, test_set_map = set_map[:int(0.9*len(set_map))], set_map[int(0.9*len(set_map)):]

In [6]:
with open('/home/yasin/repos/dispatch_smol/pretraining/data_modules/train_set_map_domainnet.json') as f:
    train_set_map_ = json.load(f)
with open('/home/yasin/repos/dispatch_smol/pretraining/data_modules/test_set_map_domainnet.json') as f:
    test_set_map_ = json.load(f)

In [7]:
train_set_map = []
for sample in train_set_map_:
    if sample['label'] in division_map:
        new_sample = {**sample}
        new_sample['label'] = division_map[sample['label']]
        train_set_map.append(new_sample)

In [8]:
test_set_map = []
for sample in test_set_map_:
    if sample['label'] in division_map:
        new_sample = {**sample}
        new_sample['label'] = division_map[sample['label']]
        test_set_map.append(new_sample)

In [9]:
len(train_set_map), len(test_set_map)

(241485, 26659)

In [10]:
dict(collections.Counter([(d['label'], d['domain']) for d in train_set_map]))

{('tool', 'quickdraw'): 12600,
 ('tool', 'painting'): 4618,
 ('electricity', 'real'): 10563,
 ('cloth', 'infograph'): 3370,
 ('tool', 'real'): 11645,
 ('mammal', 'quickdraw'): 11309,
 ('cloth', 'real'): 10397,
 ('tool', 'clipart'): 3441,
 ('mammal', 'sketch'): 4644,
 ('electricity', 'painting'): 2669,
 ('furniture', 'painting'): 4472,
 ('cloth', 'painting'): 3192,
 ('furniture', 'sketch'): 6760,
 ('building', 'quickdraw'): 9479,
 ('furniture', 'real'): 15417,
 ('building', 'real'): 10874,
 ('mammal', 'clipart'): 3096,
 ('furniture', 'quickdraw'): 15745,
 ('mammal', 'infograph'): 3212,
 ('mammal', 'real'): 13996,
 ('electricity', 'infograph'): 3744,
 ('electricity', 'quickdraw'): 11284,
 ('cloth', 'sketch'): 4374,
 ('building', 'painting'): 4610,
 ('mammal', 'painting'): 8114,
 ('electricity', 'clipart'): 2803,
 ('cloth', 'clipart'): 3275,
 ('building', 'sketch'): 4241,
 ('furniture', 'clipart'): 5199,
 ('tool', 'sketch'): 4373,
 ('cloth', 'quickdraw'): 10285,
 ('building', 'clipart'): 

In [11]:
# balance_map = {
#     'furniture': {'main_domain': 'clipart', 'target': 5_253},
#     'cloth': {'main_domain': 'real', 'target': 10_377},
#     'electricity': {'main_domain': 'infograph', 'target': 3_745},
#     'building': {'main_domain': 'painting', 'target': 4_570},
#     'mammal': {'main_domain': 'quickdraw', 'target': 11_202},
#     'tool': {'main_domain': 'sketch', 'target': 4_401},
# }

In [12]:
balance_map = {
    'furniture': {'main_domain': 'clipart', 'target': 1_000},
    'cloth': {'main_domain': 'real', 'target': 1_000},
    'electricity': {'main_domain': 'infograph', 'target': 1_000},
    'building': {'main_domain': 'painting', 'target': 1_000},
    'mammal': {'main_domain': 'quickdraw', 'target': 1_000},
    'tool': {'main_domain': 'sketch', 'target': 1_000},
}

In [13]:
# balance_map = {
#     'furniture': {'main_domain': 'clipart', 'target': 1_000},
#     'cloth': {'main_domain': 'real', 'target': 1_000},
#     'electricity': {'main_domain': 'infograph', 'target': 1_000},
#     'building': {'main_domain': 'painting', 'target': 1_000},
#     'mammal': {'main_domain': 'quickdraw', 'target': 1_000},
#     'tool': {'main_domain': 'sketch', 'target': 1_000},
# }

In [14]:
# balance_map = {
#     'furniture': {'main_domain': 'clipart', 'target': 100},
#     'cloth': {'main_domain': 'real', 'target': 100},
#     'electricity': {'main_domain': 'infograph', 'target': 100},
#     'building': {'main_domain': 'painting', 'target': 100},
#     'mammal': {'main_domain': 'quickdraw', 'target': 100},
#     'tool': {'main_domain': 'sketch', 'target': 100},
# }

In [15]:
def prune(set_map, balance_map):
    pruned_set_map = []
    
    domains = set([i['domain'] for i in set_map])
    counts = {
        l: {
            d: 0 for d in domains
        } for l in balance_map
    }
    targets = {
        l: {
            d: balance_map[l]['target'] if d == balance_map[l]['main_domain'] else int(balance_map[l]['target']*0.01) for d in domains
        } for l in balance_map
    }

    for i in set_map:
        label, domain = i['label'], i['domain']

        if counts[label][domain] < targets[label][domain]:
            pruned_set_map.append(i)

        counts[label][domain] += 1

    return pruned_set_map

In [16]:
train_set_map = prune(train_set_map, balance_map)

In [17]:
dict(collections.Counter([(d['label'], d['domain']) for d in train_set_map]))

{('tool', 'quickdraw'): 10,
 ('tool', 'painting'): 10,
 ('electricity', 'real'): 10,
 ('cloth', 'infograph'): 10,
 ('tool', 'real'): 10,
 ('mammal', 'quickdraw'): 1000,
 ('cloth', 'real'): 1000,
 ('tool', 'clipart'): 10,
 ('mammal', 'sketch'): 10,
 ('electricity', 'painting'): 10,
 ('furniture', 'painting'): 10,
 ('cloth', 'painting'): 10,
 ('furniture', 'sketch'): 10,
 ('building', 'quickdraw'): 10,
 ('furniture', 'real'): 10,
 ('building', 'real'): 10,
 ('mammal', 'clipart'): 10,
 ('furniture', 'quickdraw'): 10,
 ('mammal', 'infograph'): 10,
 ('mammal', 'real'): 10,
 ('electricity', 'infograph'): 1000,
 ('electricity', 'quickdraw'): 10,
 ('cloth', 'sketch'): 10,
 ('building', 'painting'): 1000,
 ('mammal', 'painting'): 10,
 ('electricity', 'clipart'): 10,
 ('cloth', 'clipart'): 10,
 ('building', 'sketch'): 10,
 ('furniture', 'clipart'): 1000,
 ('tool', 'sketch'): 1000,
 ('cloth', 'quickdraw'): 10,
 ('building', 'clipart'): 10,
 ('building', 'infograph'): 10,
 ('furniture', 'infograph

In [18]:
class UnbalancedDomainNetDataset(ImageDataset):
    def __init__(self, set_map, transform=None) -> None:

        super().__init__(set_map, transform)
        self.class_map = {
            'furniture': 0,
            'cloth': 1,
            'electricity': 2,
            'building': 3,
            'mammal': 4,
            'tool': 5,
        }
        self.domain_map = {
            'clipart': 0,
            'real': 1,
            'infograph': 2,
            'painting': 3,
            'quickdraw': 4,
            'sketch': 5,
        }

    def __getitem__(self, index):
        item = super().__getitem__(index)
        item['label'] = self.class_map[item['label']]
        item['domain'] = self.domain_map[item['domain']]

        return item

In [19]:
train_transform = T.Compose([
    T.RandomResizedCrop(128),
    T.RandomHorizontalFlip(),
    # T.Resize((128,128)),
    T.ToTensor(),
    imagenet_normalize,
])
val_transform = T.Compose([
    T.Resize(156),
    T.CenterCrop(128),
    # T.Resize((128,128)),
    T.ToTensor(),
    imagenet_normalize
])



In [20]:
train_set = UnbalancedDomainNetDataset(set_map=train_set_map, transform=train_transform)
test_set = UnbalancedDomainNetDataset(set_map=test_set_map, transform=val_transform)

In [21]:
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True, drop_last=True)
val_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=8, pin_memory=True, persistent_workers=True)

In [22]:
len(train_set), len(test_set)

(6300, 26659)

## Train Models

In [23]:
from typing import OrderedDict
import torch
from torch import nn
from torchvision.models.resnet import resnet50, ResNet50_Weights
import pytorch_lightning as L
from torchmetrics import Accuracy

In [24]:
class LinearProbe(L.LightningModule):
    def __init__(self, backbone, emb_dim, num_classes, lr=1e-3) -> None:
        super().__init__()
        
        self.backbone: nn.Module = backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.backbone.eval()

        # self.linear_head: nn.Module = nn.Linear(emb_dim, num_classes)
        self.linear_head: nn.Module = nn.Sequential(
            nn.Linear(emb_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.Linear(512, num_classes)
        ) 
        
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr

        self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.linear_head(x)

        return x

    def training_step(self, batch, bacth_idx):
        X = batch['image']
        t = batch['label']

        y = self.forward(X)
        loss = self.criterion(y, t)

        print(loss.item())

        return loss
    
    def test_step(self, batch, batch_idx):
        X = batch['image']
        t = batch['label']

        y = self.forward(X)
        acc = self.test_accuracy(y, t)

        self.log('accuracy', acc, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.linear_head.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 20)

        return [optimizer], [scheduler]

In [25]:
def get_backbone_from_ckpt(ckpt_path: str) -> torch.nn.Module:
    state_dict = torch.load(ckpt_path)["state_dict"]
    state_dict = OrderedDict([
        (".".join(name.split(".")[1:]), param) for name, param in state_dict.items() if name.startswith("backbone")
    ])

    return state_dict

In [26]:
# Baseline Model
model_bl  = resnet50(weights=ResNet50_Weights.DEFAULT)
# weights_bl = get_backbone_from_ckpt("/home/yasin/Downloads/final_baseline.ckpt")
# model_bl.load_state_dict(weights_bl, strict=False)
model_bl.fc = torch.nn.Identity()
model_bl = model_bl.cuda()

In [27]:
# MixStyle Model
model_ms  = resnet50()
weights_ms = get_backbone_from_ckpt("/home/yasin/Downloads/final_mixstyle.ckpt")
model_ms.load_state_dict(weights_ms, strict=False)
model_ms.fc = torch.nn.Identity()
model_ms = model_ms.cuda()

In [28]:
baseline_module = LinearProbe(model_bl, emb_dim=2048, num_classes=6, lr=1e-3)

In [29]:
mixstyle_module = LinearProbe(model_ms, emb_dim=2048, num_classes=6, lr=1e-3)

In [30]:
trainer = L.Trainer(
    max_epochs=20
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [31]:
trainer.fit(baseline_module, train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | backbone      | ResNet             | 23.5 M
1 | linear_head   | Sequential         | 2.6 M 
2 | criterion     | CrossEntropyLoss   | 0     
3 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
2.6 M     Trainable params
23.5 M    Non-trainable params
26.1 M    Total params
104.536   Total estimated model params size (MB)
/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (24) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

1.8243143558502197
1.476125717163086
1.0754354000091553
0.8993470668792725
0.7957234382629395
0.617152988910675
0.7471188902854919
0.7708296179771423
0.9365196228027344
0.6361665725708008
0.6245900392532349
0.6026021242141724
0.6494312882423401
0.45400765538215637
0.5014895796775818
0.667331337928772
0.5263014435768127
0.6133747696876526
0.4492204785346985
0.41073137521743774
0.534043550491333
0.615673303604126
0.3925480842590332
0.4419724941253662
0.4501001536846161
0.42576003074645996
0.5080721974372864
0.4747084975242615
0.3985210061073303
0.4329240620136261
0.5123441815376282
0.4530315399169922
0.5329002141952515
0.43316569924354553
0.4037638306617737
0.3545137047767639
0.5766234993934631
0.38909751176834106
0.4978066682815552
0.44544073939323425
0.4476734399795532
0.4516713321208954
0.4189143776893616
0.3120996952056885
0.4625685214996338
0.3722013831138611
0.5079835653305054
0.40774640440940857
0.4167310297489166
0.3454476296901703
0.36670780181884766
0.45628637075424194
0.364743

`Trainer.fit` stopped: `max_epochs=20` reached.


In [32]:
trainer.test(baseline_module, dataloaders=train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:492: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Testing: |          | 0/? [00:00<?, ?it/s]

/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 256. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'accuracy': 0.9627278447151184}]

In [33]:
trainer.test(baseline_module, dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 35. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'accuracy': 0.4554184377193451}]

In [34]:
trainer = L.Trainer(
    max_epochs=20
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [35]:
trainer.fit(mixstyle_module, train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | backbone      | ResNet             | 23.5 M
1 | linear_head   | Sequential         | 2.6 M 
2 | criterion     | CrossEntropyLoss   | 0     
3 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
2.6 M     Trainable params
23.5 M    Non-trainable params
26.1 M    Total params
104.536   Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

1.8211008310317993
1.396600604057312
0.9942277669906616
0.508196234703064
0.48753172159194946
0.7164696455001831
0.4693939983844757
0.7465087175369263
0.628608226776123
0.49176180362701416
0.6297223567962646
0.3639090359210968
0.42854562401771545
0.4240899384021759
0.3046439588069916
0.6265892386436462
0.43210944533348083
0.3792787790298462
0.47830790281295776
0.451437383890152
0.48723673820495605
0.4722122251987457
0.4607580602169037
0.31540998816490173
0.28745168447494507
0.2748413681983948
0.4079006016254425
0.27696913480758667
0.29561448097229004
0.30237677693367004
0.27563002705574036
0.3120799958705902
0.20630724728107452
0.292596697807312
0.3027072548866272
0.2741285264492035
0.23733307421207428
0.25165048241615295
0.36302125453948975
0.2211320847272873
0.2894200384616852
0.34269896149635315
0.2811563014984131
0.2058606594800949
0.25917983055114746
0.2532876133918762
0.27679896354675293
0.252267450094223
0.2217949777841568
0.24562206864356995
0.18824906647205353
0.18014027178287

`Trainer.fit` stopped: `max_epochs=20` reached.


In [36]:
trainer.test(mixstyle_module, dataloaders=train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'accuracy': 0.9830729365348816}]

In [37]:
trainer.test(mixstyle_module, dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'accuracy': 0.5080085396766663}]

| # Samples | Baseline | MixStyle |
| --------- | -------- | -------- |
| All       | 0.629    | 0.642    |
| 1_000     | 0.569    | 0.576    |
| 500       | 0.520    | 0.538    |
| 100       | 0.451    | 0.463    |