In [1]:
import os
import collections
import random

from data_modules.domainnet_dataset import ImageDataset
from data_modules.domainnet_metadata import DOMAIN_NET_DOMAINS, DOMAIN_NET_CLASSES, DOMAIN_NET_DIVISIONS

from torch.utils.data import DataLoader
from torchvision.transforms import v2 as T

from yasin_utils.image import imagenet_normalize

In [2]:
division_map = {}
for k, v in DOMAIN_NET_DIVISIONS.items():
    for i in v:
        division_map[i] = k

In [3]:
root = '/data/domainnet_v1.0'
set_map = []
for domain in DOMAIN_NET_DOMAINS:
    try:
        labels = os.listdir(os.path.join(root, domain))
    except:
        raise Exception(f'{domain} directory not found.')
    for label in labels:
        for image in os.listdir(os.path.join(root, domain, label)):
            if label in division_map:
                set_map.append(
                    dict(
                        img_path=os.path.join(root, domain, label, image),
                        label=division_map[label],
                        domain=domain
                        )
                    )

In [4]:
random.shuffle(set_map)

In [5]:
train_set_map, test_set_map = set_map[:int(0.9*len(set_map))], set_map[int(0.9*len(set_map)):]

In [6]:
len(set_map), len(train_set_map), len(test_set_map)

(268144, 241329, 26815)

In [7]:
dict(collections.Counter([(d['label'], d['domain']) for d in train_set_map]))

{('building', 'infograph'): 2858,
 ('furniture', 'clipart'): 5254,
 ('building', 'real'): 10853,
 ('tool', 'real'): 11637,
 ('furniture', 'sketch'): 6757,
 ('building', 'sketch'): 4267,
 ('tool', 'clipart'): 3442,
 ('furniture', 'real'): 15426,
 ('electricity', 'quickdraw'): 11240,
 ('mammal', 'painting'): 8096,
 ('building', 'clipart'): 2441,
 ('mammal', 'quickdraw'): 11264,
 ('cloth', 'real'): 10344,
 ('cloth', 'quickdraw'): 10288,
 ('building', 'quickdraw'): 9441,
 ('tool', 'quickdraw'): 12647,
 ('building', 'painting'): 4596,
 ('furniture', 'quickdraw'): 15749,
 ('electricity', 'infograph'): 3719,
 ('mammal', 'real'): 13981,
 ('furniture', 'infograph'): 5840,
 ('electricity', 'sketch'): 3672,
 ('electricity', 'real'): 10547,
 ('tool', 'sketch'): 4394,
 ('furniture', 'painting'): 4524,
 ('mammal', 'clipart'): 3076,
 ('mammal', 'sketch'): 4626,
 ('cloth', 'clipart'): 3257,
 ('electricity', 'painting'): 2666,
 ('cloth', 'sketch'): 4377,
 ('tool', 'infograph'): 2801,
 ('electricity', '

In [8]:
# balance_map = {
#     'furniture': {'main_domain': 'clipart', 'target': 5_253},
#     'cloth': {'main_domain': 'real', 'target': 10_377},
#     'electricity': {'main_domain': 'infograph', 'target': 3_745},
#     'building': {'main_domain': 'painting', 'target': 4_570},
#     'mammal': {'main_domain': 'quickdraw', 'target': 11_202},
#     'tool': {'main_domain': 'sketch', 'target': 4_401},
# }

In [9]:
balance_map = {
    'furniture': {'main_domain': 'clipart', 'target': 3_745},
    'cloth': {'main_domain': 'real', 'target': 3_745},
    'electricity': {'main_domain': 'infograph', 'target': 3_745},
    'building': {'main_domain': 'painting', 'target': 3_745},
    'mammal': {'main_domain': 'quickdraw', 'target': 3_745},
    'tool': {'main_domain': 'sketch', 'target': 3_745},
}

In [10]:
# balance_map = {
#     'furniture': {'main_domain': 'clipart', 'target': 1_000},
#     'cloth': {'main_domain': 'real', 'target': 1_000},
#     'electricity': {'main_domain': 'infograph', 'target': 1_000},
#     'building': {'main_domain': 'painting', 'target': 1_000},
#     'mammal': {'main_domain': 'quickdraw', 'target': 1_000},
#     'tool': {'main_domain': 'sketch', 'target': 1_000},
# }

In [11]:
# balance_map = {
#     'furniture': {'main_domain': 'clipart', 'target': 100},
#     'cloth': {'main_domain': 'real', 'target': 100},
#     'electricity': {'main_domain': 'infograph', 'target': 100},
#     'building': {'main_domain': 'painting', 'target': 100},
#     'mammal': {'main_domain': 'quickdraw', 'target': 100},
#     'tool': {'main_domain': 'sketch', 'target': 100},
# }

In [12]:
def prune(set_map, balance_map):
    pruned_set_map = []
    
    domains = set([i['domain'] for i in set_map])
    counts = {
        l: {
            d: 0 for d in domains
        } for l in balance_map
    }
    targets = {
        l: {
            d: balance_map[l]['target'] if d == balance_map[l]['main_domain'] else int(balance_map[l]['target']*1.0) for d in domains
        } for l in balance_map
    }

    for i in set_map:
        label, domain = i['label'], i['domain']

        if counts[label][domain] < targets[label][domain]:
            pruned_set_map.append(i)

        counts[label][domain] += 1

    return pruned_set_map

In [13]:
train_set_map = prune(train_set_map, balance_map)

In [14]:
dict(collections.Counter([(d['label'], d['domain']) for d in train_set_map]))

{('building', 'infograph'): 2858,
 ('furniture', 'clipart'): 3745,
 ('building', 'real'): 3745,
 ('tool', 'real'): 3745,
 ('furniture', 'sketch'): 3745,
 ('building', 'sketch'): 3745,
 ('tool', 'clipart'): 3442,
 ('furniture', 'real'): 3745,
 ('electricity', 'quickdraw'): 3745,
 ('mammal', 'painting'): 3745,
 ('building', 'clipart'): 2441,
 ('mammal', 'quickdraw'): 3745,
 ('cloth', 'real'): 3745,
 ('cloth', 'quickdraw'): 3745,
 ('building', 'quickdraw'): 3745,
 ('tool', 'quickdraw'): 3745,
 ('building', 'painting'): 3745,
 ('furniture', 'quickdraw'): 3745,
 ('electricity', 'infograph'): 3719,
 ('mammal', 'real'): 3745,
 ('furniture', 'infograph'): 3745,
 ('electricity', 'sketch'): 3672,
 ('electricity', 'real'): 3745,
 ('tool', 'sketch'): 3745,
 ('furniture', 'painting'): 3745,
 ('mammal', 'clipart'): 3076,
 ('mammal', 'sketch'): 3745,
 ('cloth', 'clipart'): 3257,
 ('electricity', 'painting'): 2666,
 ('cloth', 'sketch'): 3745,
 ('tool', 'infograph'): 2801,
 ('electricity', 'clipart'): 

In [15]:
class UnbalancedDomainNetDataset(ImageDataset):
    def __init__(self, root: str, set_map, transform=None) -> None:

        super().__init__(set_map, transform)
        self.class_map = {
            'furniture': 0,
            'cloth': 1,
            'electricity': 2,
            'building': 3,
            'mammal': 4,
            'tool': 5,
        }
        self.domain_map = {
            'clipart': 0,
            'real': 1,
            'infograph': 2,
            'painting': 3,
            'quickdraw': 4,
            'sketch': 5,
        }

    def __getitem__(self, index):
        item = super().__getitem__(index)
        item['label'] = self.class_map[item['label']]
        item['domain'] = self.domain_map[item['domain']]

        return item

In [16]:
train_transform = T.Compose([
    T.RandomResizedCrop(128),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    imagenet_normalize,
])
val_transform = T.Compose([
    T.Resize(156),
    T.CenterCrop(128),
    T.ToTensor(),
    imagenet_normalize
])



In [17]:
train_set = UnbalancedDomainNetDataset(root=root, set_map=train_set_map, transform=train_transform)
test_set = UnbalancedDomainNetDataset(root=root, set_map=test_set_map, transform=val_transform)

In [18]:
train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True, drop_last=True)
val_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=8, pin_memory=True, persistent_workers=True)

In [19]:
len(train_set), len(test_set)

(126704, 26815)

## Train Models

In [20]:
from typing import OrderedDict
import torch
from torch import nn
from torchvision.models.resnet import resnet50
import pytorch_lightning as L
from torchmetrics import Accuracy

In [21]:
class LinearProbe(L.LightningModule):
    def __init__(self, backbone, emb_dim, num_classes, lr=1e-3) -> None:
        super().__init__()
        
        self.backbone: nn.Module = backbone
        for param in self.backbone.parameters():
            param.requires_grad = False

        self.backbone.eval()

        # self.linear_head: nn.Module = nn.Linear(emb_dim, num_classes)
        self.linear_head: nn.Module = nn.Sequential(
            nn.Linear(emb_dim, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.Linear(512, num_classes)
        ) 
        
        self.criterion = nn.CrossEntropyLoss()
        self.lr = lr

        self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)

    def forward(self, x):
        x = self.backbone(x)
        x = self.linear_head(x)

        return x

    def training_step(self, batch, bacth_idx):
        X = batch['image']
        t = batch['label']

        y = self.forward(X)
        loss = self.criterion(y, t)

        print(loss.item())

        return loss
    
    def test_step(self, batch, batch_idx):
        X = batch['image']
        t = batch['label']

        y = self.forward(X)
        acc = self.test_accuracy(y, t)

        self.log('accuracy', acc, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.linear_head.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 20)

        return [optimizer], [scheduler]

In [22]:
def get_backbone_from_ckpt(ckpt_path: str) -> torch.nn.Module:
    state_dict = torch.load(ckpt_path)["state_dict"]
    state_dict = OrderedDict([
        (".".join(name.split(".")[1:]), param) for name, param in state_dict.items() if name.startswith("backbone")
    ])

    return state_dict

In [23]:
# Baseline Model
model_bl  = resnet50()
weights_bl = get_backbone_from_ckpt("/home/yasin/Downloads/r50_domainnet_baseline.ckpt")
model_bl.load_state_dict(weights_bl, strict=False)
model_bl.fc = torch.nn.Identity()
model_bl = model_bl.cuda()

In [24]:
# MixStyle Model
model_ms  = resnet50()
weights_ms = get_backbone_from_ckpt("/home/yasin/Downloads/r50_domainnet_mixstyle_no_ms_head_pretrain.ckpt")
model_ms.load_state_dict(weights_ms, strict=False)
model_ms.fc = torch.nn.Identity()
model_ms = model_ms.cuda()

In [25]:
baseline_module = LinearProbe(model_bl, emb_dim=2048, num_classes=6, lr=1e-3)

In [26]:
mixstyle_module = LinearProbe(model_ms, emb_dim=2048, num_classes=6, lr=1e-3)

In [27]:
trainer = L.Trainer(
    max_epochs=20
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [28]:
trainer.fit(baseline_module, train_loader)

Missing logger folder: /home/yasin/repos/dispatch_smol/pretraining/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | backbone      | ResNet             | 23.5 M
1 | linear_head   | Sequential         | 2.6 M 
2 | criterion     | CrossEntropyLoss   | 0     
3 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
2.6 M     Trainable params
23.5 M    Non-trainable params
26.1 M    Total params
104.536   Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

1.7920619249343872
1.807389497756958
1.5932101011276245
1.4589028358459473
1.3797285556793213
1.3428058624267578
1.2084718942642212
1.174642562866211
0.9916455745697021
1.2584822177886963
1.1719002723693848
0.9935769438743591
0.9723197221755981
0.994493305683136
0.9968547821044922
0.8600085973739624
0.8794234991073608
1.0377098321914673
0.9433909058570862
0.9999668598175049
0.9939106702804565
0.9613649249076843
0.9508869647979736
0.8805478811264038
0.7819434404373169
0.873111367225647
0.825409471988678
1.0162822008132935
0.8017672300338745
0.8362581133842468
0.974381685256958
0.9086064696311951
1.0175597667694092
0.8914541602134705
0.9087833166122437
1.0054274797439575
0.9638128280639648
0.9213946461677551
0.9171159267425537
0.9757367372512817
0.8514981269836426
1.0457773208618164
0.927155077457428
0.9685695171356201
0.7257282137870789
0.8045173287391663
0.935858428478241
0.8334203958511353
0.8921822905540466
0.8682836890220642
0.7475568652153015
0.9003480672836304
0.781627893447876
0.

`Trainer.fit` stopped: `max_epochs=20` reached.


In [29]:
trainer.test(baseline_module, dataloaders=train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:492: Your `test_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test dataloaders.


Testing: |          | 0/? [00:00<?, ?it/s]

/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 256. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'accuracy': 0.9035851955413818}]

In [30]:
trainer.test(baseline_module, dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

/home/yasin/miniforge3/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:77: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 191. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


[{'accuracy': 0.8068618178367615}]

In [31]:
trainer = L.Trainer(
    max_epochs=20
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [32]:
trainer.fit(mixstyle_module, train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type               | Params
-----------------------------------------------------
0 | backbone      | ResNet             | 23.5 M
1 | linear_head   | Sequential         | 2.6 M 
2 | criterion     | CrossEntropyLoss   | 0     
3 | test_accuracy | MulticlassAccuracy | 0     
-----------------------------------------------------
2.6 M     Trainable params
23.5 M    Non-trainable params
26.1 M    Total params
104.536   Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

1.7974072694778442
1.764238953590393
1.5517719984054565
1.4760671854019165
1.3861379623413086
1.2627485990524292
1.104216456413269
1.0936604738235474
1.2206825017929077
1.1315655708312988
1.041304588317871
1.039202332496643
1.0557713508605957
1.0498424768447876
0.9763909578323364
1.023460865020752
1.0509527921676636
1.0110623836517334
0.946755051612854
0.8813489675521851
1.084365963935852
0.812236487865448
0.9769463539123535
0.8664440512657166
0.8428932428359985
0.9534685611724854
0.8931517601013184
0.9254287481307983
0.9597480893135071
0.9244791865348816
0.9064436554908752
1.001157522201538
0.9892169833183289
0.9371392130851746
0.9434017539024353
0.9521726369857788
0.8320316076278687
1.0072280168533325
0.8111764192581177
1.001930832862854
0.815632164478302
0.8751972913742065
0.8775754570960999
0.8802378177642822
0.8866392970085144
0.8725622296333313
0.8303690552711487
0.9840664267539978
0.9813448190689087
0.8982937932014465
0.8914728164672852
1.0366677045822144
0.966301679611206
0.924

`Trainer.fit` stopped: `max_epochs=20` reached.


In [33]:
trainer.test(mixstyle_module, dataloaders=train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'accuracy': 0.8799421191215515}]

In [34]:
trainer.test(mixstyle_module, dataloaders=val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'accuracy': 0.8093977570533752}]

| # Samples | Baseline | MixStyle |
| --------- | -------- | -------- |
| All       | 0.629    | 0.642    |
| 1_000     | 0.569    | 0.576    |
| 500       | 0.520    | 0.538    |
| 100       | 0.451    | 0.463    |