# Imports

In [None]:
import numpy as np
import torch

torch.cuda.empty_cache()
torch.cuda.ipc_collect()

from torchvision import transforms as T
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter

import sys
sys.path.append("..")

from data import pacs

# Config

## Regarding Dataset

In [2]:
NUM_CLASSES = 7
CLASSES = ["dog", "elephant", "giraffe", "guitar", "horse", "house", "person"]
DOMAINS = {"art_painting", "cartoon", "photo", "sketch"}

## Hyperparameters

In [3]:
EPOCHS = 10
BATCH_SIZE = 18
LEARNING_RATE = 0.001 # Standard value for Adam: 0.001
REGULARIZATION = 0 # Standard value for Adam: 0
BETAS = (0.9, 0.999) # Standard values for Adam: (0.9, 0.999)

## Image Normalization

In [4]:
# Values for pretrained ResNet
# image_transform = T.Compose([
#     T.Resize(256),
#     T.CenterCrop(224),
#     T.ToTensor(),
#     T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

## Device

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Device: {device}")

Device: cuda


## Set the Seed for Reproducibility

In [6]:
seed = 42
torch.manual_seed(42)
if device == torch.device("cuda"):
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

# Training

In [7]:
writer = SummaryWriter()
%load_ext tensorboard
%tensorboard --logdir ./runs

Launching TensorBoard...

## Training Loop

In [8]:
class AverageMeter:
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [9]:
def accuracy(target, output):
    batch_size = target.shape[0]
    _, pred = torch.max(output, dim=-1)
    correct = pred.eq(target).sum()
    return correct.item() / batch_size

In [None]:
def train(epoch, target, data_loader, model, optimizer) -> tuple[float, float]:
    """train one epoch"""
    model.train()
    losses = AverageMeter()
    accs = AverageMeter()

    for i, (data, _domain, target) in enumerate(data_loader):
        step = (epoch - 1) * len(data_loader) + i + 1
        data = data.to(device)
        target = target.to(device)

        out = model(data)
        loss = F.cross_entropy(out, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = accuracy(target, out)
        losses.update(loss.item(), out.shape[0])
        accs.update(acc, out.shape[0])

        writer.add_scalar(f'Loss/train_{target}', loss.item(), step)
        writer.add_scalar(f'Accuracy/train_{target}', acc, step)

    return losses.avg, accs.avg

## Evaluation

In [None]:
def evaluate(data_loader, model, phase="val") -> tuple[float, float]:
    model.eval()
    losses = AverageMeter()
    accs = AverageMeter()

    with torch.no_grad():
        for data, _domain, target in data_loader:
            data = data.to(device)
            target = target.to(device)

            out = model(data)
            loss = F.cross_entropy(out, target)

            acc = accuracy(target, out)
            losses.update(loss.item(), out.shape[0])
            accs.update(acc, out.shape[0])
    
    return losses.avg, accs.avg

In [12]:
class CustomHeader(nn.Module):
    def __init__(self, in_features, num_classes, dropout_p=0.5):
        super().__init__()
        self.fc1 = nn.Linear(in_features, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.dropout = nn.Dropout(dropout_p)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(self.bn1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [13]:
results = []

for target_domain in tqdm(DOMAINS, desc="Target Domain"):
    model = models.resnet50(pretrained=True)
    model.fc = CustomHeader(2048, NUM_CLASSES, dropout_p=0.5)
    model = model.to(device)

    for param in model.parameters():
        param.requires_grad = False

    for param in model.layer4.parameters():
        param.requires_grad = True

    for param in model.fc.parameters():
        param.requires_grad = True

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=BETAS, weight_decay=REGULARIZATION)

    img_mean, img_std = pacs.get_nomalization_stats(target_domain)
    print(f"Normalization values excluding domain {target_domain}:\n\tmean: {img_mean}\n\tstd: {img_std}")
    image_transform = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=img_mean, std=img_std)
    ])

    train_loader, test_loader = pacs.get_data_loaders(target_domain, train_batch_size=BATCH_SIZE, transform=image_transform)

    best_acc = 0
    for epoch in tqdm(range(1, EPOCHS + 1), desc=f"Epoch ({target_domain})"):
        train_loss, train_acc = train(epoch, target_domain, train_loader, model, optimizer)
        test_loss, test_acc = evaluate(test_loader, model)

        writer.add_scalar(f"Loss/test_{target_domain}", test_loss, epoch)
        writer.add_scalar(f"Accuracy/test_{target_domain}", test_acc, epoch)

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), f"models/resnet50/best_{target_domain}.pt")

    model.load_state_dict(torch.load(f"models/resnet50/best_{target_domain}.pt"))
    _, acc = evaluate(test_loader, model, phase="final")

    results.append(acc)

avg_acc = np.mean(results)
worst_case_acc = np.min(results)

print(f"Average Accuracy: {avg_acc}\nWorst-case Performance: {worst_case_acc}")

Target Domain:   0%|          | 0/4 [00:00<?, ?it/s]



Normalization values excluding domain art_painting:
	mean: tensor([0.8185, 0.8058, 0.7828])
	std: tensor([0.1691, 0.1735, 0.1904])


Epoch (art_painting):   0%|          | 0/10 [00:00<?, ?it/s]

Normalization values excluding domain cartoon:
	mean: tensor([0.7512, 0.7332, 0.7101])
	std: tensor([0.1835, 0.1886, 0.2020])


Epoch (cartoon):   0%|          | 0/10 [00:00<?, ?it/s]

Normalization values excluding domain photo:
	mean: tensor([0.8158, 0.7974, 0.7717])
	std: tensor([0.1672, 0.1741, 0.1914])


Epoch (photo):   0%|          | 0/10 [00:00<?, ?it/s]

Normalization values excluding domain sketch:
	mean: tensor([0.6399, 0.6076, 0.5603])
	std: tensor([0.1787, 0.1796, 0.1951])


Epoch (sketch):   0%|          | 0/10 [00:00<?, ?it/s]

Average Accuracy: 0.7706195276449936
Worst-case Performance: 0.6151692542631713


In [None]:
https://github.com/hosthans/domain_generalization.git