# Imports

In [1]:
import numpy as np
import torch
from torchvision import transforms as T
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter

import sys
sys.path.append("..")

from data import pacs
import models.resnet_ms as resnet_ms

# Config

## Regarding Dataset

In [2]:
NUM_CLASSES = 7
CLASSES = ["dog", "elephant", "giraffe", "guitar", "horse", "house", "person"]
DOMAINS = ["art_painting", "cartoon", "photo", "sketch"]

## Hyperparameters

In [3]:
EPOCHS = 25
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
REGULARIZATION = 1e-4
MOMENTUM = 0.9
MODEL = resnet_ms.resnet50_fc512_ms12_a0d1
USE_PRETRAINED = True
OPTIMIZER = optim.SGD
OPTIMIZER_KWARGS = {
    "lr": LEARNING_RATE,
    "weight_decay": REGULARIZATION,
    "momentum": MOMENTUM
}
SCHEDULER = optim.lr_scheduler.CosineAnnealingLR # optim.lr_scheduler.ReduceLROnPlateau
SCHEDULER_KWARGS = {"T_max": EPOCHS} # {"mode": "min", "patience": 5}
EARLY_STOPPING_PATIENCE = 5
EARLY_STOPPING_DELTA = 1e-5
AUGMENTATIONS = ()
NUM_SEEDS = 3

## Image Normalization

In [4]:
# Values for pretrained ResNet
pretrained_image_transform = T.Compose([
    *AUGMENTATIONS,
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## Device

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Device: {device}")

Device: cuda


## Abstract model building, optimizer and scheduler

In [6]:
build_model = lambda: MODEL(NUM_CLASSES, loss='softmax', pretrained=USE_PRETRAINED)
build_optimizer = lambda model: OPTIMIZER(model.parameters(), **OPTIMIZER_KWARGS)
build_scheduler = lambda optimizer: SCHEDULER(optimizer, **SCHEDULER_KWARGS)

# Set seed for reproducibility

In [7]:
# seed = 42
# torch.manual_seed(42)
# if device == torch.device("cuda"):
#     torch.cuda.manual_seed(seed)
#     torch.backends.cudnn.deterministic = True

# Training

In [8]:
writer = SummaryWriter()
%load_ext tensorboard
%tensorboard --logdir ./runs

Reusing TensorBoard on port 6006 (pid 25156), started 4 days, 21:33:29 ago. (Use '!kill 25156' to kill it.)

## Training Loop

In [9]:
class AverageMeter:
    def __init__(self):
        self.reset()
    
    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0
    
    def update(self, val, n):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [10]:
def accuracy(target, output):
    batch_size = target.shape[0]
    _, pred = torch.max(output, dim=-1)
    correct = pred.eq(target).sum()
    return correct.item() / batch_size

In [11]:
def train(epoch: int,
        target_domain: str,
        data_loader:torch.utils.data.DataLoader,
        model: nn.Module,
        optimizer: optim.Optimizer
        ) -> tuple[float, float]:
    """train one epoch"""
    model.train()
    losses = AverageMeter()
    accs = AverageMeter()

    for i, (data, target) in enumerate(data_loader):
        step = (epoch - 1) * len(data_loader) + i + 1
        data = data.to(device)
        target = target.to(device)

        out = model(data)
        loss = F.cross_entropy(out, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = accuracy(target, out)
        losses.update(loss.item(), out.shape[0])
        accs.update(acc, out.shape[0])

        writer.add_scalar(f'Loss/Train/target={target_domain}', loss.item(), step)
        writer.add_scalar(f'Accuracy/Train/target={target_domain}', acc, step)

    return losses.avg, accs.avg

## Evaluation

In [12]:
def evaluate(data_loader: torch.utils.data.DataLoader, model: nn.Module, phase="val") -> tuple[float, float]:
    model.eval()

    losses = AverageMeter()
    accs = AverageMeter()

    with torch.no_grad():
        for data, target in data_loader:
            data = data.to(device)
            target = target.to(device)

            out = model(data)

            # The implementation returns only the feature vector rather than the classification logits.
            # To compare the labels, we therefore must apply the classification layer manually:
            out = model.classifier(out)

            loss = F.cross_entropy(out, target)
            acc = accuracy(target, out)

            losses.update(loss.item(), out.shape[0])
            accs.update(acc, out.shape[0])
    
    return losses.avg, accs.avg

## Training Loop

In [14]:
all_results = {d: [] for d in DOMAINS}
all_results['avgs'] = []
all_results['worst'] = []

for _ in tqdm(range(NUM_SEEDS), desc="Seeds"):
    results = {}

    for target_domain in tqdm(DOMAINS, desc="Target Domain"):
        model = build_model()
        model = model.to(device)

        optimizer = build_optimizer(model)
        scheduler = build_scheduler(optimizer)

        if not USE_PRETRAINED:
            img_mean, img_std = pacs.get_normalization_stats(target_domain)
            print(f"Normalization values excluding domain {target_domain}:\n\tmean: {img_mean}\n\tstd: {img_std}")
            image_transform = T.Compose([
                *AUGMENTATIONS,
                T.Resize(256),
                T.CenterCrop(224),
                T.ToTensor(),
                T.Normalize(mean=img_mean, std=img_std)
            ])
        else:
            image_transform = pretrained_image_transform

        train_loader, test_loader, val_loader = pacs.get_data_loaders(target_domain,
                                                                      train_batch_size=BATCH_SIZE,
                                                                      split="threefold",
                                                                      transform=image_transform,
                                                                      shuffle_test=True,
                                                                      drop_last=True
                                                                     )

        best_loss = float('inf')
        patience_counter = 0
        for epoch in tqdm(range(1, EPOCHS + 1), desc=f"Epoch ({target_domain})"):
            train_loss, train_acc = train(epoch, target_domain, train_loader, model, optimizer)
            val_loss, val_acc = evaluate(val_loader, model)

            writer.add_scalar(f"Loss/Val/target={target_domain}", val_loss, epoch)
            writer.add_scalar(f"Accuracy/Val/target={target_domain}", val_acc, epoch)

            scheduler.step() # scheduler.step(test_loss)

            if best_loss - val_loss < EARLY_STOPPING_DELTA and (patience_counter := patience_counter+1) > EARLY_STOPPING_PATIENCE:
                break

            if val_loss < best_loss:
                best_acc = val_acc
                torch.save(model.state_dict(), f"../checkpoints/mixstyle/best_{target_domain}.pt")

        model.load_state_dict(torch.load(f"../checkpoints/mixstyle/best_{target_domain}.pt"))
        _, acc = evaluate(test_loader, model, phase="final")

        results[target_domain] = acc

    avg_acc = np.mean([*results.values()])
    worst_case_acc = np.min([*results.values()])

    for d in DOMAINS:
        all_results[d].append(results[d])
    all_results['avgs'].append(avg_acc)
    all_results['worst'].append(worst_case_acc)

print("Average Accuracy:\n" +
      "{}".format("".join(f"\t{d}: {np.mean(all_results[d]):.4f}, std: {np.std(all_results[d]):.4f}\n" for d in DOMAINS)) +
      f"\ttotal: {np.mean(all_results['avgs']):.4f}, std: {np.std(all_results['avgs']):.4f}\n"
      "Worst-case Accuracy:\n" +
      "{}".format("".join(f"\t{d}: {np.min(all_results[d]):.4f}\n" for d in DOMAINS)) +
      f"\ttotal: {np.mean(all_results['worst']):.4f}, std: {np.std(all_results['worst']):.4f}"
)

Seeds:   0%|          | 0/3 [00:00<?, ?it/s]

Target Domain:   0%|          | 0/4 [00:00<?, ?it/s]

Insert MixStyle after the following layers: ['layer1', 'layer2']


TypeError: 'set' object is not subscriptable

In [None]:
print("Average Accuracy:\n" +
      "{}".format("".join(f"\t{d}: {np.mean(all_results[d]):.4f}, std: {np.std(all_results[d]):.4f}\n" for d in DOMAINS)) +
      f"\ttotal: {np.mean(all_results['avgs']):.4f}, std: {np.std(all_results['avgs']):.4f}\n"
      "Worst-case Accuracy:\n" +
      "{}".format("".join(f"\t{d}: {np.min(all_results[d]):.4f}\n" for d in DOMAINS)) +
      f"\ttotal: {np.mean(all_results['worst']):.4f}, std: {np.std(all_results['worst']):.4f}"
      )

Average Accuracy:
	art_painting: 0.8877, std: 0.0063
	cartoon: 0.7674, std: 0.0121
	photo: 0.9818, std: 0.0007
	sketch: 0.7550, std: 0.0136
	total: 0.8480, std: 0.0005
Worst-case Accuracy:
	art_painting: 0.8823
	cartoon: 0.7526
	photo: 0.9808
	sketch: 0.7439
	total: 0.7478, std: 0.0036
