## IRMv1 with ResNet18 (Adam)

#### Source Domains: PACS Art/Painting, Cartoon, Photo
#### Target Domain: PACS Sketch

In [1]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset, random_split
from domainbed.algorithms import IRM
from domainbed import algorithms, networks
import random
import numpy as np
import os


device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [2]:
def fix_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Fixed random seed: {seed}")

fix_random_seed(42)

# For deterministic DataLoader behavior
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

Fixed random seed: 42


In [3]:
hparams = {
    'lr': 0.001,
    'weight_decay': 0,
    'irm_lambda': 50,
    'irm_penalty_anneal_iters': 75,

    'nonlinear_classifier': 0,
    'resnet18': 1,
    'resnet_dropout': 0.0,
    'freeze_bn': 0,

    'mlp_width': 1024,
    'mlp_depth': 3,
    'mlp_dropout': 0.1,
    'vit': 0,
    'dinov2': 0,
    'vit_dropout': 0.0,
    'vit_attn_tune': 0,
}


# --- Initialize IRM ---
irm = algorithms.IRM(
    input_shape=(3, 224, 224),
    num_classes=7,
    num_domains=3,
    hparams=hparams
).to(device)




In [4]:
# --- Constants ---
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
BATCH_SIZE = 64
NUM_WORKERS = 0
TRAIN_RATIO = 0.95
VAL_RATIO = 0.05

# --- Transform ---
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(imagenet_mean, imagenet_std)
])

# --- Directories ---
art_dir = "data/pacs_data/pacs_data/art_painting"
cartoon_dir = "data/pacs_data/pacs_data/cartoon"
photo_dir = "data/pacs_data/pacs_data/photo"
sketch_dir = "data/pacs_data/pacs_data/sketch"

# --- Datasets ---
art_dataset = datasets.ImageFolder(root=art_dir, transform=transform)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=transform)
photo_dataset = datasets.ImageFolder(root=photo_dir, transform=transform)
sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=transform)  # target domain

# --- Seed setup for reproducibility ---
g = torch.Generator()
g.manual_seed(42)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    torch.manual_seed(worker_seed)
    torch.cuda.manual_seed_all(worker_seed)

# --- Split each train domain into train/val ---
def split_dataset(dataset, ratio=TRAIN_RATIO, generator=g):
    total_size = len(dataset)
    train_size = int(ratio * total_size)
    val_size = total_size - train_size
    return random_split(dataset, [train_size, val_size], generator=generator)

art_train, art_val = split_dataset(art_dataset)
cartoon_train, cartoon_val = split_dataset(cartoon_dataset)
photo_train, photo_val = split_dataset(photo_dataset)

# --- Loaders for each train domain ---
art_train_loader = DataLoader(
    art_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

cartoon_train_loader = DataLoader(
    cartoon_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

photo_train_loader = DataLoader(
    photo_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)


# --- Loaders for each val domain ---
art_val_loader = DataLoader(
    art_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

cartoon_val_loader = DataLoader(
    cartoon_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

photo_val_loader = DataLoader(
    photo_val,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

# --- Combined train and val datasets for all source domains ---
combined_train_dataset = ConcatDataset([art_train, cartoon_train, photo_train])
combined_val_dataset = ConcatDataset([art_val, cartoon_val, photo_val])

combined_train_loader = DataLoader(
    combined_train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

combined_val_loader = DataLoader(
    combined_val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

# --- Sketch loader (target domain) ---
sketch_loader = DataLoader(
    sketch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    worker_init_fn=seed_worker,
    generator=g
)

envs = [art_train_loader, cartoon_train_loader, photo_train_loader]

# --- Debug info ---
print("Samples per domain:")
print(f"  Art total: {len(art_dataset)} | Train: {len(art_train)} | Val: {len(art_val)}")
print(f"  Cartoon total: {len(cartoon_dataset)} | Train: {len(cartoon_train)} | Val: {len(cartoon_val)}")
print(f"  Photo total: {len(photo_dataset)} | Train: {len(photo_train)} | Val: {len(photo_val)}")
print(f"  Sketch total (target): {len(sketch_dataset)}")
print()
print(f"Combined Train Samples: {len(combined_train_dataset)}")
print(f"Combined Val Samples:   {len(combined_val_dataset)}")

# --- Classes ---
pacs_classes = sketch_dataset.classes

Samples per domain:
  Art total: 2048 | Train: 1945 | Val: 103
  Cartoon total: 2344 | Train: 2226 | Val: 118
  Photo total: 1670 | Train: 1586 | Val: 84
  Sketch total (target): 3929

Combined Train Samples: 5757
Combined Val Samples:   305


In [5]:
def evaluate(model, loader, device):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model.predict(x)
            correct += (torch.argmax(preds, dim=1) == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

In [6]:
num_epochs = 9
print(min(len(loader) for loader in envs))
for epoch in range(num_epochs):
    irm.network.train()
    env_iters = [iter(loader) for loader in envs]
    num_batches = min(len(loader) for loader in envs)

    epoch_loss = 0.0
    epoch_penalty = 0.0
    epoch_avg_acc = 0.0
    epoch_worst_acc = 1.0  # start high, we’ll take min

    for batch_idx in range(num_batches):
        minibatches = []
        for i, env_loader in enumerate(envs):
            try:
                x, y = next(env_iters[i])
            except StopIteration:
                env_iters[i] = iter(env_loader)
                x, y = next(env_iters[i])
            minibatches.append((x.to(device), y.to(device)))

        metrics = irm.update(minibatches)

        epoch_loss += metrics['loss']
        epoch_penalty += metrics['penalty']
        epoch_avg_acc += metrics['avg_acc']
        epoch_worst_acc = min(epoch_worst_acc, metrics['worst_acc'])

    # Normalize epoch averages
    epoch_loss /= num_batches
    epoch_penalty /= num_batches
    epoch_avg_acc /= num_batches

    print(
        f"Epoch {epoch+1:02d}: "
        f"Loss {epoch_loss:.4f} | "
        f"Penalty {epoch_penalty:.4f} | "
        f"AvgAcc {epoch_avg_acc*100:.2f}% | "
        f"WorstAcc {epoch_worst_acc*100:.2f}%"
    )


25
Epoch 01: Loss 0.5788 | Penalty 0.0156 | AvgAcc 79.81% | WorstAcc 9.38%
Epoch 02: Loss 0.2259 | Penalty 0.0013 | AvgAcc 92.49% | WorstAcc 79.69%
Epoch 03: Loss 0.0919 | Penalty 0.0019 | AvgAcc 97.33% | WorstAcc 89.06%
Epoch 04: Loss 1.6118 | Penalty 0.0186 | AvgAcc 77.99% | WorstAcc 43.75%
Epoch 05: Loss 1.7837 | Penalty 0.0087 | AvgAcc 54.44% | WorstAcc 17.19%
Epoch 06: Loss 2.5697 | Penalty 0.0197 | AvgAcc 40.79% | WorstAcc 18.75%
Epoch 07: Loss 1.6821 | Penalty 0.0020 | AvgAcc 36.63% | WorstAcc 17.19%
Epoch 08: Loss 1.7365 | Penalty 0.0033 | AvgAcc 36.16% | WorstAcc 18.75%
Epoch 09: Loss 1.7429 | Penalty 0.0025 | AvgAcc 34.61% | WorstAcc 12.50%


### Evaluation on Source Domains

In [7]:
art_accuracy = evaluate(irm, art_val_loader, device)
print(f"Art Accuracy: {art_accuracy:.2f}%")

cartoon_accuracy = evaluate(irm, cartoon_val_loader, device)
print(f"Cartoon Accuracy: {cartoon_accuracy:.2f}%")

photo_accuracy = evaluate(irm, photo_val_loader, device)
print(f"Photo Accuracy: {photo_accuracy:.2f}%")

source_accuracy = evaluate(irm, combined_val_loader, device)
print(f"\nAll Source Domains Accuracy: {source_accuracy:.2f}%")

Art Accuracy: 42.72%
Cartoon Accuracy: 35.59%
Photo Accuracy: 57.14%

All Source Domains Accuracy: 43.93%


### Evaluation on Test Domain

In [8]:
sketch_accuracy = evaluate(irm, sketch_loader, device)
print(f"Sketch Accuracy: {sketch_accuracy:.2f}%")

Sketch Accuracy: 22.14%
