In [None]:
!pip install torch torchvision datasets tqdm scikit-learn




In [None]:
from datasets import load_dataset

dataset = load_dataset('flwrlabs/pacs')
print(dataset)
source_domains=["art_painting","cartoon","photo"]
target_domain="sketch"


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9991 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'domain', 'label'],
        num_rows: 9991
    })
})


In [None]:
# === Imports ===
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
import torch

# === Load PACS Dataset ===

# === Define source & target domains ===
source_domains = ["art_painting", "cartoon", "photo"]
target_domain = "sketch"

# === Image transformations (resize, normalize, etc.) ===
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# === Helper: filter by domain ===
def filter_by_domain(ds, domain_name):
    return ds.filter(lambda x: x["domain"] == domain_name)

# === Create source & target datasets ===
source_splits = [filter_by_domain(dataset["train"], d) for d in source_domains]
target_split = filter_by_domain(dataset["train"], target_domain)

# === Custom PyTorch dataset wrapper ===
class PACSDataset(Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img = self.dataset[idx]["image"].convert("RGB")
        label = self.dataset[idx]["label"]
        if self.transform:
            img = self.transform(img)
        return img, label

# === Wrap in PACSDataset and combine ===
train_datasets = [PACSDataset(d, transform=transform) for d in source_splits]
target_dataset = PACSDataset(target_split, transform=transform)

train_dataset = ConcatDataset(train_datasets)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
target_loader = DataLoader(target_dataset, batch_size=32, shuffle=False)

# === Print dataset sizes for confirmation ===
print("Number of images per source domain:")
for domain, ds in zip(source_domains, source_splits):
    print(f"  {domain}: {len(ds)}")

print(f"\nTarget domain ({target_domain}): {len(target_split)}")
print(f"\nTotal training images (all sources combined): {len(train_dataset)}")


Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Number of images per source domain:
  art_painting: 2048
  cartoon: 2344
  photo: 1670

Target domain (sketch): 3929

Total training images (all sources combined): 6062


In [None]:
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from tqdm import tqdm
socurce_domains=["art_painting","cartoon","photo"]


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

model = models.resnet50(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 7)  # 7 classes in PACS
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}")





Using device: cuda




Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 200MB/s]
Epoch 1/5: 100%|██████████| 190/190 [01:18<00:00,  2.43it/s]


Epoch 1, Loss: 0.4697


Epoch 2/5: 100%|██████████| 190/190 [01:18<00:00,  2.44it/s]


Epoch 2, Loss: 0.2525


Epoch 3/5: 100%|██████████| 190/190 [01:17<00:00,  2.44it/s]


Epoch 3, Loss: 0.1932


Epoch 4/5: 100%|██████████| 190/190 [01:18<00:00,  2.42it/s]


Epoch 4, Loss: 0.1133


Epoch 5/5: 100%|██████████| 190/190 [01:17<00:00,  2.44it/s]

Epoch 5, Loss: 0.1269





In [None]:
source_domains = ["art_painting", "cartoon", "photo"]
target_domain = "sketch"

def evaluate(domain_name):
    model.eval()
    # Filter dataset by domain name
    domain_data = dataset["train"].filter(lambda x: x["domain"] == domain_name)
    loader = DataLoader(PACSDataset(domain_data, transform=transform),
                        batch_size=32, shuffle=False)
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs).argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    acc = 100 * correct / total
    print(f"{domain_name} Accuracy: {acc:.2f}%")
    return acc

accs = []
for d in source_domains:
    accs.append(evaluate(d))

target_acc = evaluate(target_domain)
print("\nAverage Source Accuracy:", sum(accs)/len(accs))
print("Target (Unseen Domain) Accuracy:", target_acc)


Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

art_painting Accuracy: 95.51%


Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

cartoon Accuracy: 93.77%


Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

photo Accuracy: 98.44%


Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

sketch Accuracy: 68.16%

Average Source Accuracy: 95.90741911015853
Target (Unseen Domain) Accuracy: 68.15983710867906


In [None]:
# ==== Setup: make per-domain datasets & loaders ====
from torch.utils.data import DataLoader

domains = ["art_painting", "cartoon", "photo"]   # sources
target_domain = "sketch"

# Create a DataLoader for each source domain (use the same transform)
domain_loaders = {}
batch_size = 32
for d in domains:
    ds = dataset["train"].filter(lambda x: x["domain"] == d)
    domain_loaders[d] = DataLoader(PACSDataset(ds, transform=transform),
                                   batch_size=batch_size, shuffle=True, drop_last=True)

# target loader for eval
target_ds = dataset["train"].filter(lambda x: x["domain"] == target_domain)
target_loader = DataLoader(PACSDataset(target_ds, transform=transform), batch_size=batch_size, shuffle=False)

# To iterate fairly across domains in training, we'll loop by the max number of batches
max_batches = max(len(loader) for loader in domain_loaders.values())

# ==== IRM helpers ====
def irm_penalty_from_loss(loss, dummy_w):
    grad = torch.autograd.grad(loss, dummy_w, create_graph=True)[0]
    return torch.sum(grad**2)

# ==== Training hyperparams ====
warmup_epochs = 3
total_epochs = 30
lambda_irm = 1.0    # start with 1.0 (try 0.1, 1.0, 5.0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# optimizer, criterion assumed defined: optimizer, criterion

best_target = -1.0

for epoch in range(total_epochs):
    model.train()
    epoch_loss = 0.0
    epoch_penalty = 0.0
    n_steps = 0

    # Create iterators for each domain loader
    iters = {d: iter(domain_loaders[d]) for d in domains}

    for step in range(max_batches):
        # accumulate domain-wise loss & penalty
        total_obj = 0.0
        total_loss = 0.0
        total_pen = 0.0

        for d in domains:
            try:
                imgs, labels = next(iters[d])
            except StopIteration:
                # re-create iterator if shorter epoch
                iters[d] = iter(domain_loaders[d])
                imgs, labels = next(iters[d])

            imgs, labels = imgs.to(device), labels.to(device)

            # dummy scalar per-domain
            dummy_w = torch.tensor(1.0, requires_grad=True, device=device)

            outputs = model(imgs) * dummy_w
            loss_e = criterion(outputs, labels)
            pen_e = irm_penalty_from_loss(loss_e, dummy_w)

            total_loss += loss_e
            total_pen += pen_e

        # total per-step objective: sum over domains
        if epoch < warmup_epochs:
            step_obj = total_loss
        else:
            step_obj = total_loss + lambda_irm * total_pen

        optimizer.zero_grad()
        step_obj.backward()
        optimizer.step()

        epoch_loss += total_loss.item()
        epoch_penalty += total_pen.item()
        n_steps += 1

    avg_loss = epoch_loss / n_steps / len(domains)   # avg per-domain loss
    avg_pen = epoch_penalty / n_steps / len(domains)
    print(f"Epoch {epoch+1}/{total_epochs}  avg_loss {avg_loss:.4f}  avg_penalty {avg_pen:.6f}")

    # lightweight eval every few epochs
    if (epoch+1) % 2 == 0:
        # evaluate function (per-domain) from earlier
        src_accs = [evaluate(d) for d in domains]
        tgt_acc = evaluate(target_domain)
        avg_src = sum(src_accs) / len(src_accs)
        print(f"Eval --> Avg source acc: {avg_src:.2f}  Target acc: {tgt_acc:.2f}")

        if tgt_acc > best_target:
            best_target = tgt_acc
            torch.save(model.state_dict(), "best_irm_perdomain.pt")
            print("Saved best model by target acc.")


Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Epoch 1/30  avg_loss 0.7881  avg_penalty 0.021335
Epoch 2/30  avg_loss 0.6584  avg_penalty 0.006975
art_painting Accuracy: 92.53%
cartoon Accuracy: 87.16%
photo Accuracy: 97.78%
sketch Accuracy: 72.16%
Eval --> Avg source acc: 92.49  Target acc: 72.16
Saved best model by target acc.
Epoch 3/30  avg_loss 0.6236  avg_penalty 0.005948
Epoch 4/30  avg_loss 0.5875  avg_penalty 0.005540
art_painting Accuracy: 94.73%
cartoon Accuracy: 94.54%
photo Accuracy: 98.98%
sketch Accuracy: 71.70%
Eval --> Avg source acc: 96.08  Target acc: 71.70
Epoch 5/30  avg_loss 0.5742  avg_penalty 0.004448
Epoch 6/30  avg_loss 0.5612  avg_penalty 0.003616
art_painting Accuracy: 96.29%
cartoon Accuracy: 96.08%
photo Accuracy: 99.10%
sketch Accuracy: 78.19%
Eval --> Avg source acc: 97.16  Target acc: 78.19
Saved best model by target acc.
Epoch 7/30  avg_loss 0.5477  avg_penalty 0.003217
Epoch 8/30  avg_loss 0.5343  avg_penalty 0.002638
art_painting Accuracy: 96.29%
cartoon Accuracy: 97.23%
photo Accuracy: 99.28%
sk

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models
import numpy as np
from datasets import load_dataset
from tqdm import tqdm

# ---- Load PACS Dataset ----
dataset = load_dataset('flwrlabs/pacs')
print(dataset)

# ---- Settings ----
domains = ["art_painting", "cartoon", "photo"]
target_domain = "sketch"
batch_size = 32
warmup_epochs = 5
total_epochs = 30
group_eta = 0.01  # ✅ MUCH LOWER (critical fix!)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---- Strong Data Augmentation ----
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomGrayscale(p=0.3),  # Help with sketch domain
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# ---- PyTorch Dataset Wrapper ----
class PACSDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img = self.dataset[idx]["image"].convert("RGB")
        label = self.dataset[idx]["label"]
        if self.transform:
            img = self.transform(img)
        return img, label

# ---- Prepare loaders ----
domain_loaders = {}
for d in domains:
    ds = dataset["train"].filter(lambda x: x["domain"] == d)
    domain_loaders[d] = DataLoader(
        PACSDataset(ds, transform=transform),
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=2,
        pin_memory=True
    )

target_ds = dataset["train"].filter(lambda x: x["domain"] == target_domain)
target_loader = DataLoader(PACSDataset(target_ds, transform=transform),
                          batch_size=batch_size, shuffle=False)

max_batches = max(len(loader) for loader in domain_loaders.values())

# ---- Model setup with regularization ----
num_classes = 7  # PACS has 7 classes
model = models.resnet50(pretrained=True)
model.fc = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.fc.in_features, num_classes)
)
model = model.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # ✅ Label smoothing
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3)  # ✅ Regularization
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=total_epochs)

# ---- Initialize group weights (log space for stability) ----
log_weights = np.zeros(len(domains), dtype=np.float32)  # ✅ Work in log space
domain_idx = {d: i for i, d in enumerate(domains)}

def get_normalized_weights(log_w):
    """Convert log weights to normalized probabilities"""
    # Numerically stable softmax
    max_log = np.max(log_w)
    exp_w = np.exp(log_w - max_log)
    return exp_w / (np.sum(exp_w) + 1e-12)

# ---- Evaluation helpers ----
def eval_domain_acc(domain_name):
    model.eval()
    ds = dataset["train"].filter(lambda x: x["domain"] == domain_name)
    loader = DataLoader(PACSDataset(ds, transform=transform),
                       batch_size=batch_size, shuffle=False)
    correct = total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs).argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100.0 * correct / total

def eval_all_domains():
    src_accs = [eval_domain_acc(d) for d in domains]
    tgt_acc = eval_domain_acc(target_domain)
    return src_accs, tgt_acc

# ---- Training loop ----
best_tgt = -1.0
best_worst_src = -1.0
history = {
    "epoch": [],
    "avg_source": [],
    "worst_source": [],
    "target": [],
    "weights": [],
    "domain_losses": []
}

print(f"\nTraining Configuration:")
print(f"  Warmup Epochs: {warmup_epochs}")
print(f"  Total Epochs: {total_epochs}")
print(f"  Group DRO η: {group_eta}")
print(f"  Learning Rate: 1e-4")
print(f"  Weight Decay: 1e-3")
print("-" * 70)

for epoch in range(total_epochs):
    model.train()
    is_warmup = (epoch < warmup_epochs)

    iters = {d: iter(domain_loaders[d]) for d in domains}
    epoch_loss = 0.0
    epoch_domain_losses = [0.0] * len(domains)
    epoch_steps = 0

    progress_bar = tqdm(range(max_batches), desc=f"Epoch {epoch+1}/{total_epochs}")

    for step in progress_bar:
        losses = []

        # Collect loss from each domain
        for i, d in enumerate(domains):
            try:
                imgs, labels = next(iters[d])
            except StopIteration:
                iters[d] = iter(domain_loaders[d])
                imgs, labels = next(iters[d])

            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss_e = criterion(outputs, labels)
            losses.append(loss_e)
            epoch_domain_losses[i] += loss_e.item()

        # Compute weighted objective
        if is_warmup:
            # Uniform weighting during warmup
            step_loss = sum(losses) / len(losses)
        else:
            # ✅ CORRECT GROUP DRO: Use current normalized weights
            p = get_normalized_weights(log_weights)
            p_tensor = torch.tensor(p, dtype=torch.float32, device=device)
            stacked = torch.stack(losses)
            step_loss = torch.dot(p_tensor, stacked)

        # Optimization step
        optimizer.zero_grad()
        step_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # ✅ CORRECT WEIGHT UPDATE (after backward pass)
        if not is_warmup:
            with torch.no_grad():
                loss_vals = np.array([float(l.item()) for l in losses], dtype=np.float32)
                # Exponentiated gradient ascent in log space
                # log_w_new = log_w_old + η * loss
                log_weights += group_eta * loss_vals

                # Optional: Add entropy regularization to prevent extreme weights
                p_current = get_normalized_weights(log_weights)
                entropy = -np.sum(p_current * np.log(p_current + 1e-12))
                # Bonus: Penalize low entropy (forces more uniform distribution)
                log_weights -= 0.001 * np.log(p_current + 1e-12)

        epoch_loss += step_loss.item()
        epoch_steps += 1

        # Update progress bar
        if not is_warmup:
            p_display = get_normalized_weights(log_weights)
            progress_bar.set_postfix({
                'loss': f'{step_loss.item():.4f}',
                'p_art': f'{p_display[0]:.2f}',
                'p_car': f'{p_display[1]:.2f}',
                'p_pho': f'{p_display[2]:.2f}'
            })

    scheduler.step()

    # Average losses
    avg_epoch_loss = epoch_loss / epoch_steps
    avg_domain_losses = [l / epoch_steps for l in epoch_domain_losses]

    # Evaluation
    src_accs, tgt_acc = eval_all_domains()
    avg_src = sum(src_accs) / len(src_accs)
    worst_src = min(src_accs)

    # Get current weights
    if is_warmup:
        p = np.ones(len(domains)) / len(domains)
    else:
        p = get_normalized_weights(log_weights)

    # Store history
    history["epoch"].append(epoch + 1)
    history["avg_source"].append(avg_src)
    history["worst_source"].append(worst_src)
    history["target"].append(tgt_acc)
    history["weights"].append(p.copy())
    history["domain_losses"].append(avg_domain_losses)

    # Print results
    print(f"\nEpoch {epoch+1}/{total_epochs}")
    print(f"  Loss: {avg_epoch_loss:.4f}")
    print(f"  Source Accs: {[f'{a:.2f}%' for a in src_accs]}")
    print(f"  Avg Source: {avg_src:.2f}%  |  Worst Source: {worst_src:.2f}%")
    print(f"  Target: {tgt_acc:.2f}%")
    print(f"  Weights: {dict(zip(domains, [f'{p[i]:.3f}' for i in range(len(domains))]))}")
    print(f"  Domain Losses: {[f'{l:.4f}' for l in avg_domain_losses]}")

    # Save best models
    saved = False
    if tgt_acc > best_tgt:
        best_tgt = tgt_acc
        torch.save(model.state_dict(), "best_groupdro_target.pt")
        print(f"  ✅ Best target accuracy: {best_tgt:.2f}%")
        saved = True

    if worst_src > best_worst_src:
        best_worst_src = worst_src
        torch.save(model.state_dict(), "best_groupdro_worst.pt")
        if not saved:
            print(f"  ✅ Best worst-source accuracy: {best_worst_src:.2f}%")

    print("-" * 70)

# ---- Final Summary ----
print("\n" + "=" * 70)
print("TRAINING COMPLETE")
print("=" * 70)
print(f"Best Target Accuracy: {best_tgt:.2f}%")
print(f"Best Worst-Source Accuracy: {best_worst_src:.2f}%")
print("\nFinal weights evolution:")
for i in range(0, len(history["epoch"]), 5):
    e = history["epoch"][i]
    w = history["weights"][i]
    print(f"  Epoch {e:2d}: {dict(zip(domains, [f'{w[j]:.3f}' for j in range(len(domains))]))}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9991 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'domain', 'label'],
        num_rows: 9991
    })
})


Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]



Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 213MB/s]



Training Configuration:
  Warmup Epochs: 5
  Total Epochs: 30
  Group DRO η: 0.01
  Learning Rate: 1e-4
  Weight Decay: 1e-3
----------------------------------------------------------------------


Epoch 1/30: 100%|██████████| 73/73 [01:11<00:00,  1.02it/s]


Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]

Filter:   0%|          | 0/9991 [00:00<?, ? examples/s]


Epoch 1/30
  Loss: 0.8523
  Source Accs: ['93.07%', '92.96%', '98.02%']
  Avg Source: 94.68%  |  Worst Source: 92.96%
  Target: 70.20%
  Weights: {'art_painting': '0.333', 'cartoon': '0.333', 'photo': '0.333'}
  Domain Losses: ['0.9068', '0.9338', '0.7162']
  ✅ Best target accuracy: 70.20%
----------------------------------------------------------------------


Epoch 2/30: 100%|██████████| 73/73 [01:11<00:00,  1.02it/s]



Epoch 2/30
  Loss: 0.6023
  Source Accs: ['94.73%', '95.56%', '98.74%']
  Avg Source: 96.34%  |  Worst Source: 94.73%
  Target: 76.46%
  Weights: {'art_painting': '0.333', 'cartoon': '0.333', 'photo': '0.333'}
  Domain Losses: ['0.6374', '0.6313', '0.5381']
  ✅ Best target accuracy: 76.46%
----------------------------------------------------------------------


Epoch 3/30: 100%|██████████| 73/73 [01:11<00:00,  1.03it/s]



Epoch 3/30
  Loss: 0.5647
  Source Accs: ['97.41%', '96.25%', '99.34%']
  Avg Source: 97.67%  |  Worst Source: 96.25%
  Target: 75.62%
  Weights: {'art_painting': '0.333', 'cartoon': '0.333', 'photo': '0.333'}
  Domain Losses: ['0.5943', '0.5803', '0.5195']
  ✅ Best worst-source accuracy: 96.25%
----------------------------------------------------------------------


Epoch 4/30: 100%|██████████| 73/73 [01:10<00:00,  1.03it/s]



Epoch 4/30
  Loss: 0.5494
  Source Accs: ['96.58%', '96.93%', '99.40%']
  Avg Source: 97.64%  |  Worst Source: 96.58%
  Target: 79.38%
  Weights: {'art_painting': '0.333', 'cartoon': '0.333', 'photo': '0.333'}
  Domain Losses: ['0.5775', '0.5669', '0.5039']
  ✅ Best target accuracy: 79.38%
----------------------------------------------------------------------


Epoch 5/30: 100%|██████████| 73/73 [01:10<00:00,  1.04it/s]



Epoch 5/30
  Loss: 0.5340
  Source Accs: ['98.00%', '97.95%', '99.58%']
  Avg Source: 98.51%  |  Worst Source: 97.95%
  Target: 71.06%
  Weights: {'art_painting': '0.333', 'cartoon': '0.333', 'photo': '0.333'}
  Domain Losses: ['0.5498', '0.5560', '0.4961']
  ✅ Best worst-source accuracy: 97.95%
----------------------------------------------------------------------


Epoch 6/30: 100%|██████████| 73/73 [01:10<00:00,  1.03it/s, loss=0.5173, p_art=0.34, p_car=0.33, p_pho=0.33]



Epoch 6/30
  Loss: 0.5151
  Source Accs: ['97.56%', '96.54%', '99.16%']
  Avg Source: 97.75%  |  Worst Source: 96.54%
  Target: 76.18%
  Weights: {'art_painting': '0.338', 'cartoon': '0.335', 'photo': '0.327'}
  Domain Losses: ['0.5356', '0.5211', '0.4881']
----------------------------------------------------------------------


Epoch 7/30: 100%|██████████| 73/73 [01:10<00:00,  1.03it/s, loss=0.5520, p_art=0.34, p_car=0.34, p_pho=0.32]



Epoch 7/30
  Loss: 0.5161
  Source Accs: ['97.80%', '97.01%', '99.58%']
  Avg Source: 98.13%  |  Worst Source: 97.01%
  Target: 77.68%
  Weights: {'art_painting': '0.341', 'cartoon': '0.337', 'photo': '0.322'}
  Domain Losses: ['0.5299', '0.5255', '0.4920']
----------------------------------------------------------------------


Epoch 8/30: 100%|██████████| 73/73 [01:10<00:00,  1.03it/s, loss=0.5127, p_art=0.34, p_car=0.34, p_pho=0.32]



Epoch 8/30
  Loss: 0.5066
  Source Accs: ['98.44%', '98.42%', '99.46%']
  Avg Source: 98.77%  |  Worst Source: 98.42%
  Target: 75.64%
  Weights: {'art_painting': '0.344', 'cartoon': '0.339', 'photo': '0.317'}
  Domain Losses: ['0.5221', '0.5153', '0.4808']
  ✅ Best worst-source accuracy: 98.42%
----------------------------------------------------------------------


Epoch 9/30: 100%|██████████| 73/73 [01:11<00:00,  1.02it/s, loss=0.5114, p_art=0.35, p_car=0.34, p_pho=0.31]



Epoch 9/30
  Loss: 0.4991
  Source Accs: ['98.49%', '98.46%', '99.46%']
  Avg Source: 98.80%  |  Worst Source: 98.46%
  Target: 76.46%
  Weights: {'art_painting': '0.348', 'cartoon': '0.338', 'photo': '0.314'}
  Domain Losses: ['0.5161', '0.4971', '0.4827']
  ✅ Best worst-source accuracy: 98.46%
----------------------------------------------------------------------


Epoch 10/30: 100%|██████████| 73/73 [01:10<00:00,  1.03it/s, loss=0.4775, p_art=0.35, p_car=0.34, p_pho=0.31]



Epoch 10/30
  Loss: 0.4921
  Source Accs: ['98.68%', '98.63%', '99.76%']
  Avg Source: 99.03%  |  Worst Source: 98.63%
  Target: 79.05%
  Weights: {'art_painting': '0.348', 'cartoon': '0.339', 'photo': '0.313'}
  Domain Losses: ['0.4989', '0.4980', '0.4782']
  ✅ Best worst-source accuracy: 98.63%
----------------------------------------------------------------------


Epoch 11/30: 100%|██████████| 73/73 [01:10<00:00,  1.03it/s, loss=0.5005, p_art=0.35, p_car=0.34, p_pho=0.31]



Epoch 11/30
  Loss: 0.4865
  Source Accs: ['98.78%', '98.68%', '99.64%']
  Avg Source: 99.03%  |  Worst Source: 98.68%
  Target: 82.64%
  Weights: {'art_painting': '0.349', 'cartoon': '0.340', 'photo': '0.311'}
  Domain Losses: ['0.4934', '0.4937', '0.4711']
  ✅ Best target accuracy: 82.64%
----------------------------------------------------------------------


Epoch 12/30:  73%|███████▎  | 53/73 [00:52<00:19,  1.02it/s, loss=0.4650, p_art=0.35, p_car=0.34, p_pho=0.31]


KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models
from datasets import load_dataset
from tqdm import tqdm
import copy

# ===========================
# SAM Optimizer Implementation
# ===========================
class SAM(torch.optim.Optimizer):
    """
    Sharpness-Aware Minimization (SAM) Optimizer

    Paper: "Sharpness-Aware Minimization for Efficiently Improving Generalization"
    https://arxiv.org/abs/2010.01412

    SAM seeks parameters that lie in neighborhoods having uniformly low loss,
    leading to flatter minima and better generalization.

    Args:
        params: Model parameters
        base_optimizer: Base optimizer (e.g., SGD, Adam)
        rho: Neighborhood size for perturbation (default: 0.05)
        adaptive: Use adaptive SAM (ASAM) - scale perturbation per parameter
    """
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        """
        First step: Compute adversarial perturbation and move to worst-case point
        Maximizes loss: θ_adv = θ + ε where ε = ρ * grad / ||grad||
        """
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None:
                    continue

                # Save original parameters
                self.state[p]["old_p"] = p.data.clone()

                # Compute perturbation
                if group["adaptive"]:
                    # ASAM: scale by parameter magnitude
                    e_w = (torch.pow(p, 2) if group.get("adaptive", False) else 1.0) * p.grad * scale.to(p)
                else:
                    # Standard SAM
                    e_w = p.grad * scale.to(p)

                # Move to adversarial point: θ_adv = θ + ε
                p.add_(e_w)

        if zero_grad:
            self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        """
        Second step: Update parameters based on gradient at adversarial point
        Returns to original point and applies actual update
        """
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue

                # Return to original point
                p.data = self.state[p]["old_p"]

        # Apply base optimizer update with gradient from adversarial point
        self.base_optimizer.step()

        if zero_grad:
            self.zero_grad()

    def step(self, closure=None):
        """
        Single optimization step (not typically used with SAM)
        Use first_step() and second_step() instead
        """
        raise NotImplementedError("SAM requires calling first_step() and second_step() separately")

    def _grad_norm(self):
        """Compute gradient norm across all parameters"""
        shared_device = self.param_groups[0]["params"][0].device
        norm = torch.norm(
            torch.stack([
                ((torch.abs(p) if group.get("adaptive", False) else 1.0) * p.grad).norm(p=2).to(shared_device)
                for group in self.param_groups
                for p in group["params"]
                if p.grad is not None
            ]),
            p=2
        )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

    def zero_grad(self, set_to_none: bool = False):
        self.base_optimizer.zero_grad(set_to_none)


# ===========================
# Dataset Setup
# ===========================
print("Loading PACS dataset...")
dataset = load_dataset('flwrlabs/pacs')

source_domains = ["art_painting", "cartoon", "photo"]
target_domain = "sketch"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}\n")

# Strong data augmentation for domain generalization
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomGrayscale(p=0.3),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# PyTorch Dataset wrapper
class PACSDataset(torch.utils.data.Dataset):
    def __init__(self, hf_dataset, transform=None):
        self.dataset = hf_dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img = self.dataset[idx]["image"].convert("RGB")
        label = self.dataset[idx]["label"]
        if self.transform:
            img = self.transform(img)
        return img, label

# Create per-domain loaders
domain_loaders = {}
for d in source_domains:
    ds = dataset["train"].filter(lambda x: x["domain"] == d)
    domain_loaders[d] = DataLoader(
        PACSDataset(ds, transform=transform),
        batch_size=32,
        shuffle=True,
        drop_last=True,
        num_workers=2,
        pin_memory=True
    )

target_ds = dataset["train"].filter(lambda x: x["domain"] == target_domain)
target_loader = DataLoader(
    PACSDataset(target_ds, transform=transform),
    batch_size=32,
    shuffle=False
)

max_batches = max(len(loader) for loader in domain_loaders.values())

# ===========================
# Model Setup
# ===========================
print("Setting up model...")
num_classes = 7
model = models.resnet50(pretrained=True)
model.fc = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(model.fc.in_features, num_classes)
)
model = model.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# ===========================
# SAM Optimizer Setup
# ===========================
# Base optimizer configuration
base_optimizer = torch.optim.SGD  # SAM works best with SGD
optimizer = SAM(
    model.parameters(),
    base_optimizer,
    lr=0.01,
    momentum=0.9,
    weight_decay=5e-4,
    rho=0.05,  # Perturbation radius (key SAM hyperparameter)
    adaptive=False  # Set True for ASAM variant
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer.base_optimizer, T_max=30)

# ===========================
# Evaluation Function
# ===========================
def eval_domain_acc(domain_name):
    model.eval()
    ds = dataset["train"].filter(lambda x: x["domain"] == domain_name)
    loader = DataLoader(
        PACSDataset(ds, transform=transform),
        batch_size=32,
        shuffle=False
    )
    correct = total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            preds = model(imgs).argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100.0 * correct / total

def eval_all_domains():
    src_accs = [eval_domain_acc(d) for d in source_domains]
    tgt_acc = eval_domain_acc(target_domain)
    return src_accs, tgt_acc

# ===========================
# SAM Training Loop
# ===========================
print(f"\nTraining Configuration:")
print(f"  Method: Sharpness-Aware Minimization (SAM)")
print(f"  Epochs: 30")
print(f"  Base Optimizer: SGD")
print(f"  Learning Rate: 0.01")
print(f"  Weight Decay: 5e-4")
print(f"  SAM rho (ρ): 0.05")
print(f"  Label Smoothing: 0.1")
print("-" * 70)

best_tgt = -1.0
best_worst_src = -1.0
history = {"epoch": [], "avg_source": [], "worst_source": [], "target": [], "train_loss": []}

num_epochs = 15

for epoch in range(num_epochs):
    model.train()

    # Create iterators for each domain
    domain_iters = {d: iter(domain_loaders[d]) for d in source_domains}

    epoch_loss = 0.0
    epoch_correct = 0
    epoch_total = 0

    progress_bar = tqdm(range(max_batches), desc=f"Epoch {epoch+1}/{num_epochs}")

    for step in progress_bar:
        # Collect batch from each domain
        all_imgs = []
        all_labels = []

        for d in source_domains:
            try:
                imgs, labels = next(domain_iters[d])
            except StopIteration:
                domain_iters[d] = iter(domain_loaders[d])
                imgs, labels = next(domain_iters[d])

            all_imgs.append(imgs)
            all_labels.append(labels)

        # Concatenate all domain batches
        imgs = torch.cat(all_imgs, dim=0).to(device)
        labels = torch.cat(all_labels, dim=0).to(device)

        # ===========================
        # SAM Two-Step Update
        # ===========================

        # First forward-backward pass (compute adversarial perturbation)
        def closure():
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            return loss, outputs

        # Step 1: Compute gradient and move to adversarial point
        optimizer.zero_grad()
        loss, outputs = closure()
        loss.backward()
        optimizer.first_step(zero_grad=True)

        # Step 2: Compute gradient at adversarial point and update
        loss_adv, _ = closure()
        loss_adv.backward()
        optimizer.second_step(zero_grad=True)

        # Track metrics
        with torch.no_grad():
            preds = outputs.argmax(dim=1)
            epoch_correct += (preds == labels).sum().item()
            epoch_total += labels.size(0)

        epoch_loss += loss.item()

        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100.0 * epoch_correct / epoch_total:.2f}%'
        })

    # Learning rate scheduling
    scheduler.step()

    # Calculate metrics
    avg_epoch_loss = epoch_loss / max_batches
    train_acc = 100.0 * epoch_correct / epoch_total

    # Evaluation
    src_accs, tgt_acc = eval_all_domains()
    avg_src = sum(src_accs) / len(src_accs)
    worst_src = min(src_accs)

    # Store history
    history["epoch"].append(epoch + 1)
    history["avg_source"].append(avg_src)
    history["worst_source"].append(worst_src)
    history["target"].append(tgt_acc)
    history["train_loss"].append(avg_epoch_loss)

    # Print results
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"  Train Loss: {avg_epoch_loss:.4f}  |  Train Acc: {train_acc:.2f}%")
    print(f"  Source Accs: {[f'{a:.2f}%' for a in src_accs]}")
    print(f"  Avg Source: {avg_src:.2f}%  |  Worst Source: {worst_src:.2f}%")
    print(f"  Target: {tgt_acc:.2f}%")

    # Save best models
    saved = False
    if tgt_acc > best_tgt:
        best_tgt = tgt_acc
        torch.save(model.state_dict(), "best_sam_target.pt")
        print(f"  ✅ Best target accuracy: {best_tgt:.2f}%")
        saved = True

    if worst_src > best_worst_src:
        best_worst_src = worst_src
        torch.save(model.state_dict(), "best_sam_worst.pt")
        if not saved:
            print(f"  ✅ Best worst-source accuracy: {best_worst_src:.2f}%")

    print("-" * 70)

# ===========================
# Final Results
# ===========================
print("\n" + "=" * 70)
print("TRAINING COMPLETE - SAM RESULTS")
print("=" * 70)
print(f"Best Target Accuracy: {best_tgt:.2f}%")
print(f"Best Worst-Source Accuracy: {best_worst_src:.2f}%")
print("\nTarget accuracy evolution:")
for i in range(0, len(history["epoch"]), 5):
    e = history["epoch"][i]
    t = history["target"][i]
    s = history["avg_source"][i]
    print(f"  Epoch {e:2d}: Target={t:.2f}%  Avg Source={s:.2f}%")

print("\n" + "=" * 70)
print("COMPARISON WITH BASELINES")
print("=" * 70)
print("Method          | Target Acc | Notes")
print("-" * 70)
print(f"ERM Baseline    | ~64%       | Your original result")
print(f"IRM             | 82.44%     | Best overall (invariant risk)")
print(f"Group DRO       | 78.21%     | Worst-case optimization")
print(f"SAM (this run)  | {best_tgt:.2f}%     | Flat minima seeking")
print("=" * 70)

# Load best model for final test
print("\nLoading best model for final verification...")
model.load_state_dict(torch.load("best_sam_target.pt"))
final_src_accs, final_tgt_acc = eval_all_domains()
print(f"\nFinal verification (best model):")
print(f"  Source: {[f'{a:.2f}%' for a in final_src_accs]}")
print(f"  Target: {final_tgt_acc:.2f}%")
print(f"  Generalization Gap: {sum(final_src_accs)/len(final_src_accs) - final_tgt_acc:.2f}%")

Loading PACS dataset...
Using device: cuda

Setting up model...

Training Configuration:
  Method: Sharpness-Aware Minimization (SAM)
  Epochs: 30
  Base Optimizer: SGD
  Learning Rate: 0.01
  Weight Decay: 5e-4
  SAM rho (ρ): 0.05
  Label Smoothing: 0.1
----------------------------------------------------------------------


Epoch 1/15: 100%|██████████| 73/73 [02:26<00:00,  2.01s/it, loss=0.6838, acc=79.31%]



Epoch 1/15
  Train Loss: 0.9287  |  Train Acc: 79.31%
  Source Accs: ['93.07%', '90.70%', '98.44%']
  Avg Source: 94.07%  |  Worst Source: 90.70%
  Target: 69.43%
  ✅ Best target accuracy: 69.43%
----------------------------------------------------------------------


Epoch 2/15: 100%|██████████| 73/73 [02:25<00:00,  1.99s/it, loss=0.5702, acc=94.81%]



Epoch 2/15
  Train Loss: 0.6127  |  Train Acc: 94.81%
  Source Accs: ['96.39%', '96.54%', '99.40%']
  Avg Source: 97.44%  |  Worst Source: 96.39%
  Target: 79.54%
  ✅ Best target accuracy: 79.54%
----------------------------------------------------------------------


Epoch 3/15: 100%|██████████| 73/73 [02:25<00:00,  1.99s/it, loss=0.5480, acc=96.83%]



Epoch 3/15
  Train Loss: 0.5600  |  Train Acc: 96.83%
  Source Accs: ['98.05%', '97.18%', '99.46%']
  Avg Source: 98.23%  |  Worst Source: 97.18%
  Target: 78.24%
  ✅ Best worst-source accuracy: 97.18%
----------------------------------------------------------------------


Epoch 4/15: 100%|██████████| 73/73 [02:26<00:00,  2.00s/it, loss=0.5321, acc=97.99%]



Epoch 4/15
  Train Loss: 0.5343  |  Train Acc: 97.99%
  Source Accs: ['98.68%', '98.76%', '99.70%']
  Avg Source: 99.05%  |  Worst Source: 98.68%
  Target: 80.25%
  ✅ Best target accuracy: 80.25%
----------------------------------------------------------------------


Epoch 5/15: 100%|██████████| 73/73 [02:25<00:00,  1.99s/it, loss=0.4966, acc=98.73%]



Epoch 5/15
  Train Loss: 0.5163  |  Train Acc: 98.73%
  Source Accs: ['98.88%', '99.40%', '99.76%']
  Avg Source: 99.35%  |  Worst Source: 98.88%
  Target: 81.60%
  ✅ Best target accuracy: 81.60%
----------------------------------------------------------------------


Epoch 6/15: 100%|██████████| 73/73 [02:25<00:00,  1.99s/it, loss=0.5111, acc=99.14%]



Epoch 6/15
  Train Loss: 0.5041  |  Train Acc: 99.14%
  Source Accs: ['99.22%', '99.36%', '99.94%']
  Avg Source: 99.51%  |  Worst Source: 99.22%
  Target: 82.36%
  ✅ Best target accuracy: 82.36%
----------------------------------------------------------------------


Epoch 7/15: 100%|██████████| 73/73 [02:26<00:00,  2.00s/it, loss=0.5019, acc=99.27%]



Epoch 7/15
  Train Loss: 0.4997  |  Train Acc: 99.27%
  Source Accs: ['99.61%', '99.74%', '100.00%']
  Avg Source: 99.78%  |  Worst Source: 99.61%
  Target: 80.99%
  ✅ Best worst-source accuracy: 99.61%
----------------------------------------------------------------------


Epoch 8/15:  66%|██████▌   | 48/73 [01:36<00:50,  2.02s/it, loss=0.4862, acc=99.57%]


KeyboardInterrupt: 