## ERM Baseline with ResNet18 (SGD)

#### Source Domains: PACS Art/Painting, Cartoon, Photo
#### Target Domain: PACS Sketch

In [1]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset, random_split
from torch.optim.lr_scheduler import LambdaLR
import random
import numpy as np
import os


device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [2]:
def fix_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Fixed random seed: {seed}")

fix_random_seed(42)

# For deterministic DataLoader behavior
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

Fixed random seed: 42


In [3]:
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')

model.fc = torch.nn.Linear(model.fc.in_features, 7)

base_lr = 0.01
weight_decay = 1e-4
warmup_epochs = 3
num_epochs = 9
optim = torch.optim.SGD(
        model.parameters(),
        lr=base_lr,
        momentum=0.9,
        weight_decay=weight_decay
    )

def lr_lambda(current_epoch):
    if current_epoch < warmup_epochs:
        # Linear warmup from 0 -> 1
        return float(current_epoch + 1) / float(max(1, warmup_epochs))
    else:
        # Cosine annealing from 1 -> 0
        progress = (current_epoch - warmup_epochs) / float(max(1, num_epochs - warmup_epochs))
        return 0.5 * (1.0 + torch.cos(torch.tensor(progress * 3.1415926535)))

scheduler = LambdaLR(optim, lr_lambda=lr_lambda)

criterion = torch.nn.CrossEntropyLoss()

In [4]:
# --- Constants ---
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
BATCH_SIZE = 64
NUM_WORKERS = 0
TRAIN_RATIO = 0.95
VAL_RATIO = 0.05

# --- Transform ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std)
])

# --- Directories ---
art_dir = "data/pacs_data/pacs_data/art_painting"
cartoon_dir = "data/pacs_data/pacs_data/cartoon"
photo_dir = "data/pacs_data/pacs_data/photo"
sketch_dir = "data/pacs_data/pacs_data/sketch"

# --- Datasets ---
art_dataset = datasets.ImageFolder(root=art_dir, transform=transform)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=transform)
photo_dataset = datasets.ImageFolder(root=photo_dir, transform=transform)
sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=transform)  # target domain

# --- Seed setup for reproducibility ---
g = torch.Generator()
g.manual_seed(42)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    torch.manual_seed(worker_seed)
    torch.cuda.manual_seed_all(worker_seed)

# --- Split each train domain into train/val ---
def split_dataset(dataset, ratio=TRAIN_RATIO, generator=g):
    total_size = len(dataset)
    train_size = int(ratio * total_size)
    val_size = total_size - train_size
    return random_split(dataset, [train_size, val_size], generator=generator)

art_train, art_val = split_dataset(art_dataset)
cartoon_train, cartoon_val = split_dataset(cartoon_dataset)
photo_train, photo_val = split_dataset(photo_dataset)

# --- Loaders for each val domain ---
art_val_loader = DataLoader(
    art_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

cartoon_val_loader = DataLoader(
    cartoon_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

photo_val_loader = DataLoader(
    photo_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

# --- Combined train and val datasets for all source domains ---
combined_train_dataset = ConcatDataset([art_train, cartoon_train, photo_train])
combined_val_dataset = ConcatDataset([art_val, cartoon_val, photo_val])

combined_train_loader = DataLoader(
    combined_train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

combined_val_loader = DataLoader(
    combined_val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

# --- Sketch loader (target domain) ---
sketch_loader = DataLoader(
    sketch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

# --- Debug info ---
print("Samples per domain:")
print(f"  Art total: {len(art_dataset)} | Train: {len(art_train)} | Val: {len(art_val)}")
print(f"  Cartoon total: {len(cartoon_dataset)} | Train: {len(cartoon_train)} | Val: {len(cartoon_val)}")
print(f"  Photo total: {len(photo_dataset)} | Train: {len(photo_train)} | Val: {len(photo_val)}")
print(f"  Sketch total (target): {len(sketch_dataset)}")
print()
print(f"Combined Train Samples: {len(combined_train_dataset)}")
print(f"Combined Val Samples:   {len(combined_val_dataset)}")

# --- Classes ---
pacs_classes = sketch_dataset.classes

Samples per domain:
  Art total: 2048 | Train: 1945 | Val: 103
  Cartoon total: 2344 | Train: 2226 | Val: 118
  Photo total: 1670 | Train: 1586 | Val: 84
  Sketch total (target): 3929

Combined Train Samples: 5757
Combined Val Samples:   305


In [5]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.to(device)
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    avg_loss = running_loss / total
    accuracy = 100. * correct / total
    return avg_loss, accuracy

In [6]:
def evaluate(model, test_loader, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            preds = model(images)
            preds = torch.argmax(preds, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total * 100
    return accuracy

In [7]:
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, combined_train_loader, criterion, optim, device)
    scheduler.step()
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | LR: {scheduler.get_last_lr()[0]:.5f}")


Epoch 1/9 | Loss: 0.5595 | Train Acc: 81.36% | LR: 0.00667
Epoch 2/9 | Loss: 0.1159 | Train Acc: 96.54% | LR: 0.01000
Epoch 3/9 | Loss: 0.0648 | Train Acc: 98.12% | LR: 0.01000
Epoch 4/9 | Loss: 0.0462 | Train Acc: 98.66% | LR: 0.00933
Epoch 5/9 | Loss: 0.0205 | Train Acc: 99.37% | LR: 0.00750
Epoch 6/9 | Loss: 0.0041 | Train Acc: 99.97% | LR: 0.00500
Epoch 7/9 | Loss: 0.0020 | Train Acc: 100.00% | LR: 0.00250
Epoch 8/9 | Loss: 0.0010 | Train Acc: 100.00% | LR: 0.00067
Epoch 9/9 | Loss: 0.0009 | Train Acc: 100.00% | LR: 0.00000


### Evaluation on Source Domains

In [8]:
art_accuracy = evaluate(model, art_val_loader, device)
print(f"Art Accuracy: {art_accuracy:.2f}%")

cartoon_accuracy = evaluate(model, cartoon_val_loader, device)
print(f"Cartoon Accuracy: {cartoon_accuracy:.2f}%")

photo_accuracy = evaluate(model, photo_val_loader, device)
print(f"Photo Accuracy: {photo_accuracy:.2f}%")

source_accuracy = evaluate(model, combined_val_loader, device)
print(f"\nAll Source Domains Accuracy: {source_accuracy:.2f}%")

Art Accuracy: 94.17%
Cartoon Accuracy: 96.61%
Photo Accuracy: 98.81%

All Source Domains Accuracy: 96.39%


### Evaluation on Test Domain

In [9]:
sketch_accuracy = evaluate(model, sketch_loader, device)
print(f"Sketch Accuracy: {sketch_accuracy:.2f}%")

Sketch Accuracy: 70.40%
