In [1]:
import torch
import torch.nn as nn
from domainbed.algorithms import IRM
from domainbed import algorithms, networks
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [2]:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(imagenet_mean, imagenet_std)
])

art_dir = "data/pacs_data/pacs_data/art_painting"
cartoon_dir = "data/pacs_data/pacs_data/cartoon"
photo_dir = "data/pacs_data/pacs_data/photo"
sketch_dir = "data/pacs_data/pacs_data/sketch"

art_dataset = datasets.ImageFolder(root=art_dir, transform=transform)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=transform)
photo_dataset = datasets.ImageFolder(root=photo_dir, transform=transform)
sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=transform)

art_loader = DataLoader(
    art_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

cartoon_loader = DataLoader(
    cartoon_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

photo_loader = DataLoader(
    photo_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

sketch_loader = DataLoader( # This is also the test domain loader
    sketch_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

source_dataset = ConcatDataset([art_dataset, cartoon_dataset, photo_dataset])

source_loader = DataLoader(
    source_dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

pacs_classes = sketch_dataset.classes

envs = [art_loader, cartoon_loader, photo_loader]

In [3]:
def evaluate(model, loader, device):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model.predict(x)
            correct += (torch.argmax(preds, dim=1) == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

In [4]:


# --- Hyperparameters ---
hparams = {
    'lr': 5e-5,
    'weight_decay': 0.0,
    'irm_lambda': 1e2,
    'irm_penalty_anneal_iters': 500,

    'nonlinear_classifier': 0,
    'resnet18': 0,
    'resnet_dropout': 0.0,
    'freeze_bn': 1,

    'mlp_width': 1024,
    'mlp_depth': 3,
    'mlp_dropout': 0.1,
    'vit': 0,
    'dinov2': 0,
    'vit_dropout': 0.0,
    'vit_attn_tune': 0,
}

# --- Initialize IRM ---
irm = algorithms.IRM(
    input_shape=(3, 224, 224),
    num_classes=7,
    num_domains=3,
    hparams=hparams
).to(device)




In [5]:
env_iters = [iter(loader) for loader in envs]

num_epochs = 50
for epoch in range(num_epochs):
    minibatches = []

    for i, env_iter in enumerate(env_iters):
        try:
            x, y = next(env_iter)
        except StopIteration:
            # Restart iterator when one env runs out
            env_iters[i] = iter(envs[i])
            x, y = next(env_iters[i])

        minibatches.append((x.to(device), y.to(device)))

    metrics = irm.update(minibatches)
    print(f"Epoch {epoch+1}: Loss {metrics['loss']:.4f} | Penalty {metrics['penalty']:.4f}")

Epoch 1: Loss 1.9791 | Penalty 0.0034
Epoch 2: Loss 1.8127 | Penalty 0.0088
Epoch 3: Loss 1.6558 | Penalty 0.0166
Epoch 4: Loss 1.6116 | Penalty 0.0478
Epoch 5: Loss 1.3973 | Penalty 0.0763


KeyboardInterrupt: 

In [None]:


sketch_acc = evaluate(irm, sketch_loader, device)
print(f"Sketch Accuracy: {sketch_acc:.2f}%")


Sketch Accuracy: 55.23%


In [None]:
art_accuracy = evaluate(irm, art_loader, device)
print(f"Art Accuracy: {art_accuracy:.2f}%")

cartoon_accuracy = evaluate(irm, cartoon_loader, device)
print(f"Cartoon Accuracy: {cartoon_accuracy:.2f}%")

photo_accuracy = evaluate(irm, photo_loader, device)
print(f"Photo Accuracy: {photo_accuracy:.2f}%")

source_accuracy = evaluate(irm, source_loader, device)
print(f"\nAll Source Domains Accuracy: {source_accuracy:.2f}%")

Art Accuracy: 92.63%
Cartoon Accuracy: 91.42%
Photo Accuracy: 97.31%

All Source Domains Accuracy: 93.45%
