# Imports

In [1]:
import numpy as np
import torch
from torchvision import transforms as T
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter

In [2]:
from pacs import get_data_loaders

# Config

## Regarding Dataset

In [3]:
NUM_CLASSES = 7
CLASSES = ["dog", "elephant", "giraffe", "guitar", "horse", "house", "person"]
DOMAINS = {"art_painting", "cartoon", "photo", "sketch"}

## Hyperparameters

In [4]:
EPOCHS = 10
BATCH_SIZE = 128
LEARNING_RATE = 0.001 # Standard value for Adam: 0.001
REGULARIZATION = 0 # Standard value for Adam: 0
BETAS = (0.9, 0.999) # Standard values for Adam: (0.9, 0.999)

## Image Normalization

In [5]:
# Values for pretrained ResNet, might need adjusting
image_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## Device

In [6]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Device: {device}")

Device: cuda


# Training

In [7]:
writer = SummaryWriter()
%load_ext tensorboard
%tensorboard --logdir ./runs

Reusing TensorBoard on port 6006 (pid 29728), started 0:05:26 ago. (Use '!kill 29728' to kill it.)

## Training Loop

In [8]:
class AverageMeter:
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [9]:
def accuracy(target, output):
    batch_size = target.shape[0]
    _, pred = torch.max(output, dim=-1)
    correct = pred.eq(target).sum()
    return correct.item() / batch_size

In [10]:
def train(epoch, data_loader, model, optimizer) -> tuple[float, float]:
    """train one epoch"""
    model.train()
    losses = AverageMeter()
    accs = AverageMeter()

    for i, (data, target) in enumerate(tqdm(data_loader, desc="Instance")):
        step = (epoch - 1) * len(data_loader) + i + 1
        data = data.to(device)
        target = target.to(device)

        out = model(data)
        loss = F.cross_entropy(out, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = accuracy(target, out)
        losses.update(loss.item(), out.shape[0])
        accs.update(acc, out.shape[0])

        writer.add_scalar('Loss/train', loss.item(), step)
        writer.add_scalar('Accuracy/train', acc, step)

    return losses.avg, accs.avg

## Evaluation

In [11]:
def evaluate(data_loader, model, phase="val") -> tuple[float, float]:
    model.eval()
    losses = AverageMeter()
    accs = AverageMeter()

    with torch.no_grad():
        for data, target in tqdm(data_loader, desc=f"Evaluate ({phase})"):
            data = data.to(device)
            target = target.to(device)

            out = model(data)
            loss = F.cross_entropy(out, target)

            acc = accuracy(target, out)
            losses.update(loss.item(), out.shape[0])
            accs.update(acc, out.shape[0])
    
    return losses.avg, accs.avg

In [None]:
results = []

for target_domain in tqdm(DOMAINS, desc="Target Domain"):
    model = models.resnet50()
    model.fc = nn.Linear(512, NUM_CLASSES)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=BETAS, weight_decay=REGULARIZATION)

    train_loader, test_loader = get_data_loaders(target_domain, train_batch_size=BATCH_SIZE, transform=image_transform)

    best_acc = 0
    for epoch in tqdm(range(1, EPOCHS + 1), desc=f"Epoch ({target_domain})"):
        train_loss, train_acc = train(epoch, train_loader, model, optimizer)
        test_loss, test_acc = evaluate(test_loader, model)

        writer.add_scalar(f"Loss/test_{target_domain}", test_loss, epoch)
        writer.add_scalar(f"Accuracy/test_{target_domain}", test_acc, epoch)

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), f"models/resnet50/best_{target_domain}.pt")

    model.load_state_dict(torch.load(f"models/resnet50/best_{target_domain}.pt"))
    _, acc = evaluate(test_loader, model, phase="final")

    results.append(acc)

avg_acc = np.mean(results)
worst_case_acc = np.min(results)

print(f"Average Accuracy: {avg_acc}\nWorst-case Performance: {worst_case_acc}")

Target Domain:   0%|          | 0/4 [00:00<?, ?it/s]

Epoch (sketch):   0%|          | 0/10 [00:00<?, ?it/s]

Instance:   0%|          | 0/48 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/31 [00:00<?, ?it/s]

Instance:   0%|          | 0/48 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/31 [00:00<?, ?it/s]

Instance:   0%|          | 0/48 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/31 [00:00<?, ?it/s]

Instance:   0%|          | 0/48 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/31 [00:00<?, ?it/s]

Instance:   0%|          | 0/48 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/31 [00:00<?, ?it/s]

Instance:   0%|          | 0/48 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/31 [00:00<?, ?it/s]

Instance:   0%|          | 0/48 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/31 [00:00<?, ?it/s]

Instance:   0%|          | 0/48 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/31 [00:00<?, ?it/s]

Instance:   0%|          | 0/48 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/31 [00:00<?, ?it/s]

Instance:   0%|          | 0/48 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/31 [00:00<?, ?it/s]

Evaluate (final):   0%|          | 0/31 [00:00<?, ?it/s]

Epoch (art_painting):   0%|          | 0/10 [00:00<?, ?it/s]

Instance:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/16 [00:00<?, ?it/s]

Instance:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/16 [00:00<?, ?it/s]

Instance:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/16 [00:00<?, ?it/s]

Instance:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/16 [00:00<?, ?it/s]

Instance:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/16 [00:00<?, ?it/s]

Instance:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/16 [00:00<?, ?it/s]

Instance:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/16 [00:00<?, ?it/s]

Instance:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/16 [00:00<?, ?it/s]

Instance:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/16 [00:00<?, ?it/s]

Instance:   0%|          | 0/63 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/16 [00:00<?, ?it/s]

Evaluate (final):   0%|          | 0/16 [00:00<?, ?it/s]

Epoch (cartoon):   0%|          | 0/10 [00:00<?, ?it/s]

Instance:   0%|          | 0/60 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/19 [00:00<?, ?it/s]

Instance:   0%|          | 0/60 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/19 [00:00<?, ?it/s]

Instance:   0%|          | 0/60 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/19 [00:00<?, ?it/s]

Instance:   0%|          | 0/60 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/19 [00:00<?, ?it/s]

Instance:   0%|          | 0/60 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/19 [00:00<?, ?it/s]

Instance:   0%|          | 0/60 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/19 [00:00<?, ?it/s]

Instance:   0%|          | 0/60 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/19 [00:00<?, ?it/s]

Instance:   0%|          | 0/60 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/19 [00:00<?, ?it/s]

Instance:   0%|          | 0/60 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/19 [00:00<?, ?it/s]

Instance:   0%|          | 0/60 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/19 [00:00<?, ?it/s]

Evaluate (final):   0%|          | 0/19 [00:00<?, ?it/s]

Epoch (photo):   0%|          | 0/10 [00:00<?, ?it/s]

Instance:   0%|          | 0/66 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/14 [00:00<?, ?it/s]

Instance:   0%|          | 0/66 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/14 [00:00<?, ?it/s]

Instance:   0%|          | 0/66 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/14 [00:00<?, ?it/s]

Instance:   0%|          | 0/66 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/14 [00:00<?, ?it/s]

Instance:   0%|          | 0/66 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/14 [00:00<?, ?it/s]

Instance:   0%|          | 0/66 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/14 [00:00<?, ?it/s]

Instance:   0%|          | 0/66 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/14 [00:00<?, ?it/s]

Instance:   0%|          | 0/66 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/14 [00:00<?, ?it/s]

Instance:   0%|          | 0/66 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/14 [00:00<?, ?it/s]

Instance:   0%|          | 0/66 [00:00<?, ?it/s]

Evaluate (val):   0%|          | 0/14 [00:00<?, ?it/s]

Evaluate (final):   0%|          | 0/14 [00:00<?, ?it/s]

Average Accuracy: 0.36756580184835974
Worst-case Performance: 0.2744140625
