In [1]:
import torch
from domainbed.algorithms import IRM
from domainbed import algorithms, networks
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"
print(device)

cuda


In [None]:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(imagenet_mean, imagenet_std)
])

art_dir = "../data/pacs_data/pacs_data/art_painting"
cartoon_dir = "../data/pacs_data/pacs_data/cartoon"
photo_dir = "../data/pacs_data/pacs_data/photo"
sketch_dir = "../data/pacs_data/pacs_data/sketch"

art_dataset = datasets.ImageFolder(root=art_dir, transform=transform)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=transform)
photo_dataset = datasets.ImageFolder(root=photo_dir, transform=transform)
sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=transform)

art_loader = DataLoader(
    art_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

cartoon_loader = DataLoader(
    cartoon_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

photo_loader = DataLoader(
    photo_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

sketch_loader = DataLoader( # This is also the test domain loader
    sketch_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

source_dataset = ConcatDataset([art_dataset, cartoon_dataset, photo_dataset])

source_loader = DataLoader(
    source_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

pacs_classes = sketch_dataset.classes

envs = [art_loader, cartoon_loader, photo_loader]

In [3]:
def evaluate(model, loader, device):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            preds = model.predict(x)
            correct += (torch.argmax(preds, dim=1) == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

In [4]:


# --- Hyperparameters ---
hparams = {
    'lr': 5e-5,
    'weight_decay': 0.0,
    'irm_lambda': 1e2,
    'irm_penalty_anneal_iters': 500,

    'nonlinear_classifier': 0,
    'resnet18': 0,
    'resnet50_augmix': 0,
    'resnet_dropout': 0.0,
    'freeze_bn': 1,

    'mlp_width': 1024,
    'mlp_depth': 3,
    'mlp_dropout': 0.1,
    'vit': 0,
    'dinov2': 0,
    'vit_dropout': 0.0,
    'vit_attn_tune': 0,
}

# --- Initialize IRM ---
irm = algorithms.IRM(
    input_shape=(3, 224, 224),
    num_classes=7,
    num_domains=3,
    hparams=hparams
).to(device)




In [5]:
env_iters = [iter(loader) for loader in envs]

num_epochs = 50
for epoch in range(num_epochs):
    minibatches = []

    for i, env_iter in enumerate(env_iters):
        try:
            x, y = next(env_iter)
        except StopIteration:
            # Restart iterator when one env runs out
            env_iters[i] = iter(envs[i])
            x, y = next(env_iters[i])

        minibatches.append((x.to(device), y.to(device)))

    metrics = irm.update(minibatches)
    print(f"Epoch {epoch+1}: Loss {metrics['loss']:.4f} | Penalty {metrics['penalty']:.4f}")

Epoch 1: Loss 2.0211 | Penalty 0.0135
Epoch 2: Loss 1.8529 | Penalty 0.0019
Epoch 3: Loss 1.8536 | Penalty 0.0083
Epoch 4: Loss 1.7267 | Penalty 0.0267
Epoch 5: Loss 1.5446 | Penalty 0.0495
Epoch 6: Loss 1.5711 | Penalty 0.0406
Epoch 7: Loss 1.3225 | Penalty 0.0780
Epoch 8: Loss 1.2488 | Penalty 0.0927
Epoch 9: Loss 1.1006 | Penalty 0.0886
Epoch 10: Loss 1.0906 | Penalty 0.0105
Epoch 11: Loss 0.9228 | Penalty 0.0225
Epoch 12: Loss 0.9930 | Penalty 0.0029
Epoch 13: Loss 0.7422 | Penalty 0.0005
Epoch 14: Loss 0.7960 | Penalty 0.0153
Epoch 15: Loss 0.6843 | Penalty 0.0262
Epoch 16: Loss 0.6545 | Penalty 0.0083
Epoch 17: Loss 0.4529 | Penalty 0.0461
Epoch 18: Loss 0.5245 | Penalty -0.0165
Epoch 19: Loss 0.6504 | Penalty 0.0905
Epoch 20: Loss 0.3709 | Penalty -0.0166
Epoch 21: Loss 0.3590 | Penalty -0.0005
Epoch 22: Loss 0.5656 | Penalty 0.0390
Epoch 23: Loss 0.4181 | Penalty 0.0016
Epoch 24: Loss 0.3185 | Penalty 0.0107
Epoch 25: Loss 0.3388 | Penalty -0.0040
Epoch 26: Loss 0.4006 | Penalt

In [6]:


sketch_acc = evaluate(irm, sketch_loader, device)
print(f"Sketch Accuracy: {sketch_acc:.2f}%")


Sketch Accuracy: 65.92%


In [7]:
art_accuracy = evaluate(irm, art_loader, device)
print(f"Art Accuracy: {art_accuracy:.2f}%")

cartoon_accuracy = evaluate(irm, cartoon_loader, device)
print(f"Cartoon Accuracy: {cartoon_accuracy:.2f}%")

photo_accuracy = evaluate(irm, photo_loader, device)
print(f"Photo Accuracy: {photo_accuracy:.2f}%")

source_accuracy = evaluate(irm, source_loader, device)
print(f"\nAll Source Domains Accuracy: {source_accuracy:.2f}%")

Art Accuracy: 94.58%
Cartoon Accuracy: 92.58%
Photo Accuracy: 98.56%

All Source Domains Accuracy: 94.90%
