## ERM Baseline with ResNet18 (Adam)

#### Source Domains: PACS Art/Painting, Cartoon, Photo
#### Target Domain: PACS Sketch

In [161]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset, random_split
from torch.optim.lr_scheduler import LambdaLR
import random
import numpy as np
import os


device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [162]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)
# !unzip /content/drive/MyDrive/PA2-Data/Archive.zip

In [163]:
def fix_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Fixed random seed: {seed}")

fix_random_seed(42)

# For deterministic DataLoader behavior
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

Fixed random seed: 42


In [164]:
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')

model.fc = torch.nn.Linear(model.fc.in_features, 7)

model = torchvision.models.resnet18(weights='IMAGENET1K_V1')

for param in model.parameters():
    param.requires_grad = False
model.fc = torch.nn.Linear(model.fc.in_features, 7)
optim = torch.optim.Adam(model.fc.parameters(), lr=0.008)

criterion = torch.nn.CrossEntropyLoss()

In [165]:
# --- Constants ---
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
BATCH_SIZE = 512
NUM_WORKERS = 2
TRAIN_RATIO = 0.95
VAL_RATIO = 0.05

# --- Transform ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std)
])

# --- Directories ---
art_dir = "/content/PA2-Data/pacs_data/pacs_data/art_painting"
cartoon_dir = "/content/PA2-Data/pacs_data/pacs_data/cartoon"
photo_dir = "/content/PA2-Data/pacs_data/pacs_data/photo"
sketch_dir = "/content/PA2-Data/pacs_data/pacs_data/sketch"

# --- Datasets ---
art_dataset = datasets.ImageFolder(root=art_dir, transform=transform)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=transform)
photo_dataset = datasets.ImageFolder(root=photo_dir, transform=transform)
sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=transform)  # target domain

# --- Seed setup for reproducibility ---
g = torch.Generator()
g.manual_seed(42)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    torch.manual_seed(worker_seed)
    torch.cuda.manual_seed_all(worker_seed)

# --- Split each train domain into train/val ---
def split_dataset(dataset, ratio=TRAIN_RATIO, generator=g):
    total_size = len(dataset)
    train_size = int(ratio * total_size)
    val_size = total_size - train_size
    return random_split(dataset, [train_size, val_size], generator=generator)

art_train, art_val = split_dataset(art_dataset)
cartoon_train, cartoon_val = split_dataset(cartoon_dataset)
photo_train, photo_val = split_dataset(photo_dataset)

# --- Loaders for each val domain ---
art_val_loader = DataLoader(
    art_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

cartoon_val_loader = DataLoader(
    cartoon_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

photo_val_loader = DataLoader(
    photo_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

# --- Combined train and val datasets for all source domains ---
combined_train_dataset = ConcatDataset([art_train, cartoon_train, photo_train])
combined_val_dataset = ConcatDataset([art_val, cartoon_val, photo_val])

combined_train_loader = DataLoader(
    combined_train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

combined_val_loader = DataLoader(
    combined_val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

# --- Sketch loader (target domain) ---
sketch_loader = DataLoader(
    sketch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

# --- Debug info ---
print("Samples per domain:")
print(f"  Art total: {len(art_dataset)} | Train: {len(art_train)} | Val: {len(art_val)}")
print(f"  Cartoon total: {len(cartoon_dataset)} | Train: {len(cartoon_train)} | Val: {len(cartoon_val)}")
print(f"  Photo total: {len(photo_dataset)} | Train: {len(photo_train)} | Val: {len(photo_val)}")
print(f"  Sketch total (target): {len(sketch_dataset)}")
print()
print(f"Combined Train Samples: {len(combined_train_dataset)}")
print(f"Combined Val Samples:   {len(combined_val_dataset)}")

# --- Classes ---
pacs_classes = sketch_dataset.classes

Samples per domain:
  Art total: 2048 | Train: 1945 | Val: 103
  Cartoon total: 2344 | Train: 2226 | Val: 118
  Photo total: 1670 | Train: 1586 | Val: 84
  Sketch total (target): 3929

Combined Train Samples: 5757
Combined Val Samples:   305


In [166]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.to(device)
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    avg_loss = running_loss / total
    accuracy = 100. * correct / total
    return avg_loss, accuracy

In [167]:
def evaluate(model, test_loader, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            preds = model(images)
            preds = torch.argmax(preds, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total * 100
    return accuracy

In [168]:
num_epochs = 9
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, combined_train_loader, criterion, optim, device)
    print(f"Epoch {epoch+1}/{num_epochs} | Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")


Epoch 1/9 | Loss: 1.7000 | Train Acc: 44.54%
Epoch 2/9 | Loss: 0.6410 | Train Acc: 78.79%
Epoch 3/9 | Loss: 0.4630 | Train Acc: 84.68%
Epoch 4/9 | Loss: 0.3586 | Train Acc: 88.62%
Epoch 5/9 | Loss: 0.3210 | Train Acc: 89.77%
Epoch 6/9 | Loss: 0.3080 | Train Acc: 90.17%
Epoch 7/9 | Loss: 0.2935 | Train Acc: 90.29%
Epoch 8/9 | Loss: 0.2795 | Train Acc: 90.93%
Epoch 9/9 | Loss: 0.2674 | Train Acc: 91.40%


### Evaluation on Source Domains

In [169]:
art_accuracy = evaluate(model, art_val_loader, device)
print(f"Art Accuracy: {art_accuracy:.2f}%")

cartoon_accuracy = evaluate(model, cartoon_val_loader, device)
print(f"Cartoon Accuracy: {cartoon_accuracy:.2f}%")

photo_accuracy = evaluate(model, photo_val_loader, device)
print(f"Photo Accuracy: {photo_accuracy:.2f}%")

source_accuracy = evaluate(model, combined_val_loader, device)
print(f"\nAll Source Domains Accuracy: {source_accuracy:.2f}%")

Art Accuracy: 83.50%
Cartoon Accuracy: 81.36%
Photo Accuracy: 98.81%

All Source Domains Accuracy: 86.89%


### Evaluation on Test Domain

In [170]:
sketch_accuracy = evaluate(model, sketch_loader, device)
print(f"Sketch Accuracy: {sketch_accuracy:.2f}%")

Sketch Accuracy: 41.66%
