In [9]:
import torch
import torch.nn as nn
from domainbed.algorithms import IRM
from domainbed import algorithms, networks
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset, random_split
import random
import numpy as np
import os

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [10]:
def fix_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Fixed random seed: {seed}")

fix_random_seed(42)

# For deterministic DataLoader behavior
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

Fixed random seed: 42


In [11]:
# --- Hyperparameters ---
hparams = {
    'lr': 0.001,
    'weight_decay': 0.0,
    'groupdro_eta': 1e-2,

    'nonlinear_classifier': 0,
    'resnet18': 1,
    'resnet_dropout': 0.0,
    'freeze_bn': 0,

    'mlp_width': 1024,
    'mlp_depth': 3,
    'mlp_dropout': 0.1,
    'vit': 0,
    'dinov2': 0,
    'vit_dropout': 0.0,
    'vit_attn_tune': 0,
}


groupDRO = algorithms.GroupDRO(
    input_shape=(3, 224, 224),
    num_classes=7,
    num_domains=3,
    hparams=hparams
).to(device)


In [12]:
# --- Constants ---
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
BATCH_SIZE = 64
NUM_WORKERS = 0
TRAIN_RATIO = 0.95
VAL_RATIO = 0.05

# --- Transform ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std)
])

# --- Directories ---
art_dir = "data/pacs_data/pacs_data/art_painting"
cartoon_dir = "data/pacs_data/pacs_data/cartoon"
photo_dir = "data/pacs_data/pacs_data/photo"
sketch_dir = "data/pacs_data/pacs_data/sketch"

# --- Datasets ---
art_dataset = datasets.ImageFolder(root=art_dir, transform=transform)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=transform)
photo_dataset = datasets.ImageFolder(root=photo_dir, transform=transform)
sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=transform)  # target domain

# --- Seed setup for reproducibility ---
g = torch.Generator()
g.manual_seed(42)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    torch.manual_seed(worker_seed)
    torch.cuda.manual_seed_all(worker_seed)

# --- Split each train domain into train/val ---
def split_dataset(dataset, ratio=TRAIN_RATIO, generator=g):
    total_size = len(dataset)
    train_size = int(ratio * total_size)
    val_size = total_size - train_size
    return random_split(dataset, [train_size, val_size], generator=generator)

art_train, art_val = split_dataset(art_dataset)
cartoon_train, cartoon_val = split_dataset(cartoon_dataset)
photo_train, photo_val = split_dataset(photo_dataset)

# --- Loaders for each train domain ---
art_train_loader = DataLoader(
    art_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

cartoon_train_loader = DataLoader(
    cartoon_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

photo_train_loader = DataLoader(
    photo_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)


# --- Loaders for each val domain ---
art_val_loader = DataLoader(
    art_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

cartoon_val_loader = DataLoader(
    cartoon_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

photo_val_loader = DataLoader(
    photo_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

# --- Combined train and val datasets for all source domains ---
combined_train_dataset = ConcatDataset([art_train, cartoon_train, photo_train])
combined_val_dataset = ConcatDataset([art_val, cartoon_val, photo_val])

combined_train_loader = DataLoader(
    combined_train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

combined_val_loader = DataLoader(
    combined_val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

# --- Sketch loader (target domain) ---
sketch_loader = DataLoader(
    sketch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

envs = [art_train_loader, cartoon_train_loader, photo_train_loader]

# --- Debug info ---
print("Samples per domain:")
print(f"  Art total: {len(art_dataset)} | Train: {len(art_train)} | Val: {len(art_val)}")
print(f"  Cartoon total: {len(cartoon_dataset)} | Train: {len(cartoon_train)} | Val: {len(cartoon_val)}")
print(f"  Photo total: {len(photo_dataset)} | Train: {len(photo_train)} | Val: {len(photo_val)}")
print(f"  Sketch total (target): {len(sketch_dataset)}")
print()
print(f"Combined Train Samples: {len(combined_train_dataset)}")
print(f"Combined Val Samples:   {len(combined_val_dataset)}")

# --- Classes ---
pacs_classes = sketch_dataset.classes

envs = [art_train_loader, cartoon_train_loader, photo_train_loader]

Samples per domain:
  Art total: 2048 | Train: 1945 | Val: 103
  Cartoon total: 2344 | Train: 2226 | Val: 118
  Photo total: 1670 | Train: 1586 | Val: 84
  Sketch total (target): 3929

Combined Train Samples: 5757
Combined Val Samples:   305


In [13]:
def evaluate(model, loader, device):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model.predict(x)
            correct += (torch.argmax(preds, dim=1) == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

In [14]:
env_iters = [iter(loader) for loader in envs]

num_batches = min(len(loader) for loader in envs)

num_epochs = 9

for epoch in range(num_epochs):
    groupDRO.train()
    for batch_idx in range(num_batches):
        minibatches = []
        for i, env_loader in enumerate(envs):
            try:
                x, y = next(env_iters[i])
            except StopIteration:
                env_iters[i] = iter(env_loader)
                x, y = next(env_iters[i])
            minibatches.append((x.to(device), y.to(device)))

        metrics = groupDRO.update(minibatches)
    print(f"Epoch {epoch+1}: Loss {metrics['loss']:.4f} | Avg Train Acc: {metrics['avg_acc']:.2f}% | Worst Group Acc: {metrics['worst_acc']:.2f}%")

Epoch 1: Loss 0.3086 | Avg Train Acc: 89.96% | Worst Group Acc: 79.69%
Epoch 2: Loss 0.1079 | Avg Train Acc: 95.54% | Worst Group Acc: 93.75%
Epoch 3: Loss 0.0327 | Avg Train Acc: 99.33% | Worst Group Acc: 98.00%
Epoch 4: Loss 0.0836 | Avg Train Acc: 97.25% | Worst Group Acc: 96.88%
Epoch 5: Loss 0.0684 | Avg Train Acc: 98.29% | Worst Group Acc: 98.00%
Epoch 6: Loss 0.0414 | Avg Train Acc: 98.81% | Worst Group Acc: 98.00%
Epoch 7: Loss 0.0428 | Avg Train Acc: 98.96% | Worst Group Acc: 96.88%
Epoch 8: Loss 0.0502 | Avg Train Acc: 97.48% | Worst Group Acc: 94.00%
Epoch 9: Loss 0.0588 | Avg Train Acc: 98.96% | Worst Group Acc: 96.88%


In [15]:
art_accuracy = evaluate(groupDRO, art_val_loader, device)
print(f"Art Accuracy: {art_accuracy:.2f}%")

cartoon_accuracy = evaluate(groupDRO, cartoon_val_loader, device)
print(f"Cartoon Accuracy: {cartoon_accuracy:.2f}%")

photo_accuracy = evaluate(groupDRO, photo_val_loader, device)
print(f"Photo Accuracy: {photo_accuracy:.2f}%")

source_accuracy = evaluate(groupDRO, combined_val_loader, device)
print(f"\nAll Source Domains Accuracy: {source_accuracy:.2f}%")

Art Accuracy: 82.52%
Cartoon Accuracy: 91.53%
Photo Accuracy: 94.05%

All Source Domains Accuracy: 89.18%


In [16]:
sketch_accuracy = evaluate(groupDRO, sketch_loader, device)
print(f"Sketch Accuracy: {sketch_accuracy:.2f}%")

Sketch Accuracy: 65.72%
