In [None]:
# =========================
# SECTION 1: SETUP & DATA
# =========================
import os, math, random, json, time, copy
from collections import defaultdict, Counter
import numpy as np
import pandas as pd
from PIL import Image
import shutil, pathlib, itertools
from pathlib import Path

from datasets import load_dataset
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset, Sampler
from torchvision import transforms, models

import cv2  # for Canny / Laplacian; pip install opencv-python-headless if needed
import matplotlib.pyplot as plt

# ---- Repro ----
SEED = 0
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
# --- Collate that preserves meta as list-of-dicts ---
from torch.utils.data._utils.collate import default_collate
def collate_keep_meta(batch):
    xs, ys, metas = zip(*batch)
    return default_collate(xs), default_collate(ys), list(metas)

# --- A robust round-robin balanced batcher that cycles loaders ---
class BalancedDomainBatcher:
    def __init__(self, loaders_by_domain):
        self.loaders = loaders_by_domain
        self.domains = list(loaders_by_domain.keys())

    def __iter__(self):
        import torch
        iters = {d: iter(dl) for d, dl in self.loaders.items()}
        while True:  # infinite stream; your outer loop controls steps
            parts = []
            for d in self.domains:
                try:
                    b = next(iters[d])
                except StopIteration:
                    # re-create that domain's iterator and pull one batch
                    iters[d] = iter(self.loaders[d])
                    b = next(iters[d])
                parts.append(b)

            xs = torch.cat([p[0] for p in parts], dim=0)
            ys = torch.cat([p[1] for p in parts], dim=0)

            # Merge meta (list-of-dicts from our collate)
            metas = []
            for p in parts:
                m = p[2]
                if isinstance(m, dict):  # in case a default collate slipped through
                    # convert dict-of-lists -> list-of-dicts
                    keys = list(m.keys())
                    n = len(m[keys[0]])
                    for i in range(n):
                        metas.append({k: m[k][i] for k in keys})
                else:
                    metas.extend(m)
            yield xs, ys, metas



In [None]:
# ===== Drive + saving helpers =====
try:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_MOUNTED = True
except Exception as e:
    print("Colab Drive mount not available, saving locally under ./outputs")
    DRIVE_MOUNTED = False

OUTPUT_ROOT = "/content/drive/MyDrive/DG_PACS" if DRIVE_MOUNTED else "./outputs"
os.makedirs(OUTPUT_ROOT, exist_ok=True)

def ensure_dir(path):
    os.makedirs(path, exist_ok=True)
    return path

def outdir_for(method_name, target_domain):
    d = ensure_dir(os.path.join(OUTPUT_ROOT, method_name, target_domain))
    return d

def save_best_checkpoint(method_name, target_domain, backbone, head, best_ckpt, log_list):
    """
    Saves:
      - model checkpoint (backbone+head state_dict)
      - training log CSV (epochs with worst_group, avg, risk_var, etc.)
      - a plot: worst_group(Class×Domain) over epochs (source-val)
    """
    d = outdir_for(method_name, target_domain)

    # 1) save model
    ckpt_path = os.path.join(d, "best_model.pt")
    torch.save({
        "epoch": best_ckpt["epoch"],
        "backbone": backbone.state_dict(),
        "head": head.state_dict(),
        "val_metrics_cd": best_ckpt.get("val_metrics_cd", {}),
        "risk_var": best_ckpt.get("risk_var", None),
        "irm_pen": best_ckpt.get("irm_pen", None)
    }, ckpt_path)

    # 2) save log CSV
    log_df = pd.DataFrame(log_list)
    log_path = os.path.join(d, "train_log.csv")
    log_df.to_csv(log_path, index=False)

    # 3) plot worst-group curve and save
    if "worst_group_cd" in log_df.columns:
        plt.figure(figsize=(6,4))
        plt.plot(log_df["epoch"], log_df["worst_group_cd"], lw=2)
        plt.xlabel("Epoch"); plt.ylabel("Worst-Group Acc (CD) on source-val")
        plt.title(f"{method_name} — worst-group(CD) vs epoch — target={target_domain}")
        plt.grid(True); plt.tight_layout()
        plot_path = os.path.join(d, "worst_group_curve.png")
        plt.savefig(plot_path, dpi=150)
        plt.close()

    # 4) optional: IRM penalty curve
    if "irm_pen" in log_df.columns:
        plt.figure(figsize=(6,4))
        plt.plot(log_df["epoch"], log_df["irm_pen"], lw=2)
        plt.xlabel("Epoch"); plt.ylabel("IRM penalty (val)")
        plt.title(f"IRM penalty vs epoch — target={target_domain}")
        plt.grid(True); plt.tight_layout()
        plt.savefig(os.path.join(d, "irm_penalty_curve.png"), dpi=150)
        plt.close()

      # In save_best_checkpoint(...):
    if "worst_group_cds" in log_df.columns:
        plt.figure(figsize=(6,4))
        plt.plot(log_df["epoch"], log_df["worst_group_cds"], lw=2, color="tab:orange")
        plt.xlabel("Epoch"); plt.ylabel("Worst-Group Acc (CDS) on source-val")
        plt.title(f"CDS worst-group vs epoch — target={target_domain}")
        plt.grid(True); plt.tight_layout()
        plt.savefig(os.path.join(d, "worst_group_cds_curve.png"), dpi=150)
        plt.close()


    print(f"[saved] model/logs/plots -> {d}")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# ==== Get PACS only (Colab-safe) ====

# 0) Paths
DATA_DIR  = Path("/content/domainbed_data")
DRIVE_DIR = Path("/content/drive/MyDrive/DG_PACS")
PACS_DIR  = DATA_DIR / "PACS"
DATA_DIR.mkdir(parents=True, exist_ok=True)
DRIVE_DIR.mkdir(parents=True, exist_ok=True)

# 1) Minimal deps (ignore requirements.txt)
!pip -q install datasets pillow tqdm wilds gdown || true

# 2) Load PACS from Hugging Face and write to DomainBed layout
# Hugging Face mirror of PACS (domains: Art_painting, Cartoon, Photo, Sketch)
# Ref: flwrlabs/pacs dataset card
ds_all = load_dataset("flwrlabs/pacs")  # splits may include 'train', 'test', 'validation' (varies)

# Clean any partial previous attempt (optional)
if PACS_DIR.exists() and any(PACS_DIR.iterdir()):
    print("PACS already exists at", PACS_DIR, "- keeping it.")
else:
    for split in ds_all.keys():
        ds = ds_all[split]
        # Try to get human-readable names for labels/domains if provided
        label_feature  = ds.features.get("label", None)
        domain_feature = ds.features.get("domain", None)
        label_names  = (label_feature.names  if hasattr(label_feature, "names")  else None)
        domain_names = (domain_feature.names if hasattr(domain_feature, "names") else None)

        for i, row in enumerate(tqdm(ds, desc=f"Writing {split}")):
            img = row["image"]
            # Row fields typically: image (PIL), label (int), domain (int or str)
            lbl = row.get("label")
            dom = row.get("domain")

            # Resolve names
            if isinstance(lbl, int) and label_names:  cls_name = label_names[lbl]
            else:                                     cls_name = str(lbl)

            if isinstance(dom, int) and domain_names: dom_name = domain_names[dom]
            else:                                     dom_name = str(dom)

            # Normalize domain names to PACS canonical names
            canon = {"art_painting":"Art_painting", "art painting":"Art_painting",
                     "cartoon":"Cartoon", "photo":"Photo", "sketch":"Sketch"}
            dom_name = canon.get(str(dom_name).lower().replace("-", " ").replace("_", " "), dom_name)

            # Some cards use title case already; ensure directories exist
            out_dir = PACS_DIR / dom_name / cls_name
            out_dir.mkdir(parents=True, exist_ok=True)
            out_path = out_dir / f"{split}_{i:06d}.jpg"

            if isinstance(img, Image.Image):
                img.save(out_path, format="JPEG")
            else:
                Image.fromarray(img).save(out_path, format="JPEG")

    print("✅ PACS written to:", PACS_DIR)

# 3) (Optional) Symlink PACS into your Drive folder so your existing code sees it
pacs_dst = DRIVE_DIR / "PACS"
if not pacs_dst.exists():
    try:
        pacs_dst.symlink_to(PACS_DIR, target_is_directory=True)
        print("Symlinked PACS ->", pacs_dst)
    except Exception:
        # Fallback: copy if symlinks not permitted
        if not pacs_dst.exists():
            shutil.copytree(PACS_DIR, pacs_dst)
            print("Copied PACS ->", pacs_dst)

# 4) (Optional cleanup) TerraIncognita is huge; remove to save space
ti = DATA_DIR / "terra_incognita"
if ti.exists():
    print("You have TerraIncognita at", ti, " (~6.5GB). Remove it to free space? Set DO_REMOVE=True.")
    DO_REMOVE = False
    if DO_REMOVE:
        shutil.rmtree(ti)
        print("Removed TerraIncognita.")

print("DATA_DIR contents:", os.listdir(DATA_DIR))
print("DRIVE_DIR contents:", os.listdir(DRIVE_DIR))




[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/126.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.2/126.2 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/78.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.8/78.8 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/191M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/9991 [00:00<?, ? examples/s]

Writing train: 100%|██████████| 9991/9991 [00:21<00:00, 472.07it/s]


✅ PACS written to: /content/domainbed_data/PACS
DATA_DIR contents: ['PACS']
DRIVE_DIR contents: ['ERM', 'PACS', 'IRMv1', 'IRM_ERMonSketch_overall.png', 'ERM_2', 'IRM_vs_ERM_Sketch_CDS_top.png', 'IRM_vs_ERM_Sketch_CDS_bottom.png', 'IRM_vs_ERM_Sketch_CD_top.png', 'IRM_vs_ERM_Sketch_CD_bottom.png', 'IRM_vs_ERM_Sketch_stylebin_delta.png', 'IRM_vs_ERM_Sketch_CDS_deltas.csv', 'GroupDRO', 'Methods_on_Sketch_overall.png', 'SAM', 'Sketch_ERM2_vs_SAM_overall.png', 'SourceVal_domains_ERM2_vs_SAM.png', 'SAM_minus_ERM2_Sketch_CDS_deltas.csv', 'SAM_vs_ERM2_Sketch_CDS_top.png', 'SAM_vs_ERM2_Sketch_CDS_bottom.png', 'flatness_curve_ERM_2_Sketch.csv', 'flatness_curve_SAM_Sketch.csv', 'Flatness_source-val_Sketch_ERM2_vs_SAM.png', 'Flatness_target-test_Sketch_ERM2_vs_SAM.png', 'FB_ratio_GroupDRO_Sketch.png', 'FB_ratio_ERM_2_Sketch.png', 'CDS_DropHeatmap_ERM_2_Sketch.png', 'CDS_DropHeatmap_IRMv1_Sketch.png', 'CDS_DropHeatmap_GroupDRO_Sketch.png', 'CDS_DropHeatmap_SAM_Sketch.png', 'CDS_TopDrops_AllMethods_S

In [None]:

# ---- PACS paths ----
DATA_ROOT = "/content/drive/MyDrive/DG_PACS/PACS"  # change if needed
DOMAINS = ["Photo", "Art_painting", "Cartoon", "Sketch"]  # <-- match folder names exactly
# PACS classes commonly used (7 classes); your folder names should match:
# 'dog','elephant','giraffe','guitar','horse','house','person'
# Adjust if your folder names differ.

# ---- Transforms (same across methods) ----
IMG_SIZE = 224
train_tf = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

eval_tf = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

# ---- PACS dataset ----
class PACS(Dataset):
    def __init__(self, root, domains, split_indices=None, transform=None, stylebin_map=None):
        """
        root: PACS root
        domains: list of domain names to include
        split_indices: optional list of (domain, class, idx) to pick exact files
        stylebin_map: dict mapping (domain, rel_path) -> stylebin ('low'|'mid'|'high') assigned beforehand
        """
        self.root = root
        self.transform = transform
        self.items = []  # list of dicts: {path, domain, cls_name, cls, rel_key}
        self.cls_to_idx = {}
        # discover classes from first domain present
        classes = None
        for d in domains:
            dpath = os.path.join(root, d)
            if os.path.isdir(dpath):
                classes = sorted([c for c in os.listdir(dpath) if os.path.isdir(os.path.join(dpath,c))])
                break
        assert classes is not None, "Could not find any domain folders in DATA_ROOT."
        self.cls_to_idx = {c:i for i,c in enumerate(classes)}

        if split_indices is None:
            for d in domains:
                dpath = os.path.join(root, d)
                for c in classes:
                    cpath = os.path.join(dpath, c)
                    if not os.path.isdir(cpath): continue
                    for fn in os.listdir(cpath):
                        if fn.lower().endswith((".jpg",".jpeg",".png",".bmp")):
                            rel_key = f"{d}/{c}/{fn}"
                            self.items.append({
                                "path": os.path.join(cpath, fn),
                                "domain": d,
                                "cls_name": c,
                                "cls": self.cls_to_idx[c],
                                "rel_key": rel_key
                            })
        else:
            # pick exact indices
            for (d,c,fn) in split_indices:
                rel_key = f"{d}/{c}/{fn}"
                self.items.append({
                    "path": os.path.join(root, d, c, fn),
                    "domain": d,
                    "cls_name": c,
                    "cls": self.cls_to_idx[c],
                    "rel_key": rel_key
                })
        self.stylebin_map = stylebin_map or {}

    def __len__(self): return len(self.items)

    def __getitem__(self, idx):
        it = self.items[idx]
        img = Image.open(it["path"]).convert("RGB")
        x = self.transform(img) if self.transform else transforms.ToTensor()(img)
        y = it["cls"]
        meta = {
            "domain": it["domain"],
            "cls_name": it["cls_name"],
            "rel_key": it["rel_key"],
            "stylebin": self.stylebin_map.get(it["rel_key"], None),
        }
        return x, y, meta

# ---- Enumerate all files per domain/class for split bookkeeping ----
def list_files(root, domain):
    dpath = os.path.join(root, domain)
    classes = sorted([c for c in os.listdir(dpath) if os.path.isdir(os.path.join(dpath,c))])
    triplets = []
    for c in classes:
        for fn in os.listdir(os.path.join(dpath, c)):
            if fn.lower().endswith((".jpg",".jpeg",".png",".bmp")):
                triplets.append((domain, c, fn))
    return triplets, classes


NameError: name 'transforms' is not defined

In [None]:
# ---- StyleScore utilities (edge density, Laplacian variance, HSV saturation) ----
def style_score_image(path):
    im = cv2.imread(path)[:,:,::-1]  # to RGB
    im_s = cv2.resize(im, (256,256), interpolation=cv2.INTER_AREA)
    gray = cv2.cvtColor(im_s, cv2.COLOR_RGB2GRAY)

    # Edge density
    edges = cv2.Canny(gray, threshold1=100, threshold2=200)
    edge_density = float((edges>0).sum()) / edges.size

    # Laplacian variance (sharpness / texture)
    lap = cv2.Laplacian(gray, cv2.CV_64F)
    lap_var = float(lap.var())

    # Saturation mean
    hsv = cv2.cvtColor(im_s, cv2.COLOR_RGB2HSV)
    sat_mean = float(hsv[:,:,1].mean())/255.0

    return np.array([edge_density, lap_var, sat_mean], dtype=np.float32)

def compute_stylebins_per_domain(root, source_domains, train_indices):
    """
    From source-train only, compute per-domain feature mean/std for (edge_density, lap_var, sat_mean),
    then z-score and form StyleScore = mean of z-scores. Return:
      - stylebin_map for train items
      - thresholds: per-domain dict with (feat_mean, feat_std, q1, q2)
    """
    by_domain = defaultdict(list)  # domain -> list of (rel_key, f3)
    for (d,c,fn) in train_indices:
        rel_key = f"{d}/{c}/{fn}"
        path = os.path.join(root, d, c, fn)
        by_domain[d].append((rel_key, style_score_image(path)))

    stylebin_map = {}
    thresholds = {}
    for d, lst in by_domain.items():
        F = np.stack([f for _,f in lst], axis=0)             # Nx3 raw features
        mu = F.mean(0);  sd = F.std(0) + 1e-6
        Z  = (F - mu) / sd
        score = Z.mean(1)                                    # StyleScore (z-mean)
        q1, q2 = np.quantile(score, [0.33, 0.66])
        thresholds[d] = {"mu": mu.tolist(), "sd": sd.tolist(), "q1": float(q1), "q2": float(q2)}
        for (rel_key, _), s in zip(lst, score):
            sb = 'low' if s < q1 else ('mid' if s < q2 else 'high')
            stylebin_map[rel_key] = sb
    return stylebin_map, thresholds

def apply_stylebins_to_indices(root, indices, thresholds):
    """Assign StyleBin to any split using per-domain (mu, sd, q1, q2) learned from source-train.
       If a domain is missing (e.g., the target), compute its mu/sd and tertiles from the provided indices
       so analysis on test still has low/mid/high bins. This does NOT affect training/selection.
    """
    stylebin_map = {}
    # group indices by domain
    by_dom = defaultdict(list)
    for (d,c,fn) in indices:
        by_dom[d].append((d,c,fn))

    # build (mu, sd, q1, q2) for domains not in thresholds
    local_stats = {}
    for d, lst in by_dom.items():
        if d in thresholds:
            continue
        F = []
        for _,c,fn in lst:
            path = os.path.join(root, d, c, fn)
            F.append(style_score_image(path))
        F = np.stack(F, axis=0) if len(F) else np.zeros((1,3), np.float32)
        mu = F.mean(0); sd = F.std(0) + 1e-6
        Z  = (F - mu) / sd
        score = Z.mean(1)
        q1, q2 = np.quantile(score, [0.33, 0.66])
        local_stats[d] = {"mu": mu, "sd": sd, "q1": float(q1), "q2": float(q2)}

    # assign bins
    for (d,c,fn) in indices:
        rel_key = f"{d}/{c}/{fn}"
        path = os.path.join(root, d, c, fn)
        f3 = style_score_image(path)
        stats = thresholds.get(d, None)
        if stats is None:
            stats = local_stats[d]  # use local (test-only) stats
        mu = np.array(stats["mu"], dtype=np.float32)
        sd = np.array(stats["sd"], dtype=np.float32)
        z  = (f3 - mu) / sd
        score = z.mean()
        q1, q2 = stats["q1"], stats["q2"]
        sb = 'low' if score < q1 else ('mid' if score < q2 else 'high')
        stylebin_map[rel_key] = sb
    return stylebin_map


@torch.no_grad()
def stylebin_bars(model, loader, by_domain=True):
    """Return accuracy per (Domain, StyleBin) and optionally per (Class,Domain,StyleBin)."""
    backbone, head = model["backbone"], model["head"]
    backbone.eval(); head.eval()
    counts = defaultdict(lambda: [0,0])     # key -> [correct,total]
    counts_cds = defaultdict(lambda: [0,0]) # (class,domain,stylebin)

    for x,y,meta in loader:
        x,y = x.to(DEVICE), y.to(DEVICE)
        logits, _ = forward_logits(backbone, head, x)
        pred = logits.argmax(1).cpu().numpy()
        ycpu = y.cpu().numpy()
        for i in range(len(y)):
            d  = meta[i]["domain"]
            sb = meta[i].get("stylebin", None)
            c  = meta[i]["cls_name"]
            if sb is None: continue
            counts[(d,sb)][1] += 1
            counts[(d,sb)][0] += int(pred[i]==ycpu[i])
            counts_cds[(c,d,sb)][1] += 1
            counts_cds[(c,d,sb)][0] += int(pred[i]==ycpu[i])

    acc_dom_sb  = {(d,sb): c/max(1,t) for (d,sb),(c,t) in counts.items()}
    acc_cds     = {(c,d,sb): c/max(1,t) for (c,d,sb),(c,t) in counts_cds.items()}
    # worst over CDS (for robustness diagnosis)
    worst_cds = min(acc_cds.values()) if acc_cds else 0.0
    return acc_dom_sb, acc_cds, worst_cds


# ---- LODO split with source-val from source only (IIDAccuracySelectionMethod) ----
def make_lodo_splits(root, target_domain, val_frac=0.1):
    source_domains = [d for d in DOMAINS if d != target_domain]
    # list all triplets for sources and target
    src_triplets = []
    classes_ref = None
    for d in source_domains:
        trips, classes = list_files(root, d)
        if classes_ref is None: classes_ref = classes
        src_triplets += trips
    tgt_triplets, _ = list_files(root, target_domain)

    # Stratified split by (domain, class)
    rng = np.random.RandomState(SEED)
    src_df = pd.DataFrame(src_triplets, columns=["domain","cls","fn"])
    train_idx, val_idx = [], []
    for (d,c), grp in src_df.groupby(["domain","cls"]):
        idxs = grp.index.to_list()
        rng.shuffle(idxs)
        n_val = max(1, int(len(idxs)*val_frac))
        val_ids = idxs[:n_val]
        train_ids = idxs[n_val:]
        for i in train_ids: train_idx.append(tuple(src_df.loc[i]))
        for i in val_ids:   val_idx.append(tuple(src_df.loc[i]))
    # Target full test set
    test_idx = tgt_triplets
    return train_idx, val_idx, test_idx, classes_ref, source_domains

# ---- Balanced per-domain batcher ----
class BalancedDomainBatcher:
    """Yields batches with equal-sized chunks from each source domain each step."""
    def __init__(self, loaders_by_domain):
        self.loaders = loaders_by_domain
        self.iters = {d: iter(dl) for d,dl in loaders_by_domain.items()}
        self.domains = list(loaders_by_domain.keys())

    def __iter__(self):
        while True:
            batch_x, batch_y, batch_meta = [], [], []
            ended = 0
            for d in self.domains:
                try:
                    x,y,m = next(self.iters[d])
                except StopIteration:
                    ended += 1
                    return
                batch_x.append(x); batch_y.append(y); batch_meta += m  # meta is list already
            yield torch.cat(batch_x,0), torch.cat(batch_y,0), batch_meta


In [None]:
# ---- Model: ResNet-18 backbone + linear head; BN frozen ----
def build_model(n_classes):
    res = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    # Freeze BN running stats: set to eval; keep affine trainable (default)
    def freeze_bn(m):
        if isinstance(m, nn.BatchNorm2d):
            m.eval()
            m.track_running_stats = False
    res.apply(freeze_bn)
    # Split backbone & head
    in_feat = res.fc.in_features
    backbone = nn.Sequential(*list(res.children())[:-1])  # till pool (N,512,1,1)
    head = nn.Linear(in_feat, n_classes)
    return backbone, head

def forward_logits(backbone, head, x):
    feats = backbone(x).flatten(1)
    return head(feats), feats

In [None]:
# ---- Metrics helpers ----
@torch.no_grad()
def evaluate(model, loader, classes, group_key="CD"):
    backbone, head = model["backbone"], model["head"]
    backbone.eval(); head.eval()
    correct = 0; total = 0
    per_domain = defaultdict(lambda: [0,0])  # d -> [correct,total]
    per_group = defaultdict(lambda: [0,0])   # groupid -> [correct,total]
    for x,y,meta in loader:
        x,y = x.to(DEVICE), y.to(DEVICE)
        logits, feats = forward_logits(backbone, head, x)
        pred = logits.argmax(1)
        correct += (pred==y).sum().item()
        total   += y.numel()
        for i in range(len(y)):
            d = meta[i]["domain"]
            per_domain[d][1]+=1
            per_domain[d][0]+= int(pred[i].item()==y[i].item())
            if group_key=="CD":
                gid = f"{meta[i]['cls_name']}|{d}"
            else:
                sb = meta[i].get("stylebin", None)
                gid = f"{meta[i]['cls_name']}|{d}|{sb}"
            per_group[gid][1]+=1
            per_group[gid][0]+= int(pred[i].item()==y[i].item())
    avg = correct/max(1,total)
    per_domain_acc = {d: c/max(1,t) for d,(c,t) in per_domain.items()}
    per_group_acc  = {g: c/max(1,t) for g,(c,t) in per_group.items()}
    worst_group = min(per_group_acc.values()) if per_group_acc else 0.0
    avg_minus_worst = (np.mean(list(per_group_acc.values()))-worst_group) if per_group_acc else 0.0
    return {
        "avg": avg,
        "per_domain": per_domain_acc,
        "per_group": per_group_acc,
        "worst_group": worst_group,
        "avg_minus_worst": avg_minus_worst
    }

def risk_variance_per_domain(model, loader):
    """Mean CE per domain + variance across domains (for source-val only)."""
    backbone, head = model["backbone"], model["head"]
    backbone.eval(); head.eval()
    losses = defaultdict(list)
    ce = nn.CrossEntropyLoss(reduction='none')
    with torch.no_grad():
        for x,y,meta in loader:
            x,y = x.to(DEVICE), y.to(DEVICE)
            logits, _ = forward_logits(backbone, head, x)
            l_all = ce(logits, y).detach().cpu().numpy()
            for i in range(len(y)):
                d = meta[i]["domain"]
                losses[d].append(l_all[i])
    means = {d: float(np.mean(v)) for d,v in losses.items() if len(v)>0}
    if len(means)>1:
        var = float(np.var(list(means.values())))
    else:
        var = 0.0
    return means, var

def irm_penalty_grad_norm(backbone, head, loader, wbar=None):
    """
    IRMv1 penalty: sum_e || grad_w CE_e |_{w=wbar} ||^2
    Evaluate on source-val loader that contains per-env structure (we build domain-balanced batches in train).
    """
    ce = nn.CrossEntropyLoss()
    backbone.eval(); head.eval()
    head_zero = copy.deepcopy(head)
    if wbar is not None:
        with torch.no_grad():
            head_zero.weight.copy_(wbar["weight"]); head_zero.bias.copy_(wbar["bias"])
    head_zero.requires_grad_(True)
    grads = defaultdict(lambda: torch.zeros_like(head_zero.weight))
    counts = defaultdict(int)
    for x,y,meta in loader:
        x,y = x.to(DEVICE), y.to(DEVICE)
        x = x.detach(); y = y.detach()
        feats = backbone(x).flatten(1).detach()  # stop grad into backbone for penalty eval
        logits = feats @ head_zero.weight.T + head_zero.bias
        loss = ce(logits, y)
        # split by domain, accumulate gradient
        for d in set([m["domain"] for m in meta]):
            mask = torch.tensor([1 if m["domain"]==d else 0 for m in meta], dtype=torch.bool, device=DEVICE)
            if mask.sum()==0: continue
            l = ce(logits[mask], y[mask])
            gW, gB = torch.autograd.grad(l, [head_zero.weight, head_zero.bias], retain_graph=True)
            grads[d] = grads[d] + gW.detach()
            counts[d] += 1
    # norm squared sum
    s = 0.0
    for d,g in grads.items():
        s += float((g**2).sum().item())
    return s

# ========= CSV writers for per-target summaries =========

@torch.no_grad()
def collect_group_table(model, loader, group_key="CD"):
    """
    Returns a DataFrame with per-group counts and accuracy.
    group_key: "CD"  -> group = Class|Domain
               "CDS" -> group = Class|Domain|StyleBin
    """
    backbone, head = model["backbone"], model["head"]
    backbone.eval(); head.eval()
    rows = []
    for x,y,meta in loader:
        x,y = x.to(DEVICE), y.to(DEVICE)
        logits, _ = forward_logits(backbone, head, x)
        pred = logits.argmax(1).cpu().numpy()
        ycpu = y.cpu().numpy()
        for i in range(len(y)):
            d   = meta[i]["domain"]
            cls = meta[i]["cls_name"]
            sb  = meta[i].get("stylebin", None)
            if group_key == "CD":
                gid = f"{cls}|{d}"
                rows.append((cls, d, None, int(pred[i]==ycpu[i])))
            else:
                gid = f"{cls}|{d}|{sb}"
                rows.append((cls, d, sb, int(pred[i]==ycpu[i])))
    # aggregate
    df = pd.DataFrame(rows, columns=["class","domain","stylebin","correct"])
    df["n"] = 1
    agg = df.groupby(["class","domain","stylebin"], dropna=False)[["correct","n"]].sum().reset_index()
    agg["acc"] = agg["correct"] / agg["n"].clip(lower=1)
    return agg


def save_target_csvs(method_name, target_domain, model, dl_test,
                     test_metrics_cd, test_metrics_cds,
                     acc_dom_sb, acc_cds):
    """
    Writes four CSVs into your Drive folder for this method/target:
      1) metrics.csv                -> overall & per-domain metrics, worst_group_cd/cds
      2) groups_cd.csv              -> Class×Domain table with counts & acc
      3) groups_cds.csv             -> Class×Domain×StyleBin table with counts & acc
      4) stylebin_domain.csv        -> (Domain, StyleBin) acc table
      (optional) stylebin_cds.csv   -> (Class, Domain, StyleBin) acc table
    """
    out_dir = outdir_for(method_name, target_domain)

    # 1) metrics.csv
    per_dom = test_metrics_cd["per_domain"]
    metrics_rows = [{"metric":"overall_avg", "value": test_metrics_cd["avg"]}]
    for d, a in per_dom.items():
        metrics_rows.append({"metric": f"acc_{d}", "value": a})
    metrics_rows.append({"metric": "worst_group_cd",  "value": test_metrics_cd["worst_group"]})
    metrics_rows.append({"metric": "worst_group_cds", "value": test_metrics_cds["worst_group"]})
    pd.DataFrame(metrics_rows).to_csv(os.path.join(out_dir, "metrics.csv"), index=False)

    # 2) groups_cd.csv (Class×Domain)
    df_cd = collect_group_table(model, dl_test, group_key="CD")
    df_cd.to_csv(os.path.join(out_dir, "groups_cd.csv"), index=False)

    # 3) groups_cds.csv (Class×Domain×StyleBin)
    df_cds = collect_group_table(model, dl_test, group_key="CDS")
    # enforce bin order if you like
    if "stylebin" in df_cds.columns:
        order = ["low","mid","high"]
        df_cds["stylebin"] = pd.Categorical(df_cds["stylebin"], categories=order, ordered=True)
        df_cds = df_cds.sort_values(["domain","class","stylebin"])
    df_cds.to_csv(os.path.join(out_dir, "groups_cds.csv"), index=False)

    # 4) stylebin_domain.csv ((Domain, StyleBin) -> acc)
    rows_dom_sb = [{"domain": d, "stylebin": sb, "acc": acc}
                   for (d,sb), acc in acc_dom_sb.items()]
    df_dom_sb = pd.DataFrame(rows_dom_sb)
    if not df_dom_sb.empty:
        order = ["low","mid","high"]
        df_dom_sb["stylebin"] = pd.Categorical(df_dom_sb["stylebin"], categories=order, ordered=True)
        df_dom_sb = df_dom_sb.sort_values(["domain","stylebin"])
    df_dom_sb.to_csv(os.path.join(out_dir, "stylebin_domain.csv"), index=False)

    # (optional) stylebin_cds.csv ((Class, Domain, StyleBin) -> acc)
    rows_cds = [{"class": c, "domain": d, "stylebin": sb, "acc": acc}
                for (c,d,sb), acc in acc_cds.items()]
    df_cds_acc = pd.DataFrame(rows_cds)
    if len(df_cds_acc):
        df_cds_acc["stylebin"] = pd.Categorical(df_cds_acc["stylebin"], categories=["low","mid","high"], ordered=True)
        df_cds_acc = df_cds_acc.sort_values(["domain","class","stylebin"])
        df_cds_acc.to_csv(os.path.join(out_dir, "stylebin_cds.csv"), index=False)

    print(f"[saved] per-target CSVs -> {out_dir}")


SECTION 2 — ERM (training + primary analysis)

What we’re doing here

1.   Train ERM with balanced per-domain batches on source-train
2. Select the checkpoint with max WorstGroup(Class×Domain) on source-val
3. Evaluate on held-out target (per-domain table, Avg, Worst-Target), and on source-val (group tables, worst-group, avg–worst gap)
4. Log risk variance on source-val (for comparison)





In [None]:
# =========================
# SECTION 2: ERM
# =========================
LR = 0.01
WD = 1e-4
MOM = 0.9
EPOCHS = 40
BATCH_PER_DOMAIN = 32  # effective batch = BATCH_PER_DOMAIN * #source_domains

def run_erm_for_target(target_domain):
    print(f"\n=== ERM :: target={target_domain} ===")
    train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, target_domain, val_frac=0.1)

    # StyleBin thresholds from source-train; then assign to val/test
    style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
    style_map_val  = apply_stylebins_to_indices(DATA_ROOT, val_idx, thresholds)
    style_map_test = apply_stylebins_to_indices(DATA_ROOT, test_idx, thresholds)
    style_map = {**style_map_train, **style_map_val, **style_map_test}


    # Build domain-specific train loaders (balanced batcher)
    loaders_by_domain = {}
    for d in source_domains:
        d_train = [(dd,cc,fn) for (dd,cc,fn) in train_idx if dd==d]
        ds = PACS(DATA_ROOT, [d], split_indices=d_train, transform=train_tf, stylebin_map=style_map)
        dl = DataLoader(ds, batch_size=BATCH_PER_DOMAIN, shuffle=True, num_workers=2,
                        pin_memory=True, drop_last=True, collate_fn=collate_keep_meta)
        loaders_by_domain[d] = dl

    # Print batch counts for sanity
    batch_sizes = {d: len(dl) for d, dl in loaders_by_domain.items()}
    print("[train] per-domain #batches:", batch_sizes)

    train_bal = BalancedDomainBatcher(loaders_by_domain)

    # source-val & target-test
    ds_val  = PACS(DATA_ROOT, source_domains, split_indices=val_idx,  transform=eval_tf, stylebin_map=style_map)
    ds_test = PACS(DATA_ROOT, [target_domain],  split_indices=test_idx, transform=eval_tf, stylebin_map=style_map)
    dl_val  = DataLoader(ds_val,  batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)
    dl_test = DataLoader(ds_test, batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)

    # Steps per epoch: at least 1, bounded by the shortest domain (after drop_last)
    steps_per_epoch = min(len(dl) for dl in loaders_by_domain.values())
    if steps_per_epoch < 1:
        raise RuntimeError(
            f"0 steps/epoch: reduce BATCH_PER_DOMAIN (now {BATCH_PER_DOMAIN}) or set drop_last=False. "
            f"Per-domain batches: {batch_sizes}"
        )


    backbone, head = build_model(n_classes=len(classes))
    backbone, head = backbone.to(DEVICE), head.to(DEVICE)
    opt = torch.optim.SGD(list(backbone.parameters())+list(head.parameters()), lr=LR, momentum=MOM, weight_decay=WD)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS)

    ce = nn.CrossEntropyLoss()
    best_ckpt = None
    best_sel  = -1.0   # selection by worst-group (CD) on source-val
    log = []

    steps_per_epoch = min(len(dl) for dl in loaders_by_domain.values())
    for ep in range(EPOCHS):
        backbone.train(); head.train()
        it = iter(train_bal)
        for step in range(steps_per_epoch):
            x,y,meta = next(it)
            x,y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            logits, feats = forward_logits(backbone, head, x)
            loss = ce(logits, y)
            loss.backward()
            opt.step()
        sched.step()

        # ---- selection & logging on source-val ----
        metrics_cd  = evaluate({"backbone":backbone, "head":head}, dl_val, classes, group_key="CD")
        metrics_cds = evaluate({"backbone":backbone, "head":head}, dl_val, classes, group_key="CDS")
        worst_cds   = metrics_cds["worst_group"]  # WorstGroup(Class×Domain×StyleBin)
        means, rvar = risk_variance_per_domain({"backbone":backbone, "head":head}, dl_val)

        sel = metrics_cd["worst_group"]  # primary criterion
        log.append({
            "epoch": ep,
            "worst_group_cd": sel,
            "worst_group_cds": worst_cds,   # NEW
            "avg": metrics_cd["avg"],
            "risk_var": rvar
        })
        if sel > best_sel:
            best_sel = sel
            best_ckpt = {
                "epoch": ep,
                "backbone": copy.deepcopy(backbone.state_dict()),
                "head": copy.deepcopy(head.state_dict()),
                "val_metrics_cd": metrics_cd,
                "risk_var": rvar}
            # (A) also store sd for returning if you want programmatic access
            best_backbone_sd = copy.deepcopy(backbone.state_dict())
            best_head_sd = copy.deepcopy(head.state_dict())

        if (ep+1)%5==0:
            print(f"[ERM][ep {ep+1:02d}] val worst_group(CD)={sel:.3f}  avg={metrics_cd['avg']:.3f}  rvar={rvar:.4f}")

    # ---- load best and evaluate on target ----
    backbone.load_state_dict(best_ckpt["backbone"]); head.load_state_dict(best_ckpt["head"])
    test_metrics_cd  = evaluate({"backbone":backbone, "head":head}, dl_test, classes, group_key="CD")
    test_metrics_cds = evaluate({"backbone":backbone, "head":head}, dl_test, classes, group_key="CDS")
    # (B) SAVE best checkpoint + log + curves to Drive
    save_best_checkpoint("ERM", target_domain, backbone, head, best_ckpt, log)

    # --- Style-bin robustness (per-domain bars) on TEST ---
    acc_dom_sb, acc_cds, worst_cds_test = stylebin_bars({"backbone":backbone, "head":head}, dl_test)

    # Save a per-domain bar chart
    out_dir = outdir_for("ERM", target_domain)  # ERM | IRMv1 | GroupDRO | ERM_SAM
    # Build a tidy df for plotting
    df_sb = pd.DataFrame([{"domain": d, "stylebin": sb, "acc": acc}
                          for (d,sb), acc in acc_dom_sb.items()])
    plt.figure(figsize=(7,4))

    order = ["low", "mid", "high"]
    df_sb["stylebin"] = pd.Categorical(df_sb["stylebin"], categories=order, ordered=True)
    for i, dom in enumerate(sorted(df_sb["domain"].unique())):
        sub = df_sb[df_sb["domain"]==dom].sort_values("stylebin")
        xs = np.arange(len(order))
        plt.bar(xs + i*0.28, sub["acc"].values, width=0.25, label=dom)
    plt.xticks(np.arange(len(order)) + 0.28, order)
    plt.ylabel("Accuracy"); plt.title(f"Style-bin robustness — target={target_domain}")
    plt.legend(); plt.tight_layout()
    plt.savefig(os.path.join(out_dir, "stylebin_robustness_test.png"), dpi=150)
    plt.close()

    # >>> NEW: save compact CSVs for this target <<<
    save_target_csvs(
        method_name="ERM",
        target_domain=target_domain,
        model={"backbone": backbone, "head": head},
        dl_test=dl_test,
        test_metrics_cd=test_metrics_cd,
        test_metrics_cds=test_metrics_cds,
        acc_dom_sb=acc_dom_sb,
        acc_cds=acc_cds
    )

    result = {
        "target": target_domain,
        "best_epoch": best_ckpt["epoch"],
        "val_worst_group_cd": best_ckpt["val_metrics_cd"]["worst_group"],
        "val_avg": best_ckpt["val_metrics_cd"]["avg"],
        "val_risk_var": best_ckpt["risk_var"],
        "test_avg": test_metrics_cd["avg"],
        "test_per_domain": test_metrics_cd["per_domain"],
        "test_worst_group_cd": test_metrics_cd["worst_group"],
        "test_worst_group_cds": test_metrics_cds["worst_group"],
        "log": log,
        "best_backbone_sd": best_backbone_sd,  # optional, handy later
        "best_head_sd": best_head_sd,
        "test_worst_group_cds": test_metrics_cds["worst_group"],
        "test_stylebin_dom_acc": acc_dom_sb,   # (domain,bin) -> acc
        "test_stylebin_cds_acc": acc_cds       # (class,domain,bin) -> acc
    }
    return result

# Run ERM for all 4 targets (single seed)
erm_results = []
for tgt in DOMAINS:
    erm_results.append(run_erm_for_target(tgt))

print("\n=== ERM summary (single seed) ===")
for r in erm_results:
    print(f"tgt={r['target']}: test_avg={r['test_avg']:.3f} worst_group_cd={r['test_worst_group_cd']:.3f}")



=== ERM :: target=Photo ===
[train] per-domain #batches: {'Art_painting': 57, 'Cartoon': 66, 'Sketch': 110}
[ERM][ep 05] val worst_group(CD)=0.062  avg=0.600  rvar=0.0375
[ERM][ep 10] val worst_group(CD)=0.125  avg=0.766  rvar=0.1553
[ERM][ep 15] val worst_group(CD)=0.114  avg=0.692  rvar=0.3810
[ERM][ep 20] val worst_group(CD)=0.000  avg=0.755  rvar=0.2188
[ERM][ep 25] val worst_group(CD)=0.188  avg=0.773  rvar=0.0120
[ERM][ep 30] val worst_group(CD)=0.188  avg=0.760  rvar=0.0153
[ERM][ep 35] val worst_group(CD)=0.188  avg=0.786  rvar=0.0343
[ERM][ep 40] val worst_group(CD)=0.188  avg=0.790  rvar=0.0357
[saved] model/logs/plots -> /content/drive/MyDrive/DG_PACS/ERM/Photo
[saved] per-target CSVs -> /content/drive/MyDrive/DG_PACS/ERM/Photo

=== ERM :: target=Art_painting ===
[train] per-domain #batches: {'Photo': 47, 'Cartoon': 66, 'Sketch': 110}
[ERM][ep 05] val worst_group(CD)=0.611  avg=0.910  rvar=0.0122
[ERM][ep 10] val worst_group(CD)=0.375  avg=0.803  rvar=0.0268
[ERM][ep 15] va

In [None]:
import os, glob, pandas as pd, numpy as np
import matplotlib.pyplot as plt

ROOT = "/content/drive/MyDrive/DG_PACS"
METHODS = ["ERM"]  # extend later: ["ERM","IRM","GroupDRO","ERM_SAM"]
TARGETS = ["Photo","Art_painting","Cartoon","Sketch"]

rows = []
cd_rows = []
cds_rows = []
sb_rows = []

for method in METHODS:
    for tgt in TARGETS:
        d = os.path.join(ROOT, method, tgt)
        mpath = os.path.join(d, "metrics.csv")
        if not os.path.isfile(mpath):
            continue
        M = pd.read_csv(mpath).set_index("metric")["value"].to_dict()
        rows.append({
            "method": method,
            "target": tgt,
            "avg": M.get("overall_avg", np.nan),
            "worst_cd": M.get("worst_group_cd", np.nan),
            "worst_cds": M.get("worst_group_cds", np.nan),
        })

        # per-class/domain on *target* domain (your groups_cd.csv)
        g_cd = os.path.join(d, "groups_cd.csv")
        if os.path.isfile(g_cd):
            G = pd.read_csv(g_cd)
            G["method"] = method; G["target"] = tgt
            cd_rows.append(G)

        # per-class/domain/stylebin on *target* domain
        g_cds = os.path.join(d, "groups_cds.csv")
        if os.path.isfile(g_cds):
            H = pd.read_csv(g_cds)
            H["method"] = method; H["target"] = tgt
            cds_rows.append(H)

        # domain×stylebin
        sbd = os.path.join(d, "stylebin_domain.csv")
        if os.path.isfile(sbd):
            S = pd.read_csv(sbd)
            S["method"] = method; S["target"] = tgt
            sb_rows.append(S)

summary = pd.DataFrame(rows).sort_values(["method","target"])
cd_df  = pd.concat(cd_rows,  ignore_index=True) if cd_rows  else pd.DataFrame()
cds_df = pd.concat(cds_rows, ignore_index=True) if cds_rows else pd.DataFrame()
sb_df  = pd.concat(sb_rows,  ignore_index=True) if sb_rows  else pd.DataFrame()

print("=== Overall by target ===")
print(summary)

# Identify worst CD and CDS groups per target/method
if not cd_df.empty:
    worst_cd = (cd_df.assign(group=lambda x: x["class"]+" @ "+x["domain"])
                      .sort_values("acc")
                      .groupby(["method","target"]).first()
                      .reset_index()[["method","target","group","acc"]])
    print("\n=== Worst CD group per target ===")
    print(worst_cd)

if not cds_df.empty:
    worst_cds = (cds_df.assign(group=lambda x: x["class"]+" @ "+x["domain"]+" / "+x["stylebin"].astype(str))
                       .sort_values("acc")
                       .groupby(["method","target"]).first()
                       .reset_index()[["method","target","group","acc"]])
    print("\n=== Worst CDS group per target ===")
    print(worst_cds)

# Optional small plots per target (avg vs worst_cd vs worst_cds)
for (method, tgt), T in summary.groupby(["method","target"]):
    plt.figure(figsize=(4,3))
    vals = [T["avg"].iloc[0], T["worst_cd"].iloc[0], T["worst_cds"].iloc[0]]
    plt.bar(["avg","worst_cd","worst_cds"], vals)
    plt.ylim(0,1); plt.title(f"{method} — {tgt}")
    plt.tight_layout()
    out = os.path.join(ROOT, method, tgt, "summary_bars.png")
    plt.savefig(out, dpi=140); plt.close()

# Optional: domain×stylebin bars on target
for (method, tgt), S in sb_df.groupby(["method","target"]):
    plt.figure(figsize=(5,3))
    S = S.sort_values("stylebin")
    plt.bar(S["stylebin"], S["acc"])
    plt.ylim(0,1); plt.title(f"{method} — stylebins on {tgt}")
    plt.tight_layout()
    out = os.path.join(ROOT, method, tgt, "stylebin_on_target.png")
    plt.savefig(out, dpi=140); plt.close()


=== Overall by target ===
  method        target       avg  worst_cd  worst_cds
1    ERM  Art_painting  0.638672  0.398664   0.240385
2    ERM       Cartoon  0.537543  0.195373   0.151724
0    ERM         Photo  0.839521  0.381910   0.287671
3    ERM        Sketch  0.671418  0.093750   0.083333

=== Worst CD group per target ===
  method        target                  group       acc
0    ERM  Art_painting  person @ Art_painting  0.398664
1    ERM       Cartoon          dog @ Cartoon  0.195373
2    ERM         Photo          horse @ Photo  0.381910
3    ERM        Sketch        person @ Sketch  0.093750

=== Worst CDS group per target ===
  method        target                      group       acc
0    ERM  Art_painting  dog @ Art_painting / high  0.240385
1    ERM       Cartoon       dog @ Cartoon / high  0.151724
2    ERM         Photo        horse @ Photo / mid  0.287671
3    ERM        Sketch      person @ Sketch / mid  0.083333


# **Section 2: IRMv1**

In [None]:
import os, copy, math, json, glob
from collections import defaultdict
import torch, torch.nn as nn, torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torchvision

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Keep meta as list-of-dicts everywhere
from torch.utils.data._utils.collate import default_collate
def collate_keep_meta(batch):
    xs, ys, metas = zip(*batch)
    return default_collate(xs), default_collate(ys), list(metas)

# --- Evaluate with optional CD/CDS group counts (needed for CSVs)
def evaluate(model, loader, classes, group_key="CD", return_groups=False):
    backbone, head = model["backbone"], model["head"]
    backbone.eval(); head.eval()

    total, correct = 0, 0
    per_domain_cnt = defaultdict(lambda: [0,0])
    cd_cnt  = defaultdict(lambda: [0,0])
    cds_cnt = defaultdict(lambda: [0,0])

    with torch.no_grad():
        for x, y, meta in loader:
            x = x.to(DEVICE); y = y.to(DEVICE)
            logits, _ = forward_logits(backbone, head, x)
            pred = logits.argmax(1)
            correct += (pred == y).sum().item()
            total   += y.numel()

            bs = y.size(0)
            for i in range(bs):
                m = meta[i] if isinstance(meta, list) else {k: meta[k][i] for k in meta}
                dname = m["domain"]
                cname = classes[y[i].item()]

                per_domain_cnt[dname][1]+=1
                per_domain_cnt[dname][0]+= int(pred[i].item()==y[i].item())

                cd_cnt[(cname, dname)][1]+=1
                cd_cnt[(cname, dname)][0]+= int(pred[i].item()==y[i].item())

                sb = m.get("stylebin", None)
                if sb is not None:
                    cds_cnt[(cname, dname, sb)][1]+=1
                    cds_cnt[(cname, dname, sb)][0]+= int(pred[i].item()==y[i].item())

    avg = correct / max(1, total)
    per_domain = {d: c/t for d,(c,t) in per_domain_cnt.items() if t>0}

    # worst-group defined by requested grouping
    pool = []
    if group_key == "CDS":
        pool = [(c,t) for (_, (c,t)) in cds_cnt.items()]
    else:
        pool = [(c,t) for (_, (c,t)) in cd_cnt.items()]
    group_accs = [c/t for (c,t) in pool if t>0]
    worst_group = min(group_accs) if group_accs else 0.0

    out = {"avg": avg, "per_domain": per_domain, "worst_group": worst_group}
    if return_groups:
        out["cd_counts"]  = dict(cd_cnt)
        out["cds_counts"] = dict(cds_cnt)
    return out

def build_model(n_classes):
    m = torchvision.models.resnet18(
        weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1
    )
    in_dim = m.fc.in_features  # 512
    m.fc = nn.Identity()       # keep avgpool; outputs (B, 512)
    head = nn.Linear(in_dim, n_classes)
    return m, head

def forward_logits(backbone, head, x):
    feats = backbone(x)
    if feats.ndim > 2:                        # safety: if (B,C,H,W)
        feats = F.adaptive_avg_pool2d(feats, 1).view(feats.size(0), -1)
    logits = head(feats)
    return logits, feats


# --- CSV savers for source-val & target-test (CD/CDS, stylebins)
def save_source_val_csvs(method_name, target_domain, metrics_cd, metrics_cds):
    out_dir = outdir_for(method_name, target_domain)
    os.makedirs(out_dir, exist_ok=True)

    # per-domain (val on source domains)
    pd.DataFrame([
        {"domain": d, "acc": acc}
        for d, acc in sorted(metrics_cd["per_domain"].items())
    ]).to_csv(os.path.join(out_dir, "val_source_per_domain.csv"), index=False)

    # CD table
    rows = []
    for (c,d), (corr,tot) in metrics_cd["cd_counts"].items():
        if tot>0:
            rows.append({"class": c, "domain": d, "correct": corr, "total": tot, "acc": corr/tot})
    pd.DataFrame(rows).sort_values(["domain","class"]).to_csv(
        os.path.join(out_dir, "val_source_CD.csv"), index=False)

    # CDS table
    rows = []
    for (c,d,sb), (corr,tot) in metrics_cds["cds_counts"].items():
        if tot>0:
            rows.append({"class": c, "domain": d, "stylebin": sb, "correct": corr, "total": tot, "acc": corr/tot})
    if rows:
        pd.DataFrame(rows).sort_values(["domain","class","stylebin"]).to_csv(
            os.path.join(out_dir, "val_source_CDS.csv"), index=False)

def save_target_csvs(method_name, target_domain, model, dl_test, test_metrics_cd, test_metrics_cds, acc_dom_sb, acc_cds):
    d = outdir_for(method_name, target_domain)
    os.makedirs(d, exist_ok=True)

    # overall metrics file
    m = pd.DataFrame([
        {"metric": "overall_avg",     "value": float(test_metrics_cd["avg"])},
        {"metric": "worst_group_cd",  "value": float(test_metrics_cd["worst_group"])},
        {"metric": "worst_group_cds", "value": float(test_metrics_cds["worst_group"])},
    ])
    m.to_csv(os.path.join(d, "metrics.csv"), index=False)

    # CD groups on target
    rows = []
    for (c,dname), (corr,tot) in test_metrics_cd.get("cd_counts", {}).items():
        if tot>0:
            rows.append({"class": c, "domain": dname, "correct": corr, "total": tot, "acc": corr/tot})
    if rows:
        pd.DataFrame(rows).sort_values(["domain","class"]).to_csv(os.path.join(d, "groups_cd.csv"), index=False)

    # CDS groups on target
    rows = []
    for (c,dname,sb), (corr,tot) in test_metrics_cds.get("cds_counts", {}).items():
        if tot>0:
            rows.append({"class": c, "domain": dname, "stylebin": sb, "correct": corr, "total": tot, "acc": corr/tot})
    if rows:
        pd.DataFrame(rows).sort_values(["domain","class","stylebin"]).to_csv(os.path.join(d, "groups_cds.csv"), index=False)

    # stylebin per-domain on target (domain here is just the target)
    rows = []
    for (dname, sb), acc in acc_dom_sb.items():
        rows.append({"domain": dname, "stylebin": sb, "acc": acc})
    if rows:
        pd.DataFrame(rows).sort_values(["domain","stylebin"]).to_csv(os.path.join(d, "stylebin_domain.csv"), index=False)

    print(f"[saved] per-target CSVs -> {d}")

# --- tiny plotting util: per-domain curves from train_log.csv
def plot_val_per_domain_curves(method_name, target_domain):
    d = outdir_for(method_name, target_domain)
    log_path = os.path.join(d, "train_log.csv")
    if not os.path.isfile(log_path):
        return
    df = pd.read_csv(log_path)
    acc_cols = [c for c in df.columns if c.startswith("acc_")]  # acc_<Domain>
    if not acc_cols:
        return
    plt.figure(figsize=(6.5,4))
    for c in sorted(acc_cols):
        plt.plot(df["epoch"], df[c], label=c.replace("acc_",""))
    plt.xlabel("Epoch"); plt.ylabel("Val accuracy (per-domain)")
    plt.title(f"{method_name} — source-val per-domain — target={target_domain}")
    plt.grid(True); plt.legend(); plt.tight_layout()
    plt.savefig(os.path.join(d, "val_per_domain_curves.png"), dpi=150)
    plt.close()


In [None]:
# def irm_penalty_from_feats(feats, head, y, ce):
#     """
#     feats: [B, D]   (requires_grad=True from backbone)
#     head:  nn.Linear (we will use its *detached* weights for the penalty path)
#     y:     [B]
#     ce:    loss fn (e.g., CrossEntropyLoss with label smoothing, optional)

#     Implements IRMv1 penalty on a dummy scale 's' with head weights detached, so
#     gradients flow only to 'feats' (backbone), not head.
#     """
#     with torch.no_grad():
#         W = head.weight.clone()
#         b = head.bias.clone() if head.bias is not None else None

#     logits_pen = F.linear(feats, W.detach(), b.detach())
#     s = torch.tensor(1.0, device=feats.device, requires_grad=True)
#     loss = ce(s * logits_pen, y)
#     g = torch.autograd.grad(loss, s, create_graph=True)[0]
#     return g**2  # scalar

# penalty path (keep features-only gradient)
def irm_penalty_from_feats(feats, head, y, ce):
    if feats.ndim > 2:
        feats = F.adaptive_avg_pool2d(feats, 1).view(feats.size(0), -1)
    with torch.no_grad():
        W = head.weight.clone()
        b = head.bias.clone() if head.bias is not None else None
    logits_pen = F.linear(feats, W.detach(), b.detach())
    s = torch.tensor(1.0, device=feats.device, requires_grad=True)
    loss = ce(s * logits_pen, y)
    g = torch.autograd.grad(loss, s, create_graph=True)[0]
    return g**2


In [None]:
class DomainStepIterator:
    """Cycles per-domain loaders forever and yields a dict(domain -> (x,y,meta))."""
    def __init__(self, loaders_by_domain):
        self.loaders_by_domain = loaders_by_domain
        self.domains = list(loaders_by_domain.keys())
        self.iters = {d: iter(dl) for d,dl in loaders_by_domain.items()}

    def __iter__(self): return self

    def __next__(self):
        step = {}
        for d in self.domains:
            try:
                batch = next(self.iters[d])
            except StopIteration:
                self.iters[d] = iter(self.loaders_by_domain[d])
                batch = next(self.iters[d])
            step[d] = batch
        return step


In [None]:
# ====== IRMv1 config ======
IRM_LR = 0.005
IRM_WD = 1e-4
IRM_MOM = 0.9
IRM_EPOCHS = 40
IRM_BATCH_PER_DOMAIN = 32
IRM_WARMUP_EPOCHS = 5
IRM_LAMBDA_AFTER = 5.0
LABEL_SMOOTH = 0.05

def run_irmv1_for_target(target_domain="Sketch"):
    print(f"\n=== IRMv1 :: target={target_domain} ===")
    train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, target_domain, val_frac=0.1)

    # StyleBin thresholds from source-train; then assign to val/test
    style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
    style_map_val  = apply_stylebins_to_indices(DATA_ROOT, val_idx, thresholds)
    style_map_test = apply_stylebins_to_indices(DATA_ROOT, test_idx, thresholds)
    style_map = {**style_map_train, **style_map_val, **style_map_test}

    # --- Per-domain train loaders
    loaders_by_domain = {}
    for d in source_domains:
        d_train = [(dd,cc,fn) for (dd,cc,fn) in train_idx if dd==d]
        ds = PACS(DATA_ROOT, [d], split_indices=d_train, transform=train_tf, stylebin_map=style_map)
        dl = DataLoader(ds, batch_size=IRM_BATCH_PER_DOMAIN, shuffle=True, num_workers=2,
                        pin_memory=True, drop_last=True, collate_fn=collate_keep_meta)
        loaders_by_domain[d] = dl
    print("[train] per-domain #batches:", {d: len(dl) for d,dl in loaders_by_domain.items()})

    step_iter = DomainStepIterator(loaders_by_domain)

    # --- Val/Test loaders
    ds_val  = PACS(DATA_ROOT, source_domains, split_indices=val_idx,  transform=eval_tf, stylebin_map=style_map)
    ds_test = PACS(DATA_ROOT, [target_domain],  split_indices=test_idx, transform=eval_tf, stylebin_map=style_map)
    dl_val  = DataLoader(ds_val,  batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)
    dl_test = DataLoader(ds_test, batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)

    # --- Model/opt/sched
    backbone, head = build_model(n_classes=len(classes))
    backbone, head = backbone.to(DEVICE), head.to(DEVICE)
    params = list(backbone.parameters()) + list(head.parameters())
    opt = torch.optim.SGD(params, lr=IRM_LR, momentum=IRM_MOM, weight_decay=IRM_WD)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=IRM_EPOCHS)
    ce = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)

    best_ckpt, best_sel = None, -1.0
    log = []

    steps_per_epoch = min(len(dl) for dl in loaders_by_domain.values())
    if steps_per_epoch < 1:
        raise RuntimeError("0 steps/epoch — reduce IRM_BATCH_PER_DOMAIN or set drop_last=False")

    for ep in range(IRM_EPOCHS):
        backbone.train(); head.train()
        lmbd = 0.0 if ep < IRM_WARMUP_EPOCHS else IRM_LAMBDA_AFTER
        it = iter(step_iter)

        for _ in range(steps_per_epoch):
            opt.zero_grad()

            # Gather one mini-batch per source domain
            step = next(it)  # dict: d -> (x,y,meta)
            ce_terms = []
            pen_terms = []

            for d, (x,y,meta) in step.items():
                x = x.to(DEVICE); y = y.to(DEVICE)

                logits_ce, feats = forward_logits(backbone, head, x)   # ← use the safe path
                ce_terms.append(ce(logits_ce, y))

                # IRM penalty path (hits *features only*)
                pen = irm_penalty_from_feats(feats, head, y, ce)
                pen_terms.append(pen)

            loss_ce = torch.stack(ce_terms).mean()
            loss_pen = torch.stack(pen_terms).mean()
            loss = loss_ce + lmbd * loss_pen

            loss.backward()
            opt.step()

        sched.step()

        # ---- Validation on source domains
        metrics_cd  = evaluate({"backbone":backbone, "head":head}, dl_val, classes, group_key="CD",  return_groups=True)
        metrics_cds = evaluate({"backbone":backbone, "head":head}, dl_val, classes, group_key="CDS", return_groups=True)

        wc = float(metrics_cd["worst_group"])
        ws = float(metrics_cds["worst_group"])

        # coverage check for CDS (skip empty groups already in evaluate)
        cds_counts = metrics_cds.get("cds_counts", {})
        covered = sum(1 for _, (c,t) in cds_counts.items() if t > 0)
        total    = len(cds_counts)
        coverage_ok = (total > 0) and (covered / total >= 0.80)

        # choose selector
        if coverage_ok:
            # option A: weighted mix (safer)
            sel_now = 0.6 * wc + 0.4 * ws
            # option B (stricter on tails): geometric mean
            # sel_now = (wc * ws) ** 0.5
            # option C (very strict): min
            # sel_now = min(wc, ws)
        else:
            sel_now = wc  # fallback when CDS too sparse

        # EMA smoothing
        alpha = 0.8  # higher = smoother
        if ep == 0:
            sel_smoothed = sel_now
        else:
            sel_smoothed = alpha * sel_prev + (1 - alpha) * sel_now
        sel_prev = sel_smoothed

        # use sel_smoothed to pick best checkpoint
        sel = sel_smoothed

        # per-domain acc logging
        row = {
            "epoch": ep,
            "lambda": lmbd,
            "worst_group_cd": metrics_cd["worst_group"],
            "worst_group_cds": metrics_cds["worst_group"],
            "avg": metrics_cd["avg"],
        }
        # add per-domain (source-val) into the log row
        for dom, acc in metrics_cd["per_domain"].items():
            row[f"acc_{dom}"] = acc
        log.append(row)

        if sel > best_sel:
            best_sel = sel
            best_ckpt = {
                "epoch": ep,
                "backbone": copy.deepcopy(backbone.state_dict()),
                "head": copy.deepcopy(head.state_dict()),
                "val_metrics_cd": metrics_cd,
                "val_metrics_cds": metrics_cds
            }

        if (ep+1) % 5 == 0:
            pd_str = "  ".join([f"{k}={row.get(f'acc_{k}', float('nan')):.3f}" for k in sorted(metrics_cd["per_domain"].keys())])
            print(f"[IRMv1][ep {ep+1:02d}] λ={lmbd:.1f}  val worst_group(CD)={row['worst_group_cd']:.3f}  "
                  f"worst_group(CDS)={row['worst_group_cds']:.3f}  avg={row['avg']:.3f}  | {pd_str}")

    # ---- Load best, evaluate on target-test, save everything
    backbone.load_state_dict(best_ckpt["backbone"]); head.load_state_dict(best_ckpt["head"])

    test_metrics_cd  = evaluate({"backbone":backbone, "head":head}, dl_test, classes, group_key="CD",  return_groups=True)
    test_metrics_cds = evaluate({"backbone":backbone, "head":head}, dl_test, classes, group_key="CDS", return_groups=True)

    # You already have stylebin_bars(...) in ERM; reuse it here
    acc_dom_sb, acc_cds, worst_cds_test = stylebin_bars({"backbone":backbone, "head":head}, dl_test)

    # Save checkpoint + log + curves
    save_best_checkpoint("IRMv1", target_domain, backbone, head, best_ckpt, log)
    plot_val_per_domain_curves("IRMv1", target_domain)

    # Detailed CSVs
    save_source_val_csvs("IRMv1", target_domain, best_ckpt["val_metrics_cd"], best_ckpt["val_metrics_cds"])
    save_target_csvs("IRMv1", target_domain,
                     {"backbone":backbone,"head":head}, dl_test,
                     test_metrics_cd, test_metrics_cds,
                     acc_dom_sb, acc_cds)

    result = {
        "target": target_domain,
        "best_epoch": best_ckpt["epoch"],
        "val_worst_group_cd": best_ckpt["val_metrics_cd"]["worst_group"],
        "val_avg": best_ckpt["val_metrics_cd"]["avg"],
        "test_avg": test_metrics_cd["avg"],
        "test_worst_group_cd": test_metrics_cd["worst_group"],
        "test_worst_group_cds": test_metrics_cds["worst_group"],
    }
    return result

# === Run IRMv1 on Sketch as target ===
irm_result = run_irmv1_for_target("Sketch")
print("\n=== IRMv1 summary ===")
print(irm_result)



=== IRMv1 :: target=Sketch ===
[train] per-domain #batches: {'Photo': 47, 'Art_painting': 57, 'Cartoon': 66}
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:02<00:00, 22.1MB/s]


[IRMv1][ep 05] λ=0.0  val worst_group(CD)=0.844  worst_group(CDS)=0.750  avg=0.953  | Art_painting=0.915  Cartoon=0.957  Photo=0.994
[IRMv1][ep 10] λ=5.0  val worst_group(CD)=0.842  worst_group(CDS)=0.500  avg=0.956  | Art_painting=0.940  Cartoon=0.952  Photo=0.982
[IRMv1][ep 15] λ=5.0  val worst_group(CD)=0.850  worst_group(CDS)=0.500  avg=0.960  | Art_painting=0.925  Cartoon=0.970  Photo=0.988
[IRMv1][ep 20] λ=5.0  val worst_group(CD)=0.892  worst_group(CDS)=0.667  avg=0.963  | Art_painting=0.940  Cartoon=0.965  Photo=0.988
[IRMv1][ep 25] λ=5.0  val worst_group(CD)=0.865  worst_group(CDS)=0.500  avg=0.963  | Art_painting=0.935  Cartoon=0.970  Photo=0.988
[IRMv1][ep 30] λ=5.0  val worst_group(CD)=0.875  worst_group(CDS)=0.500  avg=0.963  | Art_painting=0.940  Cartoon=0.965  Photo=0.988
[IRMv1][ep 35] λ=5.0  val worst_group(CD)=0.842  worst_group(CDS)=0.500  avg=0.963  | Art_painting=0.945  Cartoon=0.965  Photo=0.982
[IRMv1][ep 40] λ=5.0  val worst_group(CD)=0.865  worst_group(CDS)=0.5

In [None]:
# === Run IRMv1 on cartoon as target ===
irm_result = run_irmv1_for_target("Cartoon")
print("\n=== IRMv1 summary ===")
print(irm_result)


=== IRMv1 :: target=Cartoon ===
[train] per-domain #batches: {'Photo': 47, 'Art_painting': 57, 'Sketch': 110}
[IRMv1][ep 05] λ=0.0  val worst_group(CD)=0.562  worst_group(CDS)=0.333  avg=0.914  | Art_painting=0.876  Photo=0.988  Sketch=0.903
[IRMv1][ep 10] λ=5.0  val worst_group(CD)=0.800  worst_group(CDS)=0.700  avg=0.946  | Art_painting=0.910  Photo=0.963  Sketch=0.957
[IRMv1][ep 15] λ=5.0  val worst_group(CD)=0.812  worst_group(CDS)=0.667  avg=0.938  | Art_painting=0.920  Photo=0.970  Sketch=0.934
[IRMv1][ep 20] λ=5.0  val worst_group(CD)=0.800  worst_group(CDS)=0.600  avg=0.946  | Art_painting=0.920  Photo=0.957  Sketch=0.954
[IRMv1][ep 25] λ=5.0  val worst_group(CD)=0.812  worst_group(CDS)=0.600  avg=0.950  | Art_painting=0.915  Photo=0.988  Sketch=0.951
[IRMv1][ep 30] λ=5.0  val worst_group(CD)=0.812  worst_group(CDS)=0.600  avg=0.962  | Art_painting=0.930  Photo=0.994  Sketch=0.964
[IRMv1][ep 35] λ=5.0  val worst_group(CD)=0.812  worst_group(CDS)=0.600  avg=0.956  | Art_paintin

Aggregation + insightful plots/tables (ERM vs IRMv1, target=Sketch)

In [None]:
ROOT = OUTPUT_ROOT  # e.g., "/content/drive/MyDrive/DG_PACS"
METHODS = ["ERM", "IRMv1"]
TARGET  = "Sketch"

def load_metrics(method, target):
    d = outdir_for(method, target)
    m = pd.read_csv(os.path.join(d, "metrics.csv")).set_index("metric")["value"].to_dict()
    return {"method":method, "target":target,
            "avg":m.get("overall_avg", np.nan),
            "worst_cd":m.get("worst_group_cd", np.nan),
            "worst_cds":m.get("worst_group_cds", np.nan)}

def load_df_if(path):
    return pd.read_csv(path) if os.path.isfile(path) else pd.DataFrame()

summary_rows = []
cd_rows, cds_rows, sb_rows = [], [], []

for method in METHODS:
    d = outdir_for(method, TARGET)
    # overall
    if os.path.isfile(os.path.join(d, "metrics.csv")):
        summary_rows.append(load_metrics(method, TARGET))
    # target CD
    G = load_df_if(os.path.join(d, "groups_cd.csv"))
    if not G.empty:
        G["method"] = method; G["target"] = TARGET; cd_rows.append(G)
    # target CDS
    H = load_df_if(os.path.join(d, "groups_cds.csv"))
    if not H.empty:
        H["method"] = method; H["target"] = TARGET; cds_rows.append(H)
    # stylebin on target
    S = load_df_if(os.path.join(d, "stylebin_domain.csv"))
    if not S.empty:
        S["method"] = method; S["target"] = TARGET; sb_rows.append(S)

summary = pd.DataFrame(summary_rows).sort_values("method")
cd_df  = pd.concat(cd_rows,  ignore_index=True) if cd_rows  else pd.DataFrame()
cds_df = pd.concat(cds_rows, ignore_index=True) if cds_rows else pd.DataFrame()
sb_df  = pd.concat(sb_rows,  ignore_index=True) if sb_rows  else pd.DataFrame()

print("=== Overall (target=Sketch) ===")
display(summary)

# Worst CD/CDS groups
if not cd_df.empty:
    worst_cd = (cd_df.assign(group=lambda x: x["class"]+" @ "+x["domain"])
                      .sort_values(["method","acc"])
                      .groupby("method").head(3))  # top-3 worst per method
    print("\n=== Worst CD groups (top-3 per method) ===")
    display(worst_cd[["method","group","acc","correct","total"]])

if not cds_df.empty:
    worst_cds = (cds_df.assign(group=lambda x: x["class"]+" @ "+x["domain"]+" / "+x["stylebin"].astype(str))
                       .sort_values(["method","acc"])
                       .groupby("method").head(5))  # top-5 worst per method
    print("\n=== Worst CDS groups (top-5 per method) ===")
    display(worst_cds[["method","group","acc","correct","total"]])

# Plots: avg vs worst_cd vs worst_cds (side-by-side)
plt.figure(figsize=(5,3))
x = np.arange(len(METHODS))
w = 0.25
plt.bar(x- w, summary["avg"],       width=w, label="avg")
plt.bar(x+0.0, summary["worst_cd"], width=w, label="worst_cd")
plt.bar(x+ w, summary["worst_cds"], width=w, label="worst_cds")
plt.xticks(x, summary["method"]); plt.ylim(0,1)
plt.title(f"Sketch target — overall vs tails")
plt.legend(); plt.tight_layout()
plt.savefig(os.path.join(ROOT, "IRM_ERMonSketch_overall.png"), dpi=160); plt.close()

# Plot: style-bin bars for each method on Sketch
for method, S in sb_df.groupby("method"):
    plt.figure(figsize=(4.2,3))
    S = S.sort_values("stylebin")
    plt.bar(S["stylebin"], S["acc"])
    plt.ylim(0,1); plt.title(f"{method} — style-bins on Sketch")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir_for(method, TARGET), "stylebins_on_target.png"), dpi=150); plt.close()

print("Saved comparison figures to:", ROOT)


=== Overall (target=Sketch) ===


Unnamed: 0,method,target,avg,worst_cd,worst_cds
0,ERM,Sketch,0.671418,0.09375,0.083333
1,IRMv1,Sketch,0.782133,0.428756,0.391705



=== Worst CD groups (top-3 per method) ===


Unnamed: 0,method,group,acc,correct,total
6,ERM,person @ Sketch,0.09375,15,
0,ERM,dog @ Sketch,0.409326,316,
4,ERM,horse @ Sketch,0.583333,476,
7,IRMv1,dog @ Sketch,0.428756,331,772.0
13,IRMv1,person @ Sketch,0.65,104,160.0
9,IRMv1,giraffe @ Sketch,0.75166,566,753.0



=== Worst CDS groups (top-5 per method) ===


Unnamed: 0,method,group,acc,correct,total
19,ERM,person @ Sketch / mid,0.083333,3,
20,ERM,person @ Sketch / high,0.092784,9,
18,ERM,person @ Sketch / low,0.111111,3,
2,ERM,dog @ Sketch / high,0.373272,81,
1,ERM,dog @ Sketch / mid,0.407713,148,
21,IRMv1,dog @ Sketch / high,0.391705,85,217.0
23,IRMv1,dog @ Sketch / mid,0.435262,158,363.0
22,IRMv1,dog @ Sketch / low,0.458333,88,192.0
37,IRMv1,house @ Sketch / low,0.5,1,2.0
41,IRMv1,person @ Sketch / mid,0.638889,23,36.0


Saved comparison figures to: /content/drive/MyDrive/DG_PACS


Since it turns out i didnt train the ERM initially very fairly, with no label smoothing and smart selection criteria, it resulted in poor accuracy so we gonna try ERM again

In [None]:
# ===== ERM_2 (Sketch-only) =====
import copy, numpy as np, torch, torch.nn as nn
from torch.utils.data import DataLoader

ERM2_LR = 0.01
ERM2_WD = 1e-4
ERM2_MOM = 0.9
ERM2_EPOCHS = 40
ERM2_BATCH_PER_DOMAIN = 32
LABEL_SMOOTH = 0.05
TARGET_FIXED = "Sketch"

# Balanced batcher that cycles domain loaders (never StopIteration)
class BalancedDomainBatcher:
    def __init__(self, loaders_by_domain):
        self.loaders = loaders_by_domain
        self.domains = list(loaders_by_domain.keys())
    def __iter__(self):
        import torch
        iters = {d: iter(dl) for d, dl in self.loaders.items()}
        while True:
            parts = []
            for d in self.domains:
                try:
                    b = next(iters[d])
                except StopIteration:
                    iters[d] = iter(self.loaders[d])
                    b = next(iters[d])
                parts.append(b)
            xs = torch.cat([p[0] for p in parts], dim=0)
            ys = torch.cat([p[1] for p in parts], dim=0)
            metas = []
            for p in parts:
                m = p[2]
                if isinstance(m, dict):
                    keys = list(m.keys()); n = len(m[keys[0]])
                    for i in range(n):
                        metas.append({k: m[k][i] for k in keys})
                else:
                    metas.extend(m)
            yield xs, ys, metas

def run_erm2_for_target(target_domain=TARGET_FIXED):
    print(f"\n=== ERM_2 :: target={target_domain} ===")
    # --- splits & style bins
    train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, target_domain, val_frac=0.1)

    style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
    style_map_val  = apply_stylebins_to_indices(DATA_ROOT, val_idx, thresholds)
    style_map_test = apply_stylebins_to_indices(DATA_ROOT, test_idx, thresholds)
    style_map = {**style_map_train, **style_map_val, **style_map_test}

    # --- per-domain train loaders
    loaders_by_domain = {}
    for d in source_domains:
        d_train = [(dd,cc,fn) for (dd,cc,fn) in train_idx if dd==d]
        ds = PACS(DATA_ROOT, [d], split_indices=d_train, transform=train_tf, stylebin_map=style_map)
        dl = DataLoader(ds, batch_size=ERM2_BATCH_PER_DOMAIN, shuffle=True, num_workers=2,
                        pin_memory=True, drop_last=True, collate_fn=collate_keep_meta)
        loaders_by_domain[d] = dl
    print("[train] per-domain #batches:", {d: len(dl) for d,dl in loaders_by_domain.items()})

    train_bal = BalancedDomainBatcher(loaders_by_domain)

    # --- val/test
    ds_val  = PACS(DATA_ROOT, source_domains, split_indices=val_idx,  transform=eval_tf, stylebin_map=style_map)
    ds_test = PACS(DATA_ROOT, [target_domain],  split_indices=test_idx, transform=eval_tf, stylebin_map=style_map)
    dl_val  = DataLoader(ds_val,  batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)
    dl_test = DataLoader(ds_test, batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)

    # --- model/opt/sched
    backbone, head = build_model(n_classes=len(classes))    # pretrained ResNet18 + Linear
    backbone, head = backbone.to(DEVICE), head.to(DEVICE)
    opt = torch.optim.SGD(list(backbone.parameters())+list(head.parameters()),
                          lr=ERM2_LR, momentum=ERM2_MOM, weight_decay=ERM2_WD)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=ERM2_EPOCHS)
    ce = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)

    best_ckpt, best_sel = None, -1.0
    log = []
    sel_prev = None  # EMA state for selector

    steps_per_epoch = min(len(dl) for dl in loaders_by_domain.values())
    if steps_per_epoch < 1:
        raise RuntimeError("0 steps/epoch — reduce ERM2_BATCH_PER_DOMAIN or set drop_last=False")

    for ep in range(ERM2_EPOCHS):
        backbone.train(); head.train()
        it = iter(train_bal)
        for _ in range(steps_per_epoch):
            x,y,meta = next(it)
            x,y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            logits, feats = forward_logits(backbone, head, x)
            loss = ce(logits, y)
            loss.backward()
            opt.step()
        sched.step()

        # ---- source-val evaluation
        metrics_cd  = evaluate({"backbone":backbone,"head":head}, dl_val, classes, group_key="CD",  return_groups=True)
        metrics_cds = evaluate({"backbone":backbone,"head":head}, dl_val, classes, group_key="CDS", return_groups=True)

        wc = float(metrics_cd["worst_group"])
        ws = float(metrics_cds["worst_group"])
        cds_counts = metrics_cds.get("cds_counts", {})
        covered = sum(1 for _, (c,t) in cds_counts.items() if t>0)
        total   = len(cds_counts)
        coverage_ok = (total > 0) and (covered / total >= 0.80)

        sel_now = 0.6 * wc + 0.4 * ws if coverage_ok else wc
        alpha = 0.8
        sel_smoothed = sel_now if sel_prev is None else (alpha*sel_prev + (1-alpha)*sel_now)
        sel_prev = sel_smoothed
        sel = sel_smoothed

        # log row (with per-domain)
        row = {
            "epoch": ep,
            "worst_group_cd": wc,
            "worst_group_cds": ws,
            "avg": metrics_cd["avg"],
            "cds_coverage_ok": int(coverage_ok),
            "sel_now": sel_now,
            "sel_smoothed": sel_smoothed,
        }
        for dom, acc in metrics_cd["per_domain"].items():
            row[f"acc_{dom}"] = acc
        log.append(row)

        if sel > best_sel:
            best_sel = sel
            best_ckpt = {
                "epoch": ep,
                "backbone": copy.deepcopy(backbone.state_dict()),
                "head": copy.deepcopy(head.state_dict()),
                "val_metrics_cd": metrics_cd,
                "val_metrics_cds": metrics_cds,
            }

        if (ep+1)%5==0:
            pd_str = "  ".join([f"{d}={row.get(f'acc_{d}', float('nan')):.3f}" for d in sorted(source_domains)])
            print(f"[ERM_2][ep {ep+1:02d}] val worst_cd={wc:.3f}  worst_cds={ws:.3f}  avg={row['avg']:.3f}  | {pd_str}")

    # ---- load best, test on Sketch, save everything under ERM_2
    backbone.load_state_dict(best_ckpt["backbone"]); head.load_state_dict(best_ckpt["head"])

    test_metrics_cd  = evaluate({"backbone":backbone,"head":head}, dl_test, classes, group_key="CD",  return_groups=True)
    test_metrics_cds = evaluate({"backbone":backbone,"head":head}, dl_test, classes, group_key="CDS", return_groups=True)

    acc_dom_sb, acc_cds, worst_cds_test = stylebin_bars({"backbone":backbone,"head":head}, dl_test)

    # Save everything with method_name="ERM_2"
    save_best_checkpoint("ERM_2", target_domain, backbone, head, best_ckpt, log)
    plot_val_per_domain_curves("ERM_2", target_domain)

    # source-val detailed CSVs (best epoch)
    save_source_val_csvs("ERM_2", target_domain, best_ckpt["val_metrics_cd"], best_ckpt["val_metrics_cds"])

    # target-test CSVs
    save_target_csvs("ERM_2", target_domain,
                     {"backbone":backbone,"head":head}, dl_test,
                     test_metrics_cd, test_metrics_cds,
                     acc_dom_sb, acc_cds)

    result = {
        "target": target_domain,
        "best_epoch": best_ckpt["epoch"],
        "val_worst_group_cd": best_ckpt["val_metrics_cd"]["worst_group"],
        "val_avg": best_ckpt["val_metrics_cd"]["avg"],
        "test_avg": test_metrics_cd["avg"],
        "test_worst_group_cd": test_metrics_cd["worst_group"],
        "test_worst_group_cds": test_metrics_cds["worst_group"],
    }
    return result

# === Run ===
erm2_result = run_erm2_for_target("Sketch")
print("\n=== ERM_2 summary ===")
print(erm2_result)



=== ERM_2 :: target=Sketch ===
[train] per-domain #batches: {'Photo': 47, 'Art_painting': 57, 'Cartoon': 66}
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 176MB/s]


[ERM_2][ep 05] val worst_cd=0.719  worst_cds=0.500  avg=0.941  | Art_painting=0.925  Cartoon=0.926  Photo=0.982
[ERM_2][ep 10] val worst_cd=0.850  worst_cds=0.667  avg=0.963  | Art_painting=0.945  Cartoon=0.961  Photo=0.988
[ERM_2][ep 15] val worst_cd=0.800  worst_cds=0.500  avg=0.965  | Art_painting=0.940  Cartoon=0.974  Photo=0.982
[ERM_2][ep 20] val worst_cd=0.842  worst_cds=0.500  avg=0.970  | Art_painting=0.960  Cartoon=0.978  Photo=0.970
[ERM_2][ep 25] val worst_cd=0.842  worst_cds=0.500  avg=0.966  | Art_painting=0.955  Cartoon=0.965  Photo=0.982
[ERM_2][ep 30] val worst_cd=0.842  worst_cds=0.500  avg=0.970  | Art_painting=0.955  Cartoon=0.974  Photo=0.982
[ERM_2][ep 35] val worst_cd=0.842  worst_cds=0.500  avg=0.968  | Art_painting=0.955  Cartoon=0.974  Photo=0.976
[ERM_2][ep 40] val worst_cd=0.842  worst_cds=0.500  avg=0.971  | Art_painting=0.960  Cartoon=0.974  Photo=0.982
[saved] model/logs/plots -> /content/drive/MyDrive/DG_PACS/ERM_2/Sketch
[saved] per-target CSVs -> /cont

Analysis of ERM_2

In [None]:

ROOT = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"
ERM_TAG = "ERM_2"
IRM_TAG = "IRMv1"

def _load(path):
    return pd.read_csv(path) if os.path.isfile(path) else pd.DataFrame()

def _ensure_dir(p):
    os.makedirs(p, exist_ok=True)
    return p

def _barh_deltas(df, title, outpath, k=10):
    if df.empty:
        print(f"[skip] No data for {title}")
        return
    topk = df.sort_values("delta", ascending=False).head(k)
    botk = df.sort_values("delta", ascending=True).head(k)
    # top movers
    plt.figure(figsize=(8, 4))
    plt.barh(topk["group"], topk["delta"])
    plt.axvline(0, color="k", linewidth=1)
    plt.title(f"Top Δ (IRM − ERM) — {title}")
    plt.tight_layout(); plt.savefig(outpath.replace(".png","_top.png"), dpi=150); plt.close()
    # bottom movers
    plt.figure(figsize=(8, 4))
    plt.barh(botk["group"], botk["delta"])
    plt.axvline(0, color="k", linewidth=1)
    plt.title(f"Bottom Δ (IRM − ERM) — {title}")
    plt.tight_layout(); plt.savefig(outpath.replace(".png","_bottom.png"), dpi=150); plt.close()

def _merge_delta(A, B, key, name_maker):
    if A.empty or B.empty:
        return pd.DataFrame()
    M = A.merge(B, on=key, suffixes=("_ERM","_IRM"))
    # normalize column names
    # Accept either 'acc' columns or already-suffixed
    if "acc_ERM" not in M.columns and "acc" in A.columns:
        M = M.rename(columns={"acc_x":"acc_ERM","acc_y":"acc_IRM"})
    # compute delta
    M["delta"] = M["acc_IRM"] - M["acc_ERM"]
    M["group"] = name_maker(M)
    return M

# --- Paths
p_erm = os.path.join(ROOT, ERM_TAG, TARGET)
p_irm = os.path.join(ROOT, IRM_TAG, TARGET)
out_root = ROOT  # extra combined figs
out_erm = _ensure_dir(p_erm)
out_irm = _ensure_dir(p_irm)

# --- Load CSVs
erm_cds = _load(os.path.join(p_erm, "groups_cds.csv"))
irm_cds = _load(os.path.join(p_irm, "groups_cds.csv"))
erm_cd  = _load(os.path.join(p_erm, "groups_cd.csv"))
irm_cd  = _load(os.path.join(p_irm, "groups_cd.csv"))
erm_sb  = _load(os.path.join(p_erm, "stylebin_domain.csv"))
irm_sb  = _load(os.path.join(p_irm, "stylebin_domain.csv"))

# --- Δ for CDS (class×domain×stylebin)
delta_cds = _merge_delta(
    erm_cds, irm_cds,
    key=["class","domain","stylebin"],
    name_maker=lambda M: M["class"] + " @ " + M["domain"] + " / " + M["stylebin"].astype(str)
)
if not delta_cds.empty:
    # print dog@Sketch/high
    q = (delta_cds["class"].eq("dog") & delta_cds["domain"].eq("Sketch") & delta_cds["stylebin"].eq("high"))
    if q.any():
        row = delta_cds[q].iloc[0]
        print(f"dog @ Sketch / high — ERM={row['acc_ERM']:.3f}  IRM={row['acc_IRM']:.3f}  Δ={row['delta']:.3f}")
    _barh_deltas(delta_cds, "CDS groups (Sketch target)", os.path.join(out_root, "IRM_vs_ERM_Sketch_CDS.png"))

# --- Optional: Δ for CD (class×domain) — target domain rows only
if not erm_cd.empty and not irm_cd.empty:
    erm_cd_t = erm_cd[erm_cd["domain"].eq(TARGET)].copy()
    irm_cd_t = irm_cd[irm_cd["domain"].eq(TARGET)].copy()
    delta_cd = _merge_delta(
        erm_cd_t, irm_cd_t,
        key=["class","domain"],
        name_maker=lambda M: M["class"] + " @ " + M["domain"]
    )
    _barh_deltas(delta_cd, "CD groups (Sketch target)", os.path.join(out_root, "IRM_vs_ERM_Sketch_CD.png"))

# --- Optional: Δ for domain×stylebin aggregate (still on target domain)
# Average over classes within each stylebin for Sketch; then IRM−ERM
def _agg_sb(df):
    if df.empty: return df
    # when evaluating target, 'domain' column is the target; aggregate by stylebin only
    return (df[df["domain"].eq(TARGET)]
              .groupby(["stylebin"], as_index=False)["acc"].mean()
              .rename(columns={"acc":"acc"}))

if not erm_cds.empty and not irm_cds.empty:
    erm_sb_agg = _agg_sb(erm_cds.rename(columns={"acc":"acc"}))
    irm_sb_agg = _agg_sb(irm_cds.rename(columns={"acc":"acc"}))
    sb_merge = erm_sb_agg.merge(irm_sb_agg, on=["stylebin"], suffixes=("_ERM","_IRM"))
    if not sb_merge.empty:
        sb_merge["delta"] = sb_merge["acc_IRM"] - sb_merge["acc_ERM"]
        plt.figure(figsize=(4.5,3))
        plt.bar(sb_merge["stylebin"], sb_merge["delta"])
        plt.axhline(0, color="k", linewidth=1)
        plt.title("Δ by stylebin (Sketch target)")
        plt.tight_layout()
        plt.savefig(os.path.join(out_root, "IRM_vs_ERM_Sketch_stylebin_delta.png"), dpi=150); plt.close()

# --- Save a compact CSV with all CDS deltas (handy for tables)
if not delta_cds.empty:
    delta_cds.sort_values("delta", ascending=True).to_csv(
        os.path.join(out_root, "IRM_vs_ERM_Sketch_CDS_deltas.csv"), index=False
    )

print("Δ-analysis saved under:", out_root)


dog @ Sketch / high — ERM=0.442  IRM=0.392  Δ=-0.051
Δ-analysis saved under: /content/drive/MyDrive/DG_PACS


# GROUP-DRO

In [None]:
# ===== GroupDRO config =====
GDRO_LR = 0.005
GDRO_WD = 1e-4
GDRO_MOM = 0.9
GDRO_EPOCHS = 40
GDRO_BATCH_PER_DOMAIN = 32
GDRO_EG_STEP = 0.01    # exponentiated gradient step (eta)
LABEL_SMOOTH = 0.05

TARGET_FIXED = "Sketch"  # as agreed

def softmax_normalize(log_w):
    # for numerical stability
    m = torch.max(log_w)
    w = torch.exp(log_w - m)
    return w / (w.sum() + 1e-12)

def weights_entropy(w):
    # H(w) = -∑ w_i log w_i
    w_ = w.clamp_min(1e-12)
    return float((-w_ * w_.log()).sum().item())

def weights_eff_num(w):
    # effective #domains: 1 / sum w_i^2  (∈[1, K])
    return float(1.0 / (w.pow(2).sum().item() + 1e-12))


In [None]:
class DomainStepIterator:
    """Cycles each domain loader; yields a dict: domain -> (x,y,meta) per step."""
    def __init__(self, loaders_by_domain):
        self.loaders_by_domain = loaders_by_domain
        self.domains = list(loaders_by_domain.keys())
        self.iters = {d: iter(dl) for d,dl in loaders_by_domain.items()}
    def __iter__(self): return self
    def __next__(self):
        step = {}
        for d in self.domains:
            try:
                batch = next(self.iters[d])
            except StopIteration:
                self.iters[d] = iter(self.loaders_by_domain[d])
                batch = next(self.iters[d])
            step[d] = batch
        return step


In [None]:
class RobustSelectorEMA:
    def __init__(self, alpha=0.8, cds_cov_thresh=0.80, mix=(0.6, 0.4)):
        self.alpha = alpha
        self.cov_thresh = cds_cov_thresh
        self.w_cd, self.w_cds = mix
        self.prev = None

    def __call__(self, metrics_cd, metrics_cds, epoch):
        wc  = float(metrics_cd["worst_group"])
        wcs = float(metrics_cds["worst_group"])
        cds_counts = metrics_cds.get("cds_counts", {})
        covered = sum(1 for _, (c,t) in cds_counts.items() if t>0)
        coverage_ok = (len(cds_counts) > 0) and (covered / len(cds_counts) >= self.cov_thresh)

        sel_now = self.w_cd*wc + self.w_cds*wcs if coverage_ok else wc
        if self.prev is None:
            self.prev = sel_now
        else:
            self.prev = self.alpha * self.prev + (1-self.alpha)*sel_now
        return self.prev


In [None]:
def _plot_groupdro_weights(method_name, target_domain, source_domains):
    d = outdir_for(method_name, target_domain)
    log_path = os.path.join(d, "train_log.csv")
    if not os.path.isfile(log_path):
        return
    df = pd.read_csv(log_path)
    w_cols = [f"w_{dom}" for dom in source_domains if f"w_{dom}" in df.columns]
    if not w_cols: return
    plt.figure(figsize=(6.5,4))
    for c in w_cols:
        plt.plot(df["epoch"], df[c], label=c.replace("w_",""))
    plt.xlabel("Epoch"); plt.ylabel("Domain weight w_d")
    plt.title(f"{method_name} — EG weights over epochs — target={target_domain}")
    plt.grid(True); plt.legend(); plt.tight_layout()
    plt.savefig(os.path.join(d, "domain_weights_curves.png"), dpi=150)
    plt.close()

    # Also plot Keff & entropy trends
    for col, title, ylabel in [
        ("weights_eff_num", "Effective #domains", "K_eff"),
        ("weights_entropy", "Weight entropy", "H(w)")
    ]:
        if col in df.columns:
            plt.figure(figsize=(5.5,3.5))
            plt.plot(df["epoch"], df[col])
            plt.xlabel("Epoch"); plt.ylabel(ylabel)
            plt.title(f"{method_name} — {title} — target={target_domain}")
            plt.grid(True); plt.tight_layout()
            plt.savefig(os.path.join(d, f"{col}_curve.png"), dpi=150)
            plt.close()


In [None]:
def run_groupdro_for_target(target_domain=TARGET_FIXED):
    print(f"\n=== GroupDRO :: target={target_domain} ===")
    # --- splits & style bins
    train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, target_domain, val_frac=0.1)

    style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
    style_map_val  = apply_stylebins_to_indices(DATA_ROOT, val_idx, thresholds)
    style_map_test = apply_stylebins_to_indices(DATA_ROOT, test_idx, thresholds)
    style_map = {**style_map_train, **style_map_val, **style_map_test}

    # --- per-domain train loaders
    loaders_by_domain = {}
    for d in source_domains:
        d_train = [(dd,cc,fn) for (dd,cc,fn) in train_idx if dd==d]
        ds = PACS(DATA_ROOT, [d], split_indices=d_train, transform=train_tf, stylebin_map=style_map)
        dl = DataLoader(ds, batch_size=GDRO_BATCH_PER_DOMAIN, shuffle=True, num_workers=2,
                        pin_memory=True, drop_last=True, collate_fn=collate_keep_meta)
        loaders_by_domain[d] = dl
    print("[train] per-domain #batches:", {d: len(dl) for d,dl in loaders_by_domain.items()})

    step_iter = DomainStepIterator(loaders_by_domain)

    # --- val/test loaders
    ds_val  = PACS(DATA_ROOT, source_domains, split_indices=val_idx,  transform=eval_tf, stylebin_map=style_map)
    ds_test = PACS(DATA_ROOT, [target_domain],  split_indices=test_idx, transform=eval_tf, stylebin_map=style_map)
    dl_val  = DataLoader(ds_val,  batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)
    dl_test = DataLoader(ds_test, batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)

    # --- model/opt
    backbone, head = build_model(n_classes=len(classes))
    backbone, head = backbone.to(DEVICE), head.to(DEVICE)
    opt = torch.optim.SGD(list(backbone.parameters())+list(head.parameters()),
                          lr=GDRO_LR, momentum=GDRO_MOM, weight_decay=GDRO_WD)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=GDRO_EPOCHS)
    ce = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)

    # --- exponentiated-gradient weights over domains
    K = len(source_domains)
    dom2idx = {d:i for i,d in enumerate(source_domains)}
    log_w = torch.zeros(K, device=DEVICE)          # start uniform in log-space
    selector = RobustSelectorEMA(alpha=0.8, cds_cov_thresh=0.80, mix=(0.6,0.4))

    best_sel, best_ckpt = -1.0, None
    log = []

    steps_per_epoch = min(len(dl) for dl in loaders_by_domain.values())
    if steps_per_epoch < 1:
        raise RuntimeError("0 steps/epoch — reduce GDRO_BATCH_PER_DOMAIN or set drop_last=False")

    for ep in range(GDRO_EPOCHS):
        backbone.train(); head.train()
        it = iter(step_iter)

        for _ in range(steps_per_epoch):
            opt.zero_grad()
            step = next(it)  # dict: domain -> (x,y,meta)

            # compute per-domain losses on *current* step
            losses = torch.zeros(K, device=DEVICE)
            for d, (x,y,meta) in step.items():
                x, y = x.to(DEVICE), y.to(DEVICE)
                logits, feats = forward_logits(backbone, head, x)
                losses[dom2idx[d]] = ce(logits, y)

            # EG weight update (higher loss → higher weight)
            log_w = log_w + GDRO_EG_STEP * losses.detach()
            w = softmax_normalize(log_w)  # simplex weights

            # weighted loss
            loss = (w * losses).sum()
            loss.backward()
            opt.step()

        sched.step()

        # ---- source-val eval (per-domain, CD & CDS)
        metrics_cd  = evaluate({"backbone":backbone,"head":head}, dl_val, classes, group_key="CD",  return_groups=True)
        metrics_cds = evaluate({"backbone":backbone,"head":head}, dl_val, classes, group_key="CDS", return_groups=True)

        # selection using robust mix (CD+CDS with coverage+EMA)
        sel = selector(metrics_cd, metrics_cds, ep)

        # compute GroupDRO-specific stats for the log row
        # - worst source-val domain accuracy
        if metrics_cd["per_domain"]:
            worst_src_acc = min(metrics_cd["per_domain"].values())
        else:
            worst_src_acc = float('nan')

        # track weight entropy & effective #domains
        w_cpu = softmax_normalize(log_w.detach()).detach().cpu()
        H = weights_entropy(w_cpu)
        K_eff = weights_eff_num(w_cpu)

        # log row
        row = {
            "epoch": ep,
            "avg": metrics_cd["avg"],
            "worst_group_cd": metrics_cd["worst_group"],
            "worst_group_cds": metrics_cds["worst_group"],
            "worst_src_domain_acc": worst_src_acc,
            "weights_entropy": H,
            "weights_eff_num": K_eff,
        }
        # append per-domain val accuracies
        for dom, acc in metrics_cd["per_domain"].items():
            row[f"acc_{dom}"] = acc
        # also store current domain weights and per-step is expensive; we store *val-time* snapshot
        for dom in source_domains:
            row[f"w_{dom}"] = float(w_cpu[dom2idx[dom]].item())

        log.append(row)

        # keep best by selector
        if sel > best_sel:
            best_sel = sel
            best_ckpt = {
                "epoch": ep,
                "backbone": copy.deepcopy(backbone.state_dict()),
                "head": copy.deepcopy(head.state_dict()),
                "val_metrics_cd": metrics_cd,
                "val_metrics_cds": metrics_cds,
                "weights": {dom: float(w_cpu[dom2idx[dom]]) for dom in source_domains},
            }

        if (ep+1) % 5 == 0:
            w_str = " ".join([f"{d}:{row[f'w_{d}']:.2f}" for d in source_domains])
            pd_str = "  ".join([f"{d}={row.get(f'acc_{d}', float('nan')):.3f}" for d in sorted(source_domains)])
            print(f"[GroupDRO][ep {ep+1:02d}] val avg={row['avg']:.3f}  "
                  f"worst_cd={row['worst_group_cd']:.3f}  worst_cds={row['worst_group_cds']:.3f}  "
                  f"worst_src_dom_acc={worst_src_acc:.3f}  Keff={K_eff:.2f}  | {pd_str}")

    # load best, evaluate on target-test, save everything
    backbone.load_state_dict(best_ckpt["backbone"]); head.load_state_dict(best_ckpt["head"])

    test_metrics_cd  = evaluate({"backbone":backbone,"head":head}, dl_test, classes, group_key="CD",  return_groups=True)
    test_metrics_cds = evaluate({"backbone":backbone,"head":head}, dl_test, classes, group_key="CDS", return_groups=True)

    acc_dom_sb, acc_cds, worst_cds_test = stylebin_bars({"backbone":backbone,"head":head}, dl_test)

    # Save checkpoint + log
    save_best_checkpoint("GroupDRO", target_domain, backbone, head, best_ckpt, log)

    # Save detailed CSVs (source-val best epoch + target-test)
    save_source_val_csvs("GroupDRO", target_domain, best_ckpt["val_metrics_cd"], best_ckpt["val_metrics_cds"])
    save_target_csvs("GroupDRO", target_domain,
                     {"backbone":backbone,"head":head}, dl_test,
                     test_metrics_cd, test_metrics_cds,
                     acc_dom_sb, acc_cds)

    # Curves: per-domain source-val
    plot_val_per_domain_curves("GroupDRO", target_domain)

    # Extra: plot domain weights over epochs
    _plot_groupdro_weights("GroupDRO", target_domain, source_domains)

    result = {
        "target": target_domain,
        "best_epoch": best_ckpt["epoch"],
        "val_worst_group_cd": best_ckpt["val_metrics_cd"]["worst_group"],
        "val_avg": best_ckpt["val_metrics_cd"]["avg"],
        "test_avg": test_metrics_cd["avg"],
        "test_worst_group_cd": test_metrics_cd["worst_group"],
        "test_worst_group_cds": test_metrics_cds["worst_group"],
    }
    return result

In [None]:
gdro_result = run_groupdro_for_target("Sketch")
print("\n=== GroupDRO summary ===")
print(gdro_result)



=== GroupDRO :: target=Sketch ===
[train] per-domain #batches: {'Photo': 47, 'Art_painting': 57, 'Cartoon': 66}
[GroupDRO][ep 05] val avg=0.936  worst_cd=0.838  worst_cds=0.500  worst_src_dom_acc=0.900  Keff=2.91  | Art_painting=0.900  Cartoon=0.930  Photo=0.988
[GroupDRO][ep 10] val avg=0.945  worst_cd=0.838  worst_cds=0.500  worst_src_dom_acc=0.920  Keff=2.85  | Art_painting=0.920  Cartoon=0.939  Photo=0.982
[GroupDRO][ep 15] val avg=0.946  worst_cd=0.737  worst_cds=0.500  worst_src_dom_acc=0.930  Keff=2.82  | Art_painting=0.930  Cartoon=0.943  Photo=0.970
[GroupDRO][ep 20] val avg=0.951  worst_cd=0.838  worst_cds=0.500  worst_src_dom_acc=0.925  Keff=2.79  | Art_painting=0.925  Cartoon=0.952  Photo=0.982
[GroupDRO][ep 25] val avg=0.955  worst_cd=0.838  worst_cds=0.500  worst_src_dom_acc=0.920  Keff=2.77  | Art_painting=0.920  Cartoon=0.965  Photo=0.982
[GroupDRO][ep 30] val avg=0.951  worst_cd=0.842  worst_cds=0.500  worst_src_dom_acc=0.920  Keff=2.75  | Art_painting=0.920  Cartoon=

Final aggregation & visuals (ERM vs IRMv1 vs GroupDRO, target=Sketch)

In [None]:
ROOT = OUTPUT_ROOT
METHODS = ["ERM", "IRMv1", "GroupDRO"]
TARGET  = "Sketch"

def load_metrics(method, target):
    d = outdir_for(method, target)
    mpath = os.path.join(d, "metrics.csv")
    if not os.path.isfile(mpath): return None
    m = pd.read_csv(mpath).set_index("metric")["value"].to_dict()
    return {"method":method, "target":target,
            "avg":m.get("overall_avg", np.nan),
            "worst_cd":m.get("worst_group_cd", np.nan),
            "worst_cds":m.get("worst_group_cds", np.nan)}

def load_df_if(path):
    return pd.read_csv(path) if os.path.isfile(path) else pd.DataFrame()

summary_rows = []
cd_rows, cds_rows, sb_rows = [], [], []

for method in METHODS:
    d = outdir_for(method, TARGET)
    M = load_metrics(method, TARGET)
    if M: summary_rows.append(M)

    G = load_df_if(os.path.join(d, "groups_cd.csv"))
    if not G.empty:
        G["method"] = method; G["target"] = TARGET; cd_rows.append(G)

    H = load_df_if(os.path.join(d, "groups_cds.csv"))
    if not H.empty:
        H["method"] = method; H["target"] = TARGET; cds_rows.append(H)

    S = load_df_if(os.path.join(d, "stylebin_domain.csv"))
    if not S.empty:
        S["method"] = method; S["target"] = TARGET; sb_rows.append(S)

summary = pd.DataFrame(summary_rows).sort_values("method")
cd_df  = pd.concat(cd_rows,  ignore_index=True) if cd_rows  else pd.DataFrame()
cds_df = pd.concat(cds_rows, ignore_index=True) if cds_rows else pd.DataFrame()
sb_df  = pd.concat(sb_rows,  ignore_index=True) if sb_rows  else pd.DataFrame()

print("=== Overall (target=Sketch) ===")
display(summary)

# Worst CD/CDS groups tables (for report)
if not cd_df.empty:
    worst_cd = (cd_df.assign(group=lambda x: x["class"]+" @ "+x["domain"])
                      .sort_values(["method","acc"])
                      .groupby("method").head(3))
    print("\n=== Worst CD groups (top-3 per method) ===")
    display(worst_cd[["method","group","acc","correct","total"]])

if not cds_df.empty:
    worst_cds = (cds_df.assign(group=lambda x: x["class"]+" @ "+x["domain"]+" / "+x["stylebin"].astype(str))
                       .sort_values(["method","acc"])
                       .groupby("method").head(5))
    print("\n=== Worst CDS groups (top-5 per method) ===")
    display(worst_cds[["method","group","acc","correct","total"]])

# Bars: avg vs worst_cd vs worst_cds
plt.figure(figsize=(6,3.5))
x = np.arange(len(summary))
w = 0.25
plt.bar(x- w, summary["avg"],       width=w, label="avg")
plt.bar(x+0.0, summary["worst_cd"], width=w, label="worst_cd")
plt.bar(x+ w, summary["worst_cds"], width=w, label="worst_cds")
plt.xticks(x, summary["method"]); plt.ylim(0,1)
plt.title("Sketch target — overall vs tails")
plt.legend(); plt.tight_layout()
plt.savefig(os.path.join(ROOT, "Methods_on_Sketch_overall.png"), dpi=160); plt.close()

# Style-bin bars for each method on Sketch
for method, S in sb_df.groupby("method"):
    plt.figure(figsize=(4.2,3))
    S = S.sort_values("stylebin")
    plt.bar(S["stylebin"], S["acc"])
    plt.ylim(0,1); plt.title(f"{method} — style-bins on Sketch")
    plt.tight_layout()
    plt.savefig(os.path.join(outdir_for(method, TARGET), "stylebins_on_target.png"), dpi=150); plt.close()

print("Saved comparison figures to:", ROOT)


=== Overall (target=Sketch) ===


Unnamed: 0,method,target,avg,worst_cd,worst_cds
0,ERM,Sketch,0.671418,0.09375,0.083333
2,GroupDRO,Sketch,0.800458,0.625648,0.5
1,IRMv1,Sketch,0.782133,0.428756,0.391705



=== Worst CD groups (top-3 per method) ===


Unnamed: 0,method,group,acc,correct,total
6,ERM,person @ Sketch,0.09375,15,
0,ERM,dog @ Sketch,0.409326,316,
4,ERM,horse @ Sketch,0.583333,476,
14,GroupDRO,dog @ Sketch,0.625648,483,772.0
16,GroupDRO,giraffe @ Sketch,0.706507,532,753.0
20,GroupDRO,person @ Sketch,0.76875,123,160.0
7,IRMv1,dog @ Sketch,0.428756,331,772.0
13,IRMv1,person @ Sketch,0.65,104,160.0
9,IRMv1,giraffe @ Sketch,0.75166,566,753.0



=== Worst CDS groups (top-5 per method) ===


Unnamed: 0,method,group,acc,correct,total
19,ERM,person @ Sketch / mid,0.083333,3,
20,ERM,person @ Sketch / high,0.092784,9,
18,ERM,person @ Sketch / low,0.111111,3,
2,ERM,dog @ Sketch / high,0.373272,81,
1,ERM,dog @ Sketch / mid,0.407713,148,
58,GroupDRO,house @ Sketch / low,0.5,1,2.0
42,GroupDRO,dog @ Sketch / high,0.589862,128,217.0
43,GroupDRO,dog @ Sketch / low,0.635417,122,192.0
50,GroupDRO,giraffe @ Sketch / mid,0.64,144,225.0
44,GroupDRO,dog @ Sketch / mid,0.641873,233,363.0


Saved comparison figures to: /content/drive/MyDrive/DG_PACS


# SAM with ERM

In [None]:
# ===== SAM optimizer (SGD base) =====
class SAM(torch.optim.Optimizer):
    """Sharpness-Aware Minimization (two-step) on top of a base optimizer (SGD here)."""
    def __init__(self, params, base_optimizer_cls, rho=0.05, **kwargs):
        assert rho > 0.0
        defaults = dict(rho=rho, **kwargs)
        super().__init__(params, defaults)
        self.base_optimizer = base_optimizer_cls(self.param_groups, **kwargs)
        self.rho = rho

    @torch.no_grad()
    def first_step(self):
        # compute ||g||_2 over all params
        grad_norm = torch.norm(
            torch.stack([
                p.grad.norm(p=2) for group in self.param_groups for p in group['params']
                if p.grad is not None
            ]), p=2
        )
        self._eps_list = []  # store perturbations for second step
        scale = self.rho / (grad_norm + 1e-12)

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    self._eps_list.append(None);
                    continue
                e = p.grad * scale
                p.add_(e)           # w <- w + eps
                self._eps_list.append(e)

        # important: do not step the base optimizer here
        for group in self.param_groups:
            group["_sam_grad_norm"] = float(grad_norm.item())

    @torch.no_grad()
    def second_step(self):
        # restore weights and take a step with gradients from perturbed weights
        idx = 0
        for group in self.param_groups:
            for p in group['params']:
                e = self._eps_list[idx]; idx += 1
                if e is not None:
                    p.sub_(e)       # w <- w - eps (restore)
        self.base_optimizer.step()

    def zero_grad(self):
        self.base_optimizer.zero_grad()


In [None]:
# ===== SAM config (ERM_2-style) =====
SAM_LR = 0.01          # same as ERM_2 now
WARMUP_EPOCHS = 3      # tiny warm-up for stability
SAM_WD = 1e-4
SAM_MOM = 0.9
SAM_EPOCHS = 40
SAM_BATCH_PER_DOMAIN = 32
SAM_RHO = 0.05
LABEL_SMOOTH = 0.05
TARGET_FIXED = "Sketch"

class BalancedDomainBatcher:
    def __init__(self, loaders_by_domain):
        self.loaders = loaders_by_domain
        self.domains = list(loaders_by_domain.keys())
    def __iter__(self):
        iters = {d: iter(dl) for d, dl in self.loaders.items()}
        while True:
            parts = []
            for d in self.domains:
                try:
                    b = next(iters[d])
                except StopIteration:
                    iters[d] = iter(self.loaders[d])
                    b = next(iters[d])
                parts.append(b)
            xs = torch.cat([p[0] for p in parts], dim=0)
            ys = torch.cat([p[1] for p in parts], dim=0)
            metas = []
            for p in parts:
                m = p[2]
                if isinstance(m, dict):
                    keys = list(m.keys()); n = len(m[keys[0]])
                    for i in range(n):
                        metas.append({k: m[k][i] for k in keys})
                else:
                    metas.extend(m)
            yield xs, ys, metas

def run_sam_for_target(target_domain=TARGET_FIXED):
    print(f"\n=== SAM :: target={target_domain} ===")
    # --- splits & style bins
    train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, target_domain, val_frac=0.1)
    style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
    style_map_val  = apply_stylebins_to_indices(DATA_ROOT, val_idx, thresholds)
    style_map_test = apply_stylebins_to_indices(DATA_ROOT, test_idx, thresholds)
    style_map = {**style_map_train, **style_map_val, **style_map_test}

    # --- per-domain train loaders
    loaders_by_domain = {}
    for d in source_domains:
        d_train = [(dd,cc,fn) for (dd,cc,fn) in train_idx if dd==d]
        ds = PACS(DATA_ROOT, [d], split_indices=d_train, transform=train_tf, stylebin_map=style_map)
        dl = DataLoader(ds, batch_size=SAM_BATCH_PER_DOMAIN, shuffle=True, num_workers=2,
                        pin_memory=True, drop_last=True, collate_fn=collate_keep_meta)
        loaders_by_domain[d] = dl
    print("[train] per-domain #batches:", {d: len(dl) for d,dl in loaders_by_domain.items()})
    train_bal = BalancedDomainBatcher(loaders_by_domain)

    # --- val/test
    ds_val  = PACS(DATA_ROOT, source_domains, split_indices=val_idx,  transform=eval_tf, stylebin_map=style_map)
    ds_test = PACS(DATA_ROOT, [target_domain],  split_indices=test_idx, transform=eval_tf, stylebin_map=style_map)
    dl_val  = DataLoader(ds_val,  batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)
    dl_test = DataLoader(ds_test, batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)

    # --- model/opt/sched
    backbone, head = build_model(n_classes=len(classes))
    backbone, head = backbone.to(DEVICE), head.to(DEVICE)

    # SAM wraps SGD
    sam_opt = SAM(list(backbone.parameters())+list(head.parameters()),
                  base_optimizer_cls=torch.optim.SGD,
                  rho=SAM_RHO, lr=SAM_LR, momentum=SAM_MOM, weight_decay=SAM_WD)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(sam_opt.base_optimizer, T_max=SAM_EPOCHS)
    def _set_lr(optimizer, lr):
        for pg in optimizer.param_groups:
            pg['lr'] = lr

        # --- loss & steps/epoch BEFORE training
    ce = nn.CrossEntropyLoss(label_smoothing=LABEL_SMOOTH)

    steps_per_epoch = min(len(dl) for dl in loaders_by_domain.values())
    if steps_per_epoch < 1:
        raise RuntimeError("0 steps/epoch — reduce SAM_BATCH_PER_DOMAIN or set drop_last=False")

    best_ckpt, best_sel, log = None, -1.0, []
    sel_prev = None

    for ep in range(SAM_EPOCHS):
        # ---- LR warm-up (linear) ----
        if ep < WARMUP_EPOCHS:
            warm_lr = SAM_LR * float(ep + 1) / float(WARMUP_EPOCHS)
            _set_lr(sam_opt.base_optimizer, warm_lr)
        else:
            _set_lr(sam_opt.base_optimizer, SAM_LR)

        backbone.train(); head.train()
        it = iter(train_bal)

        # epoch-level SAM diagnostics
        train_clean_losses, train_pert_losses, grad_norms = [], [], []

        for _ in range(steps_per_epoch):
            x,y,meta = next(it)
            x,y = x.to(DEVICE), y.to(DEVICE)

            # ---- step 1: clean grad
            sam_opt.zero_grad()
            logits, feats = forward_logits(backbone, head, x)
            loss_clean = ce(logits, y)
            loss_clean.backward()
            sam_opt.first_step()

            # grad norm (from SAM wrapper, if available)
            gn = 0.0
            for g in [pg.get("_sam_grad_norm", None) for pg in sam_opt.param_groups]:
                if g is not None: gn = g
            grad_norms.append(gn)

            # ---- step 2: perturbed grad
            logits_p, _ = forward_logits(backbone, head, x)
            loss_pert = ce(logits_p, y)
            # OPTIONAL: extra stability
            torch.nn.utils.clip_grad_norm_(
                list(backbone.parameters()) + list(head.parameters()), max_norm=5.0
            )
            loss_pert.backward()
            sam_opt.second_step()

            train_clean_losses.append(float(loss_clean.item()))
            train_pert_losses.append(float(loss_pert.item()))

        # advance cosine scheduler once per epoch
        sched.step()

        # ---- source-val eval
        metrics_cd  = evaluate({"backbone":backbone,"head":head}, dl_val, classes, group_key="CD",  return_groups=True)
        metrics_cds = evaluate({"backbone":backbone,"head":head}, dl_val, classes, group_key="CDS", return_groups=True)

        wc = float(metrics_cd["worst_group"])
        ws = float(metrics_cds["worst_group"])
        cds_counts = metrics_cds.get("cds_counts", {})
        covered = sum(1 for _, (c,t) in cds_counts.items() if t>0)
        total   = len(cds_counts)
        coverage_ok = (total > 0) and (covered / total >= 0.80)

        sel_now = 0.6 * wc + 0.4 * ws if coverage_ok else wc
        alpha = 0.8
        sel_smoothed = sel_now if sel_prev is None else (alpha*sel_prev + (1-alpha)*sel_now)
        sel_prev = sel_smoothed
        sel = sel_smoothed

        # SAM diagnostics (epoch means)
        sam_clean = float(np.mean(train_clean_losses))
        sam_pert  = float(np.mean(train_pert_losses))
        sam_gap   = float(sam_pert - sam_clean)
        gradn     = float(np.mean(grad_norms)) if len(grad_norms) else float("nan")

        # log row (per-domain val accs + SAM metrics)
        row = {
            "epoch": ep,
            "avg": metrics_cd["avg"],
            "worst_group_cd": wc,
            "worst_group_cds": ws,
            "cds_coverage_ok": int(coverage_ok),
            "sel_now": sel_now,
            "sel_smoothed": sel_smoothed,
            # SAM-specific:
            "sam_train_loss_clean": sam_clean,
            "sam_train_loss_pert": sam_pert,
            "sam_gap": sam_gap,
            "sam_grad_norm": gradn,
        }
        for dom, acc in metrics_cd["per_domain"].items():
            row[f"acc_{dom}"] = acc
        log.append(row)

        if sel > best_sel:
            best_sel = sel
            best_ckpt = {
                "epoch": ep,
                "backbone": copy.deepcopy(backbone.state_dict()),
                "head": copy.deepcopy(head.state_dict()),
                "val_metrics_cd": metrics_cd,
                "val_metrics_cds": metrics_cds,
                "sam_gap": sam_gap,
                "sam_grad_norm": gradn,
            }

        if (ep+1) % 5 == 0:
            pd_str = "  ".join([f"{d}={row.get(f'acc_{d}', float('nan')):.3f}" for d in sorted(source_domains)])
            print(f"[SAM][ep {ep+1:02d}] val worst_cd={wc:.3f}  worst_cds={ws:.3f}  "
                  f"avg={row['avg']:.3f}  gap={sam_gap:.4f}  | {pd_str}")


    # ---- load best, evaluate on target-test, save artifacts under SAM/
    backbone.load_state_dict(best_ckpt["backbone"]); head.load_state_dict(best_ckpt["head"])

    test_metrics_cd  = evaluate({"backbone":backbone,"head":head}, dl_test, classes, group_key="CD",  return_groups=True)
    test_metrics_cds = evaluate({"backbone":backbone,"head":head}, dl_test, classes, group_key="CDS", return_groups=True)
    acc_dom_sb, acc_cds, worst_cds_test = stylebin_bars({"backbone":backbone,"head":head}, dl_test)

    # Save checkpoint + log
    save_best_checkpoint("SAM", target_domain, backbone, head, best_ckpt, log)

    # Plot curves (per-domain + SAM gap)
    plot_val_per_domain_curves("SAM", target_domain)
    _plot_sam_gap_curve("SAM", target_domain)

    # Detailed CSVs
    save_source_val_csvs("SAM", target_domain, best_ckpt["val_metrics_cd"], best_ckpt["val_metrics_cds"])
    save_target_csvs("SAM", target_domain,
                     {"backbone":backbone,"head":head}, dl_test,
                     test_metrics_cd, test_metrics_cds,
                     acc_dom_sb, acc_cds)

    result = {
        "target": target_domain,
        "best_epoch": best_ckpt["epoch"],
        "val_worst_group_cd": best_ckpt["val_metrics_cd"]["worst_group"],
        "val_avg": best_ckpt["val_metrics_cd"]["avg"],
        "test_avg": test_metrics_cd["avg"],
        "test_worst_group_cd": test_metrics_cd["worst_group"],
        "test_worst_group_cds": test_metrics_cds["worst_group"],
    }
    return result


In [None]:
def _plot_sam_gap_curve(method_name, target_domain):
    d = outdir_for(method_name, target_domain)
    log_path = os.path.join(d, "train_log.csv")
    if not os.path.isfile(log_path): return
    df = pd.read_csv(log_path)
    if not {"sam_train_loss_clean","sam_train_loss_pert","sam_gap"}.issubset(df.columns): return
    plt.figure(figsize=(6.5,4))
    plt.plot(df["epoch"], df["sam_train_loss_clean"], label="clean loss")
    plt.plot(df["epoch"], df["sam_train_loss_pert"],  label="perturbed loss")
    plt.plot(df["epoch"], df["sam_gap"],              label="gap (pert-clean)")
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.title(f"SAM losses & gap — target={target_domain}")
    plt.grid(True); plt.legend(); plt.tight_layout()
    plt.savefig(os.path.join(d, "sam_losses_gap.png"), dpi=150)
    plt.close()


In [None]:
sam_result = run_sam_for_target("Sketch")
print("\n=== SAM summary ===")
print(sam_result)



=== SAM :: target=Sketch ===
[train] per-domain #batches: {'Photo': 47, 'Art_painting': 57, 'Cartoon': 66}
[SAM][ep 05] val worst_cd=0.800  worst_cds=0.500  avg=0.948  gap=0.1009  | Art_painting=0.920  Cartoon=0.939  Photo=0.994
[SAM][ep 10] val worst_cd=0.842  worst_cds=0.500  avg=0.956  gap=0.0643  | Art_painting=0.935  Cartoon=0.965  Photo=0.970
[SAM][ep 15] val worst_cd=0.789  worst_cds=0.500  avg=0.965  gap=0.0450  | Art_painting=0.950  Cartoon=0.970  Photo=0.976
[SAM][ep 20] val worst_cd=0.842  worst_cds=0.500  avg=0.958  gap=0.0380  | Art_painting=0.930  Cartoon=0.970  Photo=0.976
[SAM][ep 25] val worst_cd=0.850  worst_cds=0.500  avg=0.966  gap=0.0345  | Art_painting=0.945  Cartoon=0.974  Photo=0.982
[SAM][ep 30] val worst_cd=0.842  worst_cds=0.500  avg=0.968  gap=0.0281  | Art_painting=0.945  Cartoon=0.978  Photo=0.982
[SAM][ep 35] val worst_cd=0.842  worst_cds=0.500  avg=0.965  gap=0.0271  | Art_painting=0.935  Cartoon=0.983  Photo=0.976
[SAM][ep 40] val worst_cd=0.880  worst

Analysis of SAM

In [None]:
ROOT = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"
ERM_TAG = "ERM_2"
SAM_TAG = "SAM"

def outp(*xs): print(*xs)

def read_metrics(method):
    p = os.path.join(ROOT, method, TARGET, "metrics.csv")
    if not os.path.isfile(p): return None
    m = pd.read_csv(p).set_index("metric")["value"].to_dict()
    return {"method": method, "avg": m.get("overall_avg", np.nan),
            "worst_cd": m.get("worst_group_cd", np.nan),
            "worst_cds": m.get("worst_group_cds", np.nan)}

def read_df(method, name):
    p = os.path.join(ROOT, method, TARGET, name)
    return pd.read_csv(p) if os.path.isfile(p) else pd.DataFrame()

def savefig(path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    plt.tight_layout(); plt.savefig(path, dpi=150); plt.close()


In [None]:
summary = pd.DataFrame([read_metrics(ERM_TAG), read_metrics(SAM_TAG)]).set_index("method")
outp("=== Sketch target: headline metrics ===")
display(summary)

# Bar plot (avg vs worsts)
plt.figure(figsize=(5.5,3.3))
x = np.arange(len(summary))
w = 0.28
plt.bar(x-w, summary["avg"], width=w, label="avg")
plt.bar(x+0.0, summary["worst_cd"], width=w, label="worst_cd")
plt.bar(x+w, summary["worst_cds"], width=w, label="worst_cds")
plt.xticks(x, summary.index); plt.ylim(0,1); plt.ylabel("Accuracy")
plt.title("Sketch — ERM_2 vs SAM (avg & tails)")
plt.legend(); savefig(os.path.join(ROOT, "Sketch_ERM2_vs_SAM_overall.png"))


=== Sketch target: headline metrics ===


Unnamed: 0_level_0,avg,worst_cd,worst_cds
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ERM_2,0.768389,0.485751,0.442396
SAM,0.798167,0.518135,0.488479


In [None]:
erm_sv = read_df(ERM_TAG, "val_source_per_domain.csv").rename(columns={"acc":"acc_ERM"})
sam_sv = read_df(SAM_TAG, "val_source_per_domain.csv").rename(columns={"acc":"acc_SAM"})
sv = erm_sv.merge(sam_sv, on="domain", how="outer").sort_values("domain")
outp("=== Source-val per-domain (best epoch selected by your selector) ===")
display(sv)

# Side-by-side bars
plt.figure(figsize=(6,3.3))
x = np.arange(len(sv))
plt.bar(x-0.18, sv["acc_ERM"], width=0.36, label="ERM_2")
plt.bar(x+0.18, sv["acc_SAM"], width=0.36, label="SAM")
plt.xticks(x, sv["domain"]); plt.ylim(0,1); plt.ylabel("Acc")
plt.title("Source-val per-domain — ERM_2 vs SAM")
plt.legend(); savefig(os.path.join(ROOT, "SourceVal_domains_ERM2_vs_SAM.png"))


=== Source-val per-domain (best epoch selected by your selector) ===


Unnamed: 0,domain,acc_ERM,acc_SAM
0,Art_painting,0.955224,0.945274
1,Cartoon,0.973913,0.982609
2,Photo,0.981707,0.981707


In [None]:
erm_cds = read_df(ERM_TAG, "groups_cds.csv").rename(columns={"acc":"acc_ERM"})
sam_cds = read_df(SAM_TAG, "groups_cds.csv").rename(columns={"acc":"acc_SAM"})
key = ["class","domain","stylebin"]
delta_cds = erm_cds.merge(sam_cds, on=key, how="inner")
if not delta_cds.empty:
    delta_cds["delta"] = delta_cds["acc_SAM"] - delta_cds["acc_ERM"]
    delta_cds["group"] = delta_cds["class"] + " @ " + delta_cds["domain"] + " / " + delta_cds["stylebin"].astype(str)
    delta_cds.sort_values("delta").to_csv(os.path.join(ROOT, "SAM_minus_ERM2_Sketch_CDS_deltas.csv"), index=False)
    outp("Saved Δ table:", os.path.join(ROOT, "SAM_minus_ERM2_Sketch_CDS_deltas.csv"))

    # top/bottom movers
    topk = delta_cds.sort_values("delta", ascending=False).head(10)
    botk = delta_cds.sort_values("delta", ascending=True).head(10)
    plt.figure(figsize=(8,4)); plt.barh(topk["group"], topk["delta"]); plt.axvline(0,color="k")
    plt.title("Top Δ (SAM − ERM_2) by CDS group — Sketch"); savefig(os.path.join(ROOT,"SAM_vs_ERM2_Sketch_CDS_top.png"))
    plt.figure(figsize=(8,4)); plt.barh(botk["group"], botk["delta"]); plt.axvline(0,color="k")
    plt.title("Bottom Δ (SAM − ERM_2) by CDS group — Sketch"); savefig(os.path.join(ROOT,"SAM_vs_ERM2_Sketch_CDS_bottom.png"))

    # (Optional) specific group example if you want to cite it:
    q = delta_cds.query("class=='dog' and domain=='Sketch' and stylebin=='high'")
    if len(q):
        row = q.iloc[0]
        outp(f"dog @ Sketch / high — ERM_2={row['acc_ERM']:.3f}  SAM={row['acc_SAM']:.3f}  Δ={row['delta']:.3f}")


Saved Δ table: /content/drive/MyDrive/DG_PACS/SAM_minus_ERM2_Sketch_CDS_deltas.csv


SyntaxError: Python keyword not valid identifier in numexpr query (<unknown>, line 1)

Flatness probe: loss vs parameter perturbation

In [None]:
# ---- Build eval loaders (reuse same split code as training) ----
train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, TARGET, val_frac=0.1)
# Rebuild style bins (thresholds derived on sources as before)
style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
style_map_val  = apply_stylebins_to_indices(DATA_ROOT, val_idx, thresholds)
style_map_test = apply_stylebins_to_indices(DATA_ROOT, test_idx, thresholds)

ds_val  = PACS(DATA_ROOT, source_domains, split_indices=val_idx,  transform=eval_tf, stylebin_map={**style_map_train,**style_map_val})
ds_test = PACS(DATA_ROOT, [TARGET],       split_indices=test_idx, transform=eval_tf, stylebin_map={**style_map_train,**style_map_test})
dl_val  = DataLoader(ds_val,  batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)
dl_test = DataLoader(ds_test, batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)

# ---- Load best checkpoints ----
def load_model_from_ckpt(method):
    ckpt = torch.load(os.path.join(ROOT, method, TARGET, "best_model.pt"), map_location="cpu")
    backbone, head = build_model(n_classes=len(classes))
    backbone.load_state_dict(ckpt["backbone"]); head.load_state_dict(ckpt["head"])
    backbone, head = backbone.to(DEVICE), head.to(DEVICE)
    backbone.eval(); head.eval()
    return backbone, head

erm_backbone, erm_head = load_model_from_ckpt(ERM_TAG)
sam_backbone, sam_head = load_model_from_ckpt(SAM_TAG)

ce = nn.CrossEntropyLoss()

# ---- Parameter vector helpers ----
def pack_params(modules):
    vec = []
    for m in modules:
        for p in m.parameters():
            vec.append(p.detach().view(-1))
    return torch.cat(vec)

def add_inplace(modules, delta_vec):
    # add delta_vec split across parameters (in-place)
    idx = 0
    with torch.no_grad():
        for m in modules:
            for p in m.parameters():
                n = p.numel()
                p.add_(delta_vec[idx:idx+n].view_as(p))
                idx += n

def eval_loss(backbone, head, loader, max_batches=10):
    backbone.eval(); head.eval()
    total, n = 0.0, 0
    with torch.no_grad():
        for b, (x,y,meta) in enumerate(loader):
            x = x.to(DEVICE); y = y.to(DEVICE)
            logits, _ = forward_logits(backbone, head, x)
            loss = ce(logits, y).item()
            total += loss; n += 1
            if b+1 >= max_batches: break   # keep it light
    return total / max(1,n)

def flatness_curve(backbone, head, loader, epsilons, max_batches=10, seed=0):
    torch.manual_seed(seed)
    # random direction v with ||v||_2 = 1 over (backbone+head)
    theta = pack_params([backbone, head]).to(DEVICE)
    v = torch.randn_like(theta); v = v / (v.norm(p=2) + 1e-12)
    losses = []
    for eps in epsilons:
        # perturb
        add_inplace([backbone, head], eps * v)
        # measure
        L = eval_loss(backbone, head, loader, max_batches=max_batches)
        losses.append(L)
        # restore
        add_inplace([backbone, head], -eps * v)
    return np.array(losses)

# ---- Sweep epsilons and plot (source-val & target-test) ----
eps = np.linspace(-0.2, 0.2, 13)  # magnitude sweep; adjust if needed
for name, (bb, hh) in {"ERM_2":(erm_backbone,erm_head), "SAM":(sam_backbone,sam_head)}.items():
    # clone modules so perturbations don't affect the other plots
    bb_v = copy.deepcopy(bb).to(DEVICE).eval()
    hh_v = copy.deepcopy(hh).to(DEVICE).eval()
    losses_val  = flatness_curve(bb_v, hh_v, dl_val,  eps, max_batches=12, seed=0)
    bb_t = copy.deepcopy(bb).to(DEVICE).eval()
    hh_t = copy.deepcopy(hh).to(DEVICE).eval()
    losses_test = flatness_curve(bb_t, hh_t, dl_test, eps, max_batches=12, seed=0)

    # save curves
    dfv = pd.DataFrame({"epsilon": eps, "loss": losses_val, "split":"source-val", "method":name})
    dft = pd.DataFrame({"epsilon": eps, "loss": losses_test,"split":"target-test","method":name})
    out = pd.concat([dfv,dft], ignore_index=True)
    out.to_csv(os.path.join(ROOT, f"flatness_curve_{name}_{TARGET}.csv"), index=False)

# Plot together
def plot_flatness(split):
    plt.figure(figsize=(6,3.5))
    for name in ["ERM_2","SAM"]:
        df = pd.read_csv(os.path.join(ROOT, f"flatness_curve_{name}_{TARGET}.csv"))
        sub = df[df["split"].eq(split)]
        plt.plot(sub["epsilon"], sub["loss"], label=name)
    plt.xlabel("Parameter perturbation ε (unit L2 dir)")
    plt.ylabel("Cross-entropy loss")
    plt.title(f"Flatness probe — {split} — {TARGET}")
    plt.grid(True); plt.legend()
    savefig(os.path.join(ROOT, f"Flatness_{split}_{TARGET}_ERM2_vs_SAM.png"))

plot_flatness("source-val")
plot_flatness("target-test")
outp("Saved flatness figures to project root.")


Saved flatness figures to project root.


# GENERAL ANALYSIS of all methods

In [None]:
# Define the new folder path
save_dir = os.path.join("DG_PACS", "General_Analysis")

# Create it if it doesn't exist
os.makedirs(save_dir, exist_ok=True)

# Redirect all save paths to this folder
def save_path(filename):
    """Return full path inside General_Analysis folder."""
    return os.path.join(save_dir, filename)

Master Comparision Target values only


In [None]:
import os, pandas as pd, numpy as np, matplotlib.pyplot as plt

ROOT = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"
METHODS = ["ERM_2", "IRMv1", "GroupDRO", "SAM"]

def load_metrics(method):
    p = os.path.join(ROOT, method, TARGET, "metrics.csv")
    if not os.path.isfile(p): return None
    m = pd.read_csv(p).set_index("metric")["value"].to_dict()
    row = {"method": method,
           "avg": m.get("overall_avg", np.nan),
           "worst_cd": m.get("worst_group_cd", np.nan),
           "worst_cds": m.get("worst_group_cds", np.nan),
           "irm_penalty_final": m.get("irm_penalty_final", np.nan)}
    return row

summary = pd.DataFrame([load_metrics(m) for m in METHODS if load_metrics(m) is not None])
# Use reindex to ensure order and handle missing methods safely
summary = summary.set_index("method").reindex(METHODS)
print("=== Sketch target — overall comparison ===")
display(summary)

# Bars
plt.figure(figsize=(6,3.5))
x = np.arange(len(summary))
w = 0.26
plt.bar(x-w,   summary["avg"],       width=w, label="avg")
plt.bar(x+0.0, summary["worst_cd"],  width=w, label="worst_cd")
plt.bar(x+w,   summary["worst_cds"], width=w, label="worst_cds")
plt.ylim(0,1); plt.xticks(x, summary.index); plt.ylabel("Accuracy")
plt.title("Sketch — ERM_2 vs IRMv1 vs GroupDRO vs SAM")
plt.legend(); plt.tight_layout()
plt.savefig(save_path("AllMethods_Sketch_overall.png"))
plt.close()

=== Sketch target — overall comparison ===


Unnamed: 0_level_0,avg,worst_cd,worst_cds,irm_penalty_final
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ERM_2,0.768389,0.485751,0.442396,
IRMv1,0.782133,0.428756,0.391705,
GroupDRO,0.800458,0.625648,0.5,
SAM,0.798167,0.518135,0.488479,


Matser table, val vs test

In [None]:
ROOT = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"
METHODS = ["ERM_2","IRMv1","GroupDRO","SAM"]

def load_test_metrics(method):
    m = pd.read_csv(os.path.join(ROOT, method, TARGET, "metrics.csv")).set_index("metric")["value"]
    return {"method": method,
            "split": "target-test",
            "avg": m.get("overall_avg", np.nan),
            "worst_cd": m.get("worst_group_cd", np.nan),
            "worst_cds": m.get("worst_group_cds", np.nan)}

def load_source_val_metrics(method):
    # build from saved source-val CSVs at best epoch
    p = os.path.join(ROOT, method, TARGET)
    cd = pd.read_csv(os.path.join(p, "val_source_CD.csv"))      # columns: class,domain,correct,total,acc
    cds = pd.read_csv(os.path.join(p, "val_source_CDS.csv"))    # columns: class,domain,stylebin,correct,total,acc
    # overall avg on source-val (macro over CD groups weighted by counts)
    avg = (cd["correct"].sum() / cd["total"].sum()) if cd["total"].sum()>0 else np.nan
    worst_cd  = cd["acc"].min() if len(cd) else np.nan
    worst_cds = cds["acc"].min() if len(cds) else np.nan
    return {"method": method, "split": "source-val", "avg": avg, "worst_cd": worst_cd, "worst_cds": worst_cds}

rows = []
for m in METHODS:
    if os.path.exists(os.path.join(ROOT,m,TARGET,"metrics.csv")):
        rows.append(load_test_metrics(m))
    if os.path.exists(os.path.join(ROOT,m,TARGET,"val_source_CD.csv")):
        rows.append(load_source_val_metrics(m))

both = pd.DataFrame(rows)
print("=== Source-val vs Target-test (Sketch) ===")
display(both.pivot(index="method", columns="split", values=["avg","worst_cd","worst_cds"]))

# Optional: small bar chart comparing splits for each metric
for metric in ["avg","worst_cd","worst_cds"]:
    sub = both.pivot(index="method", columns="split", values=metric).reindex(METHODS)
    sub.plot(kind="bar", figsize=(6,3.5))
    plt.ylim(0,1); plt.title(f"{metric} — source-val vs target-test (Sketch)")
    plt.tight_layout()
    plt.savefig(save_path(f"Sketch_{metric}_val_vs_test.png"))
    plt.close()


=== Source-val vs Target-test (Sketch) ===


Unnamed: 0_level_0,avg,avg,worst_cd,worst_cd,worst_cds,worst_cds
split,source-val,target-test,source-val,target-test,source-val,target-test
method,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
ERM_2,0.969748,0.768389,0.891892,0.485751,0.666667,0.442396
GroupDRO,0.952941,0.800458,0.842105,0.625648,0.5,0.5
IRMv1,0.963025,0.782133,0.891892,0.428756,0.666667,0.391705
SAM,0.969748,0.798167,0.88,0.518135,0.666667,0.488479


In [None]:
import os, pandas as pd

ROOT   = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"                   # fixed target
METHOD = "GroupDRO"                 # change as needed (ERM_2, IRMv1, SAM, ...)

def _load(method, fname):
    p = os.path.join(ROOT, method, TARGET, fname)
    return pd.read_csv(p) if os.path.isfile(p) else pd.DataFrame()

def summarize_counts(method=METHOD):
    out_dir = os.path.join(ROOT, method, TARGET)
    os.makedirs(out_dir, exist_ok=True)

    # ---------- TARGET-TEST ----------
    cds_t = _load(method, "groups_cds.csv")    # cols: class, domain, stylebin, correct, total, acc
    cd_t  = _load(method, "groups_cd.csv")     # cols: class, domain, correct, total, acc

    if not cds_t.empty:
        # per-group counts (CDS)
        counts_cds_t = cds_t[["class","domain","stylebin","total"]].sort_values("total")
        counts_cds_t.to_csv(save_path("counts_cds_target.csv"), index=False)

        # quick summaries
        per_class_cds   = counts_cds_t.groupby("class", as_index=False)["total"].sum().sort_values("total")
        per_domain_cds  = counts_cds_t.groupby("domain", as_index=False)["total"].sum().sort_values("total")
        per_style_cds   = counts_cds_t.groupby("stylebin", as_index=False)["total"].sum().sort_values("total")

        # a pivot you can eyeball (classes × stylebin) for the TARGET domain
        pivot_cds = counts_cds_t.pivot_table(index="class", columns="stylebin", values="total", aggfunc="sum").fillna(0).astype(int)

        print(f"\n[TARGET-TEST | {method}] CDS: per-group counts (ascending)")
        display(counts_cds_t.head(10))
        print("\n[TARGET-TEST] CDS: totals per class")
        display(per_class_cds)
        print("\n[TARGET-TEST] CDS: totals per stylebin")
        display(per_style_cds)
        print("\n[TARGET-TEST] CDS: class × stylebin count pivot (Sketch)")
        display(pivot_cds)

        # save summaries
        per_class_cds.to_csv(save_path("counts_cds_target_per_class.csv"), index=False)
        per_domain_cds.to_csv(save_path("counts_cds_target_per_domain.csv"), index=False)
        per_style_cds.to_csv(save_path("counts_cds_target_per_stylebin.csv"), index=False)
        pivot_cds.to_csv(save_path("counts_cds_target_pivot_class_stylebin.csv"))

        # quantiles to pick a low-support cutoff
        qs = counts_cds_t["total"].quantile([0.1, 0.25, 0.5]).round(1)
        print("\n[TARGET-TEST] CDS count quantiles (10/25/50%):")
        print(qs.to_dict())

    if not cd_t.empty:
        counts_cd_t = cd_t[["class","domain","total"]].sort_values("total")
        counts_cd_t.to_csv(save_path("counts_cd_target.csv"), index=False)

        per_class_cd  = counts_cd_t.groupby("class", as_index=False)["total"].sum().sort_values("total")
        per_domain_cd = counts_cd_t.groupby("domain", as_index=False)["total"].sum().sort_values("total")
        pivot_cd = counts_cd_t.pivot_table(index="class", columns="domain", values="total", aggfunc="sum").fillna(0).astype(int)

        print(f"\n[TARGET-TEST | {method}] CD: per-group counts (ascending)")
        display(counts_cd_t.head(10))
        print("\n[TARGET-TEST] CD: class × domain count pivot")
        display(pivot_cd)

        per_class_cd.to_csv(save_path("counts_cd_target_per_class.csv"), index=False)
        per_domain_cd.to_csv(save_path("counts_cd_target_per_domain.csv"), index=False)
        pivot_cd.to_csv(save_path("counts_cd_target_pivot_class_domain.csv"))

    # ---------- SOURCE-VAL (ALL SOURCES) ----------
    cds_v = _load(method, "val_source_CDS.csv")  # class, domain, stylebin, correct, total, acc
    cd_v  = _load(method, "val_source_CD.csv")   # class, domain, correct, total, acc

    if not cds_v.empty:
        # per-group counts (CDS) on source-val
        counts_cds_v = cds_v[["class","domain","stylebin","total"]].sort_values("total")
        counts_cds_v.to_csv(save_path("counts_cds_sourceval.csv"), index=False)

        # aggregate across source domains to (class, stylebin)
        agg_val_cds = (cds_v.groupby(["class","stylebin"], as_index=False)["total"].sum()
                          .sort_values("total"))
        agg_val_cds.to_csv(save_path("counts_cds_sourceval_by_class_style.csv"), index=False)

        print(f"\n[SOURCE-VAL | {method}] CDS: per-group counts (ascending)")
        display(counts_cds_v.head(10))
        print("\n[SOURCE-VAL] CDS: totals aggregated over sources (class × stylebin)")
        display(agg_val_cds.head(10))

    if not cd_v.empty:
        counts_cd_v = cd_v[["class","domain","total"]].sort_values("total")
        counts_cd_v.to_csv(save_path("counts_cd_sourceval.csv"), index=False)

        # aggregates
        per_class_cd_v  = counts_cd_v.groupby("class", as_index=False)["total"].sum().sort_values("total")
        per_domain_cd_v = counts_cd_v.groupby("domain", as_index=False)["total"].sum().sort_values("total")

        print(f"\n[SOURCE-VAL | {method}] CD: per-group counts (ascending)")
        display(counts_cd_v.head(10))
        print("\n[SOURCE-VAL] CD: totals per class across sources")
        display(per_class_cd_v)

        per_class_cd_v.to_csv(save_path("counts_cd_sourceval_per_class.csv"), index=False)
        per_domain_cd_v.to_csv(save_path("counts_cd_sourceval_per_domain.csv"), index=False)

# Run (change METHOD to see other methods)
summarize_counts("GroupDRO")
# summarize_counts("ERM_2")
# summarize_counts("IRMv1")
# summarize_counts("SAM")



[TARGET-TEST | GroupDRO] CDS: per-group counts (ascending)


Unnamed: 0,class,domain,stylebin,total
16,house,Sketch,low,2
17,house,Sketch,mid,12
19,person,Sketch,low,27
20,person,Sketch,mid,36
15,house,Sketch,high,66
9,guitar,Sketch,high,79
18,person,Sketch,high,97
11,guitar,Sketch,mid,122
4,elephant,Sketch,low,144
1,dog,Sketch,low,192



[TARGET-TEST] CDS: totals per class


Unnamed: 0,class,total
5,house,80
6,person,160
3,guitar,608
1,elephant,740
2,giraffe,753
0,dog,772
4,horse,816



[TARGET-TEST] CDS: totals per stylebin


Unnamed: 0,stylebin,total
2,mid,1296
1,low,1297
0,high,1336



[TARGET-TEST] CDS: class × stylebin count pivot (Sketch)


stylebin,high,low,mid
class,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
dog,217,192,363
elephant,317,144,279
giraffe,201,327,225
guitar,79,407,122
horse,359,198,259
house,66,2,12
person,97,27,36



[TARGET-TEST] CDS count quantiles (10/25/50%):
{0.1: 27.0, 0.25: 79.0, 0.5: 198.0}

[TARGET-TEST | GroupDRO] CD: per-group counts (ascending)


Unnamed: 0,class,domain,total
5,house,Sketch,80
6,person,Sketch,160
3,guitar,Sketch,608
1,elephant,Sketch,740
2,giraffe,Sketch,753
0,dog,Sketch,772
4,horse,Sketch,816



[TARGET-TEST] CD: class × domain count pivot


domain,Sketch
class,Unnamed: 1_level_1
dog,772
elephant,740
giraffe,753
guitar,608
horse,816
house,80
person,160



[SOURCE-VAL | GroupDRO] CDS: per-group counts (ascending)


Unnamed: 0,class,domain,stylebin,total
16,house,Art_painting,low,1
12,horse,Art_painting,high,2
10,guitar,Art_painting,low,2
30,guitar,Cartoon,high,3
58,house,Photo,low,3
37,house,Cartoon,low,4
32,guitar,Cartoon,mid,4
49,giraffe,Photo,low,4
53,guitar,Photo,mid,4
42,dog,Photo,high,5



[SOURCE-VAL] CDS: totals aggregated over sources (class × stylebin)


Unnamed: 0,class,stylebin,total
16,house,low,8
10,guitar,low,15
11,guitar,mid,15
9,guitar,high,19
12,horse,high,20
3,elephant,high,22
8,giraffe,mid,23
13,horse,low,24
17,house,mid,26
6,giraffe,high,27



[SOURCE-VAL | GroupDRO] CD: per-group counts (ascending)


Unnamed: 0,class,domain,total
10,guitar,Cartoon,13
3,guitar,Art_painting,18
14,dog,Photo,18
16,giraffe,Photo,18
17,guitar,Photo,18
18,horse,Photo,19
15,elephant,Photo,20
4,horse,Art_painting,20
1,elephant,Art_painting,25
19,house,Photo,28



[SOURCE-VAL] CD: totals per class across sources


Unnamed: 0,class,total
3,guitar,49
4,horse,71
2,giraffe,80
5,house,85
1,elephant,90
0,dog,93
6,person,127


Confusion Matrix

In [None]:
import numpy as np, matplotlib.pyplot as plt, seaborn as sns
from sklearn.metrics import confusion_matrix

def predict_all(model, loader, n_classes):
    y_true, y_pred = [], []
    backbone, head = model["backbone"], model["head"]
    backbone.eval(); head.eval()
    with torch.no_grad():
        for x,y,meta in loader:
            x = x.to(DEVICE)
            logits, _ = forward_logits(backbone, head, x)
            pred = logits.argmax(1).cpu().numpy()
            y_true.append(y.numpy()); y_pred.append(pred)
    y_true = np.concatenate(y_true); y_pred = np.concatenate(y_pred)
    return y_true, y_pred

def plot_conf_mat(method, class_names):
    # rebuild test loader as in your runs
    train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, "Sketch", val_frac=0.1)
    style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
    style_map_test = apply_stylebins_to_indices(DATA_ROOT, test_idx, thresholds)
    ds_test = PACS(DATA_ROOT, ["Sketch"], split_indices=test_idx, transform=eval_tf, stylebin_map={**style_map_train, **style_map_test})
    dl_test = DataLoader(ds_test, batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)

    # load checkpoint
    ckpt = torch.load(os.path.join(ROOT, method, "Sketch", "best_model.pt"), map_location=DEVICE)
    bb, hd = build_model(n_classes=len(classes))
    bb.load_state_dict(ckpt["backbone"]); hd.load_state_dict(ckpt["head"])
    bb, hd = bb.to(DEVICE), hd.to(DEVICE)

    y_true, y_pred = predict_all({"backbone":bb,"head":hd}, dl_test, len(classes))
    cm = confusion_matrix(y_true, y_pred, labels=list(range(len(classes))))
    cmn = cm / cm.sum(axis=1, keepdims=True).clip(min=1)
    plt.figure(figsize=(5.5,4.8))
    sns.heatmap(cmn, annot=False, cmap="Blues", xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Pred"); plt.ylabel("True"); plt.title(f"Confusion (Sketch) — {method}")
    plt.tight_layout()
    plt.savefig(os.path.join(ROOT, f"Confusion_{method}_Sketch.png"), dpi=160); plt.close()

# call for your four methods; provide your class list (7 PACS classes)
plot_conf_mat("ERM_2", classes)
plot_conf_mat("IRMv1", classes)
plot_conf_mat("GroupDRO", classes)
plot_conf_mat("SAM", classes)


Calibration on Target: ECE + reliability curves

In [None]:
def ece_and_reliability(method):
    # loader same as above
    train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, "Sketch", val_frac=0.1)
    style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
    style_map_test = apply_stylebins_to_indices(DATA_ROOT, test_idx, thresholds)
    ds_test = PACS(DATA_ROOT, ["Sketch"], split_indices=test_idx, transform=eval_tf, stylebin_map={**style_map_train, **style_map_test})
    dl_test = DataLoader(ds_test, batch_size=256, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)

    ckpt = torch.load(os.path.join(ROOT, method, "Sketch", "best_model.pt"), map_location=DEVICE)
    bb, hd = build_model(n_classes=len(classes)); bb.load_state_dict(ckpt["backbone"]); hd.load_state_dict(ckpt["head"])
    bb, hd = bb.to(DEVICE), hd.to(DEVICE); bb.eval(); hd.eval()

    probs_all, y_all = [], []
    with torch.no_grad():
        for x,y,meta in dl_test:
            x = x.to(DEVICE)
            logits, _ = forward_logits(bb, hd, x)
            probs = torch.softmax(logits, dim=1).cpu().numpy()
            probs_all.append(probs.max(axis=1))
            y_all.append((logits.argmax(1).cpu().numpy() == y.numpy()).astype(np.float32))
    conf = np.concatenate(probs_all); acc = np.concatenate(y_all)

    # reliability bins
    B=15
    bins = np.linspace(0,1,B+1)
    inds = np.digitize(conf, bins) - 1
    acc_bin, conf_bin, counts = [], [], []
    ece = 0.0
    for b in range(B):
        m = inds==b
        if m.sum()==0:
            acc_bin.append(np.nan); conf_bin.append((bins[b]+bins[b+1])/2); counts.append(0); continue
        acc_b = acc[m].mean(); conf_b = conf[m].mean()
        acc_bin.append(acc_b); conf_bin.append(conf_b); counts.append(m.sum())
        ece += (m.sum()/len(conf))*abs(acc_b - conf_b)

    # plot
    plt.figure(figsize=(4.5,3.5))
    plt.plot([0,1],[0,1], 'k--', lw=1)
    plt.plot(conf_bin, acc_bin, marker='o')
    plt.title(f"Reliability (Sketch) — {method}\nECE={ece:.3f}")
    plt.xlabel("Confidence"); plt.ylabel("Accuracy"); plt.ylim(0,1); plt.xlim(0,1)
    plt.grid(True, alpha=0.3); plt.tight_layout()
    plt.savefig(os.path.join(ROOT, f"Reliability_{method}_Sketch.png"), dpi=150); plt.close()
    return ece

ece_rows = []
for m in ["ERM_2","IRMv1","GroupDRO","SAM"]:
    e = ece_and_reliability(m)
    ece_rows.append({"method":m, "ECE_target":e})
pd.DataFrame(ece_rows).to_csv(os.path.join(ROOT, "ECE_on_Sketch.csv"), index=False)
print("Saved: ECE_on_Sketch.csv")


apparent mastery on sources that collapses on the target

Aggregate source-val accuracy by (class, stylebin) across all source domains (weighted by counts).
Compare to target-test (Sketch) accuracy for the same (class, stylebin).
Compute drop = source-val acc − target-test acc.
Rank the top-5 biggest drops per method. These are your most suspicious CDS groups.

One barh chart per method: “Top-5 CDS drops (val→test)”.

In [None]:
import os, pandas as pd, numpy as np, matplotlib.pyplot as pl

ROOT   = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"
METHODS = ["ERM_2","IRMv1","GroupDRO","SAM"]

MIN_VAL_COUNT  = 20   # skip tiny groups on source-val
MIN_TEST_COUNT = 20   # skip tiny groups on target-test

def _load_csv(method, name):
    p = os.path.join(ROOT, method, TARGET, name)
    return pd.read_csv(p) if os.path.isfile(p) else pd.DataFrame()

def compute_cds_drop_table(method):
    # source-val CDS (all source domains), columns: class, domain, stylebin, correct, total, acc
    val_cds = _load_csv(method, "val_source_CDS.csv")
    # target-test CDS (Sketch only), columns: class, domain, stylebin, correct, total, acc
    test_cds = _load_csv(method, "groups_cds.csv")

    if val_cds.empty or test_cds.empty:
        return pd.DataFrame()

    # Aggregate source-val across domains → per (class, stylebin)
    val_agg = (val_cds.groupby(["class","stylebin"], as_index=False)
                      .agg(correct=("correct","sum"), total=("total","sum")))
    val_agg["acc_val"] = val_agg["correct"] / val_agg["total"].clip(lower=1)

    # Target-test: restrict to the target domain and aggregate per (class, stylebin)
    test_t = test_cds[test_cds["domain"].eq(TARGET)].copy()
    test_agg = (test_t.groupby(["class","stylebin"], as_index=False)
                      .agg(correct=("correct","sum"), total=("total","sum")))
    test_agg["acc_test"] = test_agg["correct"] / test_agg["total"].clip(lower=1)

    # Merge and compute drop
    M = val_agg.merge(test_agg, on=["class","stylebin"], how="inner", suffixes=("_val","_test"))
    # Filter out low-count groups on either side
    M = M[(M["total_val"]  >= MIN_VAL_COUNT) & (M["total_test"] >= MIN_TEST_COUNT)].copy()
    if M.empty:
        return M
    M["drop"] = M["acc_val"] - M["acc_test"]
    M["group"] = M["class"] + " / " + M["stylebin"].astype(str)

    # attach method and target for bookkeeping
    M["method"] = method
    M["target"] = TARGET
    return M.sort_values("drop", ascending=False)

def plot_top_drops(df, method, k=5):
    if df.empty:
        print(f"[{method}] No CDS drops to plot (after count filters).")
        return
    topk = df.head(k).copy()
    plt.figure(figsize=(7,3.6))
    plt.barh(topk["group"], topk["drop"])
    plt.gca().invert_yaxis()
    for i,(val, tst) in enumerate(zip(topk["acc_val"], topk["acc_test"])):
        plt.text(topk["drop"].iloc[i]*0.5, i, f"{val:.2f}→{tst:.2f}", va='center', ha='center', fontsize=9)
    plt.xlabel("Accuracy drop (source-val → target-test)")
    plt.title(f"Top-{k} CDS drops — {method} (target={TARGET})")
    plt.tight_layout()
    outp = os.path.join(ROOT, f"TopDrops_{method}_{TARGET}.png")
    plt.savefig(outp, dpi=160); plt.close()
    print("Saved:", outp)

# Run for all methods, save tables and plots
all_rows = []
for m in METHODS:
    T = compute_cds_drop_table(m)
    if not T.empty:
        all_rows.append(T)
        # per-method CSV
        out_csv = os.path.join(ROOT, f"CDS_Drops_{m}_{TARGET}.csv")
        T.to_csv(out_csv, index=False)
        print("Saved:", out_csv)
        plot_top_drops(T, m, k=5)

# Combined view across methods (optional)
if all_rows:
    ALL = pd.concat(all_rows, ignore_index=True)
    # Heatmap: classes × stylebin for each method (drop)
    for m in METHODS:
        sub = ALL[ALL["method"].eq(m)]
        if sub.empty:
            continue
        piv = sub.pivot_table(index="class", columns="stylebin", values="drop", aggfunc="mean")
        plt.figure(figsize=(5.2,3.8))
        im = plt.imshow(piv.values, aspect='auto')
        plt.xticks(range(len(piv.columns)), piv.columns)
        plt.yticks(range(len(piv.index)), piv.index)
        plt.colorbar(im, fraction=0.046, pad=0.04, label="drop (val→test)")
        plt.title(f"CDS drop heatmap — {m} (target={TARGET})")
        plt.tight_layout()
        outp = os.path.join(ROOT, f"CDS_DropHeatmap_{m}_{TARGET}.png")
        plt.savefig(outp, dpi=160); plt.close()
        print("Saved:", outp)

    # Also save a combined CSV of just the top-5 per method
    tops = (ALL.sort_values(["method","drop"], ascending=[True, False])
              .groupby("method").head(5))
    tops.to_csv(os.path.join(ROOT, f"CDS_TopDrops_AllMethods_{TARGET}.csv"), index=False)
    print("Saved:", os.path.join(ROOT, f"CDS_TopDrops_AllMethods_{TARGET}.csv"))


Saved: /content/drive/MyDrive/DG_PACS/CDS_Drops_ERM_2_Sketch.csv
Saved: /content/drive/MyDrive/DG_PACS/TopDrops_ERM_2_Sketch.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_Drops_IRMv1_Sketch.csv
Saved: /content/drive/MyDrive/DG_PACS/TopDrops_IRMv1_Sketch.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_Drops_GroupDRO_Sketch.csv
Saved: /content/drive/MyDrive/DG_PACS/TopDrops_GroupDRO_Sketch.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_Drops_SAM_Sketch.csv
Saved: /content/drive/MyDrive/DG_PACS/TopDrops_SAM_Sketch.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DropHeatmap_ERM_2_Sketch.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DropHeatmap_IRMv1_Sketch.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DropHeatmap_GroupDRO_Sketch.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DropHeatmap_SAM_Sketch.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_TopDrops_AllMethods_Sketch.csv


for each CDS group (class, stylebin) on Sketch, how much each method reduces the ERM_2 drop:

For a method M, we compute:

drop_val→test = acc_val_source − acc_test_target

Δ-reduction vs ERM_2 = drop_ERM2 − drop_M

Positive = method reduced the drop (good).

Negative = method’s drop is worse than ERM_2 (bad).

It produces:

A tidy CSV with reductions for IRMv1, GroupDRO, SAM vs ERM_2

Top/bottom bar charts per method (biggest improvements / regressions)

An optional small heatmap of reductions (class × stylebin) per method

In [None]:
import os, pandas as pd, numpy as np, matplotlib.pyplot as plt

ROOT   = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"
BASE   = "ERM_2"
OTHERS = ["IRMv1","GroupDRO","SAM"]

MIN_VAL_COUNT  = 20   # filter tiny CDS groups on source-val
MIN_TEST_COUNT = 20   # filter tiny CDS groups on target-test

def _load_csv(method, name):
    p = os.path.join(ROOT, method, TARGET, name)
    return pd.read_csv(p) if os.path.isfile(p) else pd.DataFrame()

def cds_drop_table(method):
    """Compute per-(class, stylebin) drop: source-val acc minus target-test acc, with count filters."""
    val_cds = _load_csv(method, "val_source_CDS.csv")  # class, domain, stylebin, correct, total, acc
    test_cds = _load_csv(method, "groups_cds.csv")     # class, domain, stylebin, correct, total, acc

    if val_cds.empty or test_cds.empty:
        return pd.DataFrame()

    # aggregate source-val across *source* domains → (class, stylebin)
    val_agg = (val_cds.groupby(["class","stylebin"], as_index=False)
                      .agg(correct=("correct","sum"), total=("total","sum")))
    val_agg["acc_val"] = val_agg["correct"] / val_agg["total"].clip(lower=1)

    # target-test: restrict to Sketch and aggregate to (class, stylebin)
    test_t = test_cds[test_cds["domain"].eq(TARGET)].copy()
    test_agg = (test_t.groupby(["class","stylebin"], as_index=False)
                      .agg(correct=("correct","sum"), total=("total","sum")))
    test_agg["acc_test"] = test_agg["correct"] / test_agg["total"].clip(lower=1)

    M = val_agg.merge(test_agg, on=["class","stylebin"], how="inner", suffixes=("_val","_test"))
    M = M[(M["total_val"] >= MIN_VAL_COUNT) & (M["total_test"] >= MIN_TEST_COUNT)].copy()
    if M.empty:
        return M

    M["drop"]  = M["acc_val"] - M["acc_test"]
    M["group"] = M["class"] + " / " + M["stylebin"].astype(str)
    M["method"] = method
    M["target"] = TARGET
    return M

def plot_reductions(df_red, method, k=10):
    """Barh: top/bottom Δ-reductions vs ERM_2 for a method."""
    if df_red.empty:
        print(f"[{method}] No reductions to plot.")
        return
    topk = df_red.sort_values("delta_reduction", ascending=False).head(k)
    botk = df_red.sort_values("delta_reduction", ascending=True).head(k)

    # Top improvements
    plt.figure(figsize=(8,4))
    plt.barh(topk["group"], topk["delta_reduction"])
    plt.axvline(0, color="k", lw=1)
    plt.title(f"Top Δ-reductions vs {BASE} (val→test drop) — {method}")
    plt.xlabel("Reduction of drop (positive = better than ERM_2)")
    plt.tight_layout()
    fn = os.path.join(ROOT, f"CDS_DeltaReduction_{method}_top.png")
    plt.savefig(fn, dpi=160); plt.close(); print("Saved:", fn)

    # Biggest regressions
    plt.figure(figsize=(8,4))
    plt.barh(botk["group"], botk["delta_reduction"])
    plt.axvline(0, color="k", lw=1)
    plt.title(f"Worst Δ-reductions vs {BASE} (val→test drop) — {method}")
    plt.xlabel("Reduction of drop (positive = better than ERM_2)")
    plt.tight_layout()
    fn = os.path.join(ROOT, f"CDS_DeltaReduction_{method}_bottom.png")
    plt.savefig(fn, dpi=160); plt.close(); print("Saved:", fn)

def heatmap_reduction(df_red, method):
    if df_red.empty:
        return
    piv = df_red.pivot_table(index="class", columns="stylebin", values="delta_reduction", aggfunc="mean")
    plt.figure(figsize=(5.2,3.8))
    im = plt.imshow(piv.values, aspect='auto')
    plt.xticks(range(len(piv.columns)), piv.columns)
    plt.yticks(range(len(piv.index)), piv.index)
    plt.colorbar(im, fraction=0.046, pad=0.04, label=f"Δ-reduction vs {BASE}")
    plt.title(f"CDS Δ-reduction heatmap — {method} (target={TARGET})")
    plt.tight_layout()
    fn = os.path.join(ROOT, f"CDS_DeltaReductionHeatmap_{method}_{TARGET}.png")
    plt.savefig(fn, dpi=160); plt.close(); print("Saved:", fn)

# --- Compute baseline drops (ERM_2) ---
base_df = cds_drop_table(BASE)
if base_df.empty:
    raise RuntimeError(f"Missing or empty drop table for {BASE}. Ensure val_source_CDS.csv and groups_cds.csv exist.")

# Save baseline drops for reference
base_out = os.path.join(ROOT, f"CDS_Drops_{BASE}_{TARGET}.csv")
base_df.to_csv(base_out, index=False); print("Saved:", base_out)

# --- Compare each method to ERM_2 ---
all_reductions = []
for method in OTHERS:
    cur = cds_drop_table(method)
    if cur.empty:
        print(f"[skip] No CDS drops for {method}")
        continue
    M = base_df[["class","stylebin","drop","acc_val","acc_test","total_val","total_test"]].merge(
        cur[["class","stylebin","drop","acc_val","acc_test","total_val","total_test"]],
        on=["class","stylebin"], how="inner", suffixes=(f"_{BASE}", f"_{method}")
    )
    if M.empty:
        print(f"[skip] No overlapping CDS groups after filters for {method}")
        continue
    # Δ-reduction vs ERM_2 (positive = improved vs ERM_2)
    M["delta_reduction"] = M[f"drop_{BASE}"] - M[f"drop_{method}"]
    M["group"] = M["class"] + " / " + M["stylebin"].astype(str)
    M["method"] = method
    M["target"] = TARGET
    all_reductions.append(M)

    # Save per-method CSV + plots
    out_csv = os.path.join(ROOT, f"CDS_DeltaReduction_{method}_vs_{BASE}_{TARGET}.csv")
    M.sort_values("delta_reduction", ascending=False).to_csv(out_csv, index=False)
    print("Saved:", out_csv)
    plot_reductions(M, method, k=10)
    heatmap_reduction(M, method)

# --- Combined summary table (top-5 per method) ---
if all_reductions:
    ALL = pd.concat(all_reductions, ignore_index=True)
    tops = (ALL.sort_values(["method","delta_reduction"], ascending=[True, False])
              .groupby("method").head(5))
    out_csv = os.path.join(ROOT, f"CDS_TopDeltaReductions_AllMethods_vs_{BASE}_{TARGET}.csv")
    tops.to_csv(out_csv, index=False); print("Saved:", out_csv)


Saved: /content/drive/MyDrive/DG_PACS/CDS_Drops_ERM_2_Sketch.csv
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReduction_IRMv1_vs_ERM_2_Sketch.csv
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReduction_IRMv1_top.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReduction_IRMv1_bottom.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReductionHeatmap_IRMv1_Sketch.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReduction_GroupDRO_vs_ERM_2_Sketch.csv
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReduction_GroupDRO_top.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReduction_GroupDRO_bottom.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReductionHeatmap_GroupDRO_Sketch.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReduction_SAM_vs_ERM_2_Sketch.csv
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReduction_SAM_top.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReduction_SAM_bottom.png
Saved: /content/drive/MyDrive/DG_PACS/CDS_DeltaReductionHeatmap_SAM_Sketch.png
Save

# Testing Feature Importance
C**1. Attribution & perturbation on Top-5 CDS drops**

A) GRadCAm and F/B ratio

In [None]:
import os, math, numpy as np, pandas as pd, matplotlib.pyplot as plt
import torch, torch.nn.functional as F
from torch.utils.data import DataLoader

# --- Grad-CAM hooks for last conv
def register_cam(model_backbone, target_layer_name="layer4"):  # resnet18 last block
    acts, grads = {}, {}
    layer = dict(model_backbone.named_modules())[target_layer_name]

    def fwd_hook(m, i, o): acts["A"] = o.detach()
    def bwd_hook(m, gin, gout): grads["G"] = gout[0].detach()
    h1 = layer.register_forward_hook(fwd_hook)
    h2 = layer.register_full_backward_hook(bwd_hook)
    return acts, grads, (h1, h2)

def grad_cam(backbone, head, x, y=None):
    backbone.eval(); head.eval()
    acts, grads, hs = register_cam(backbone, "layer4")
    with torch.enable_grad():
        x = x.requires_grad_(True)
        logits, _ = forward_logits(backbone, head, x)
        if y is None:
            y = logits.argmax(1)
        loss = F.cross_entropy(logits, y)
        loss.backward()
    A, G = acts["A"], grads["G"]   # [B,C,H,W]
    weights = G.mean(dim=(2,3), keepdim=True)  # [B,C,1,1]
    cam = (A * weights).sum(dim=1, keepdim=True)  # [B,1,H,W]
    cam = F.relu(cam)
    # upsample to input size
    cam = F.interpolate(cam, size=x.shape[-2:], mode="bilinear", align_corners=False)
    cam = cam.squeeze(1)  # [B,H,W]
    # cleanup hooks
    for h in hs: h.remove()
    return cam, logits

# --- Coarse foreground mask: centered ellipse (~object proxy)
def center_ellipse_mask(H, W, scale=0.6):
    yy, xx = torch.meshgrid(torch.linspace(-1,1,H), torch.linspace(-1,1,W), indexing="ij")
    r = (xx**2 + (yy*H/W)**2)  # aspect-corrected
    return (r <= (scale**2)).float()  # 1=FG, 0=BG

def fb_ratio_from_cam(cam, fg_mask):
    # cam: [H,W] non-negative
    s = cam.sum().item() + 1e-12
    w_fg = (cam * fg_mask).sum().item() / s
    w_bg = (cam * (1 - fg_mask)).sum().item() / s
    return w_fg / max(w_bg, 1e-8)

# --- Build a loader that yields only the selected CDS groups from TARGET test
def subset_loader_for_cds(method, cls_name, stylebin, batch_size=32):
    # Use your existing splits & style-map rebuild
    train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, TARGET, val_frac=0.1)
    style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
    style_map_test  = apply_stylebins_to_indices(DATA_ROOT, test_idx, thresholds)

    # filter test_idx for the requested CDS
    keep = []
    for (d,c,fn) in test_idx:
        if d != TARGET: continue
        rel = f"{d}/{c}/{fn}"
        sb = style_map_test.get(rel, None)
        if (c == cls_name) and (sb == stylebin):
            keep.append((d,c,fn))

    ds = PACS(DATA_ROOT, [TARGET], split_indices=keep, transform=eval_tf, stylebin_map={**style_map_train,**style_map_test})
    dl = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)
    return dl, classes

# --- Run Grad-CAM & compute F/B ratios for each suspect CDS group
def fb_ratios_for_group(method, cls_name, stylebin):
    # load model
    ckpt = torch.load(os.path.join(ROOT, method, TARGET, "best_model.pt"), map_location=DEVICE)
    bb, hd = build_model(n_classes=7)  # PACS=7 classes; or len(classes) if you return it
    bb.load_state_dict(ckpt["backbone"]); hd.load_state_dict(ckpt["head"])
    bb, hd = bb.to(DEVICE), hd.to(DEVICE)

    dl, classes = subset_loader_for_cds(method, cls_name, stylebin, batch_size=16)
    out = []
    for x, y, meta in dl:
        x = x.to(DEVICE)
        cam, logits = grad_cam(bb, hd, x, None)
        for i in range(x.size(0)):
            H,W = cam.shape[-2:]
            fg = center_ellipse_mask(H,W, scale=0.6).to(cam.device)
            ratio = fb_ratio_from_cam(cam[i], fg)
            row = {"class": cls_name, "stylebin": stylebin,
                   "pred": int(logits[i].argmax().item()),
                   "fb_ratio": float(ratio)}
            # keep domain/rel_key if present
            for k in ["domain","rel_key"]:
                if k in meta[i]: row[k] = meta[i][k]
            out.append(row)
    return pd.DataFrame(out)

In [None]:
# --- Pick which method to analyze and which CDS groups (use your top-drops CSV)
ROOT   = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"
METHOD = "GroupDRO"   # or "ERM_2", "IRMv1", "SAM"

topdrops_path = os.path.join(ROOT, f"CDS_Drops_{METHOD}_{TARGET}.csv")
topk = (pd.read_csv(topdrops_path)
          .sort_values("drop", ascending=False)
          .head(5)[["class","stylebin"]].drop_duplicates())

print("Top-5 CDS suspects for", METHOD)
display(topk)


rows=[]
for _, r in topk.iterrows():
    rows.append(fb_ratios_for_group(METHOD, r["class"], r["stylebin"]))
FB = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame()
display(FB.head())

# Aggregate & plot
if not FB.empty:
    agg = (FB.groupby(["class","stylebin"], as_index=False)
             .agg(mean_fb=("fb_ratio","mean"), n=("fb_ratio","size")))
    display(agg.sort_values("mean_fb"))

    plt.figure(figsize=(6,3.5))
    lbls = agg["class"] + " / " + agg["stylebin"].astype(str)
    plt.barh(lbls, agg["mean_fb"])
    plt.xlabel("Foreground/Background attribution ratio ↑ (more FG-focused)")
    plt.title(f"Grad-CAM F/B ratio on CDS suspects — {METHOD} @ {TARGET}")
    plt.tight_layout()
    plt.savefig(os.path.join(ROOT, f"FB_ratio_{METHOD}_{TARGET}.png"), dpi=160); plt.close()

Top-5 CDS suspects for GroupDRO


Unnamed: 0,class,stylebin
0,horse,low
1,dog,low
2,dog,high
3,giraffe,mid
4,dog,mid


Unnamed: 0,class,stylebin,pred,fb_ratio,domain,rel_key
0,horse,low,3,0.055445,Sketch,Sketch/horse/train_009501.jpg
1,horse,low,4,0.185413,Sketch,Sketch/horse/train_009461.jpg
2,horse,low,4,0.000186,Sketch,Sketch/horse/train_009649.jpg
3,horse,low,4,0.033821,Sketch,Sketch/horse/train_009134.jpg
4,horse,low,4,0.003571,Sketch,Sketch/horse/train_009714.jpg


Unnamed: 0,class,stylebin,mean_fb,n
4,horse,low,0.210402,198
3,giraffe,mid,0.226521,225
0,dog,high,0.245186,217
1,dog,low,0.246348,192
2,dog,mid,0.38426,363


In [None]:
# --- Pick which method to analyze and which CDS groups (use your top-drops CSV)
ROOT   = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"
METHOD = "ERM_2"   # or "ERM_2", "IRMv1", "SAM"

topdrops_path = os.path.join(ROOT, f"CDS_Drops_{METHOD}_{TARGET}.csv")
topk = (pd.read_csv(topdrops_path)
          .sort_values("drop", ascending=False)
          .head(5)[["class","stylebin"]].drop_duplicates())

print("Top-5 CDS suspects for", METHOD)
display(topk)


rows=[]
for _, r in topk.iterrows():
    rows.append(fb_ratios_for_group(METHOD, r["class"], r["stylebin"]))
FB = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame()
display(FB.head())

# Aggregate & plot
if not FB.empty:
    agg = (FB.groupby(["class","stylebin"], as_index=False)
             .agg(mean_fb=("fb_ratio","mean"), n=("fb_ratio","size")))
    display(agg.sort_values("mean_fb"))

    plt.figure(figsize=(6,3.5))
    lbls = agg["class"] + " / " + agg["stylebin"].astype(str)
    plt.barh(lbls, agg["mean_fb"])
    plt.xlabel("Foreground/Background attribution ratio ↑ (more FG-focused)")
    plt.title(f"Grad-CAM F/B ratio on CDS suspects — {METHOD} @ {TARGET}")
    plt.tight_layout()
    plt.savefig(os.path.join(ROOT, f"FB_ratio_{METHOD}_{TARGET}.png"), dpi=160); plt.close()

Top-5 CDS suspects for ERM_2


Unnamed: 0,class,stylebin
0,dog,high
1,dog,low
8,giraffe,mid
2,dog,mid
10,horse,low


Unnamed: 0,class,stylebin,pred,fb_ratio,domain,rel_key
0,dog,high,0,0.0,Sketch,Sketch/dog/train_006577.jpg
1,dog,high,1,0.247637,Sketch,Sketch/dog/train_006643.jpg
2,dog,high,4,0.100471,Sketch,Sketch/dog/train_006725.jpg
3,dog,high,1,0.990618,Sketch,Sketch/dog/train_006153.jpg
4,dog,high,0,0.0,Sketch,Sketch/dog/train_006222.jpg


Unnamed: 0,class,stylebin,mean_fb,n
3,giraffe,mid,0.225457,225
4,horse,low,0.264578,198
0,dog,high,0.364263,217
1,dog,low,0.384039,192
2,dog,mid,0.396674,363


In [None]:
# --- Pick which method to analyze and which CDS groups (use your top-drops CSV)
ROOT   = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"
METHOD = "IRMv1"   # or "ERM_2", "IRMv1", "SAM"

topdrops_path = os.path.join(ROOT, f"CDS_Drops_{METHOD}_{TARGET}.csv")
topk = (pd.read_csv(topdrops_path)
          .sort_values("drop", ascending=False)
          .head(5)[["class","stylebin"]].drop_duplicates())

print("Top-5 CDS suspects for", METHOD)
display(topk)


rows=[]
for _, r in topk.iterrows():
    rows.append(fb_ratios_for_group(METHOD, r["class"], r["stylebin"]))
FB = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame()
display(FB.head())

# Aggregate & plot
if not FB.empty:
    agg = (FB.groupby(["class","stylebin"], as_index=False)
             .agg(mean_fb=("fb_ratio","mean"), n=("fb_ratio","size")))
    display(agg.sort_values("mean_fb"))

    plt.figure(figsize=(6,3.5))
    lbls = agg["class"] + " / " + agg["stylebin"].astype(str)
    plt.barh(lbls, agg["mean_fb"])
    plt.xlabel("Foreground/Background attribution ratio ↑ (more FG-focused)")
    plt.title(f"Grad-CAM F/B ratio on CDS suspects — {METHOD} @ {TARGET}")
    plt.tight_layout()
    plt.savefig(os.path.join(ROOT, f"FB_ratio_{METHOD}_{TARGET}.png"), dpi=160); plt.close()

In [None]:
# --- Pick which method to analyze and which CDS groups (use your top-drops CSV)
ROOT   = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"
METHOD = "SAM"   # or "ERM_2", "IRMv1", "SAM"

topdrops_path = os.path.join(ROOT, f"CDS_Drops_{METHOD}_{TARGET}.csv")
topk = (pd.read_csv(topdrops_path)
          .sort_values("drop", ascending=False)
          .head(5)[["class","stylebin"]].drop_duplicates())

print("Top-5 CDS suspects for", METHOD)
display(topk)


rows=[]
for _, r in topk.iterrows():
    rows.append(fb_ratios_for_group(METHOD, r["class"], r["stylebin"]))
FB = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame()
display(FB.head())

# Aggregate & plot
if not FB.empty:
    agg = (FB.groupby(["class","stylebin"], as_index=False)
             .agg(mean_fb=("fb_ratio","mean"), n=("fb_ratio","size")))
    display(agg.sort_values("mean_fb"))

    plt.figure(figsize=(6,3.5))
    lbls = agg["class"] + " / " + agg["stylebin"].astype(str)
    plt.barh(lbls, agg["mean_fb"])
    plt.xlabel("Foreground/Background attribution ratio ↑ (more FG-focused)")
    plt.title(f"Grad-CAM F/B ratio on CDS suspects — {METHOD} @ {TARGET}")
    plt.tight_layout()
    plt.savefig(os.path.join(ROOT, f"FB_ratio_{METHOD}_{TARGET}.png"), dpi=160); plt.close()

(B) Occlusion/blur tests (no masks needed)

Two stressors per image:

Object mask (center ellipse) → set to mean gray (object removed)

Background mask (inverse) → blur/gray background (object intact)
Compare accuracy drop from clean to perturbed

In [None]:
import cv2

def apply_occlusion(x_img, kind="object", strength=11):
    # x_img: torch tensor [3,H,W] in normalized space; denorm to uint8 first
    mean = torch.tensor([0.485,0.456,0.406], device=x_img.device).view(3,1,1)
    std  = torch.tensor([0.229,0.224,0.225], device=x_img.device).view(3,1,1)
    img = (x_img*std + mean).clamp(0,1).mul(255).permute(1,2,0).byte().cpu().numpy()  # HWC

    H,W = img.shape[:2]
    fg = center_ellipse_mask(H,W,0.6).cpu().numpy()
    bg = 1 - fg

    if kind=="object":
        # zero-out object (ellipse)
        occluded = img.copy()
        occluded[fg.astype(bool)] = img.mean(axis=(0,1)).astype(np.uint8)
    else:  # background
        occluded = cv2.GaussianBlur(img, (strength, strength), 0)
        occluded[fg.astype(bool)] = img[fg.astype(bool)]

    # back to normalized tensor
    t = torch.from_numpy(occluded).permute(2,0,1).float()/255.0
    t = (t - mean.cpu())/std.cpu()
    return t

def eval_with_perturb(backbone, head, loader, kind):
    backbone.eval(); head.eval()
    correct=total=0
    with torch.no_grad():
        for x,y,meta in loader:
            x_aug = torch.stack([apply_occlusion(xi.cpu(), kind=kind) for xi in x]).to(DEVICE)
            logits,_ = forward_logits(backbone, head, x_aug.to(DEVICE))
            pred = logits.argmax(1).cpu()
            correct += (pred==y).sum().item(); total += y.numel()
    return correct/max(1,total)

# Reuse subset loader from above (for each suspect CDS group)
ckpt = torch.load(os.path.join(ROOT, METHOD, TARGET, "best_model.pt"), map_location=DEVICE)
bb, hd = build_model(n_classes=7); bb.load_state_dict(ckpt["backbone"]); hd.load_state_dict(ckpt["head"])
bb, hd = bb.to(DEVICE), hd.to(DEVICE)

rows=[]
for _, r in topk.iterrows():
    dl, classes = subset_loader_for_cds(METHOD, r["class"], r["stylebin"], batch_size=32)
    # clean
    base_acc = evaluate({"backbone":bb,"head":hd}, dl, classes=classes, group_key=None)["avg"]
    acc_obj  = eval_with_perturb(bb, hd, dl, kind="object")
    acc_bg   = eval_with_perturb(bb, hd, dl, kind="background")
    rows.append({"class": r["class"], "stylebin": r["stylebin"],
                 "clean": base_acc, "obj_mask": acc_obj, "bg_mask": acc_bg,
                 "d_obj": base_acc - acc_obj, "d_bg": base_acc - acc_bg})
pd.DataFrame(rows).to_csv(os.path.join(ROOT, f"Occlusion_{METHOD}_{TARGET}.csv"), index=False)
display(pd.DataFrame(rows))

Unnamed: 0,class,stylebin,clean,obj_mask,bg_mask,d_obj,d_bg
0,dog,high,0.442396,0.110599,0.502304,0.331797,-0.059908
1,dog,low,0.494792,0.020833,0.510417,0.473958,-0.015625
2,giraffe,mid,0.608889,0.408889,0.586667,0.2,0.022222
3,dog,mid,0.506887,0.060606,0.526171,0.446281,-0.019284
4,horse,low,0.646465,0.111111,0.540404,0.535354,0.106061


Interpretation: if masking background barely changes accuracy but masking object tanks it (good), the model is object-focused. If the reverse happens (big Δ under bg-mask, small Δ under obj-mask), it’s relying on background cues—classic shortcut.


# C3. Feature probes (domain predictability)
Train a linear probe on penultimate features Z to predict domain while stratifying by class (class-balanced batches). High domain accuracy ⇒ features encode nuisance info.

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score

def extract_feats(model, loader):
    zs, ys_cls, ys_dom = [], [], []
    bb, hd = model["backbone"], model["head"]
    bb.eval(); hd.eval()
    with torch.no_grad():
        for x,y,meta in loader:
            x = x.to(DEVICE)
            _, feats = forward_logits(bb, hd, x)  # assume forward returns (logits, feats)
            zs.append(feats.cpu().numpy())
            ys_cls.append(y.numpy())
            ys_dom += [m["domain"] for m in meta]
    return np.concatenate(zs), np.concatenate(ys_cls), np.array(ys_dom)

# Build a *source-val* loader (all source domains) — same as in training
train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, TARGET, val_frac=0.1)
style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
style_map_val   = apply_stylebins_to_indices(DATA_ROOT, val_idx, thresholds)
ds_val  = PACS(DATA_ROOT, source_domains, split_indices=val_idx, transform=eval_tf, stylebin_map={**style_map_train,**style_map_val})
dl_val  = DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)

# Load models for all methods and compare probe accuracies
def load_model(method):
    ckpt = torch.load(os.path.join(ROOT, method, TARGET, "best_model.pt"), map_location=DEVICE)
    bb, hd = build_model(n_classes=len(classes))
    bb.load_state_dict(ckpt["backbone"]); hd.load_state_dict(ckpt["head"])
    return {"backbone": bb.to(DEVICE).eval(), "head": hd.to(DEVICE).eval()}

rows=[]
for m in ["ERM_2","IRMv1","GroupDRO","SAM"]:
    model = load_model(m)
    Z, Yc, Yd = extract_feats(model, dl_val)
    # Simple probe: predict domain (Photo/Art/Cartoon)
    clf = LogisticRegression(max_iter=200, multi_class="multinomial")
    clf.fit(Z, Yd)
    pred = clf.predict(Z)
    acc = accuracy_score(Yd, pred)
    rows.append({"method": m, "probe_domain_acc_on_val": acc})

df_probe = pd.DataFrame(rows)
display(df_probe)
df_probe.to_csv(os.path.join(ROOT, "Probes_Domain_on_Val.csv"), index=False)




Unnamed: 0,method,probe_domain_acc_on_val
0,ERM_2,1.0
1,IRMv1,1.0
2,GroupDRO,1.0
3,SAM,1.0


In [None]:
import numpy as np, pandas as pd, os, torch
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader

ROOT = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"  # fixed target; we probe on source-val only

# --- build source-val loader exactly like in training ---
train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, TARGET, val_frac=0.1)
style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
style_map_val   = apply_stylebins_to_indices(DATA_ROOT, val_idx, thresholds)

ds_val = PACS(DATA_ROOT, source_domains, split_indices=val_idx, transform=eval_tf,
              stylebin_map={**style_map_train, **style_map_val})
dl_val = DataLoader(ds_val, batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)

def load_model(method):
    ckpt = torch.load(os.path.join(ROOT, method, TARGET, "best_model.pt"), map_location=DEVICE)
    bb, hd = build_model(n_classes=len(classes))
    bb.load_state_dict(ckpt["backbone"]); hd.load_state_dict(ckpt["head"])
    return {"backbone": bb.to(DEVICE).eval(), "head": hd.to(DEVICE).eval()}

def extract_feats(model, loader):
    zs, y_dom, y_cls = [], [], []
    bb, hd = model["backbone"], model["head"]
    with torch.no_grad():
        for x, y, meta in loader:
            x = x.to(DEVICE)
            logits, feats = forward_logits(bb, hd, x)  # feats: penultimate
            zs.append(feats.cpu().numpy())
            y_dom += [m["domain"] for m in meta]       # strings (Photo/Art_painting/Cartoon)
            y_cls.append(y.numpy())
    Z = np.concatenate(zs, axis=0)
    y_cls = np.concatenate(y_cls, axis=0)
    # map domains to integers
    dom_to_id = {d:i for i,d in enumerate(sorted(set(y_dom)))}
    y_dom = np.array([dom_to_id[d] for d in y_dom], dtype=int)
    return Z, y_dom, y_cls, dom_to_id

def domain_probe_cv(Z, y_dom, y_cls, n_folds=5):
    # class-stratified CV: split by class labels
    skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=0)
    accs = []
    for tr, te in skf.split(Z, y_cls):
        clf = make_pipeline(StandardScaler(with_mean=True, with_std=True),
                            LogisticRegression(max_iter=1000, multi_class="multinomial", n_jobs=-1))
        clf.fit(Z[tr], y_dom[tr])
        pred = clf.predict(Z[te])
        accs.append(accuracy_score(y_dom[te], pred))
    return float(np.mean(accs)), float(np.std(accs))

rows = []
for method in ["ERM_2","IRMv1","GroupDRO","SAM"]:
    model = load_model(method)
    Z, y_dom, y_cls, dommap = extract_feats(model, dl_val)
    mean_acc, std_acc = domain_probe_cv(Z, y_dom, y_cls, n_folds=5)
    rows.append({"method": method, "probe_domain_acc_on_val_mean": mean_acc, "std": std_acc,
                 "n_samples": len(Z), "n_domains": len(dommap)})

probe_summary = pd.DataFrame(rows)
display(probe_summary)
probe_summary.to_csv(os.path.join(ROOT, "Probes_Domain_on_Val_CV.csv"), index=False)
print("Saved: Probes_Domain_on_Val_CV.csv")


In [None]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    DRIVE_MOUNTED = True
except Exception as e:
    print("Colab Drive mount not available, saving locally under ./outputs")
    DRIVE_MOUNTED = False

Mounted at /content/drive


In [None]:
import os, pandas as pd

ROOT="/content/drive/MyDrive/DG_PACS"; TARGET="Sketch"
def stylebin_delta(method):
    # assumes you already created CDS deltas vs ERM_2
    path = os.path.join(ROOT, f"CDS_DeltaReduction_{method}_vs_ERM_2_{TARGET}.csv")
    if not os.path.isfile(path): return None
    df = pd.read_csv(path)  # has columns: class, stylebin, delta_reduction, ...
    return df.groupby("stylebin")["delta_reduction"].mean().sort_index()

# CDS_DeltaReduction_GroupDRO_vs_ERM_2_Sketch
for m in ["IRMv1","GroupDRO","SAM"]:
    s = stylebin_delta(m)
    if s is not None:
        print(m, "Δ-reduction by stylebin on Sketch:\n", s, "\n")
    else:
        print(m, "Δ-reduction not available on Sketch")


IRMv1 Δ-reduction by stylebin on Sketch:
 stylebin
high   -0.002046
low    -0.029164
mid    -0.051684
Name: delta_reduction, dtype: float64 

GroupDRO Δ-reduction by stylebin on Sketch:
 stylebin
high    0.063881
low    -0.017716
mid     0.018446
Name: delta_reduction, dtype: float64 

SAM Δ-reduction by stylebin on Sketch:
 stylebin
high    0.012436
low     0.045693
mid     0.028053
Name: delta_reduction, dtype: float64 



In [None]:
import numpy as np, pandas as pd, os
from scipy.stats import wasserstein_distance as W1

# Build style-score lists per domain from the *fixed splits* you use
def style_scores_for(domain, split_idx, thresholds):
    # reuse your file->stylebin mapping functions, but also get raw StyleScore if stored
    # If you only have bins, you can approximate a histogram from bin counts.
    # Below assumes you stored per-image StyleScore (if not, adapt to bin counts).
    # Example stub where you re-run the StyleScore function on images in split_idx[domain].
    pass  # (fill in with your existing style-scoring helper)

# Quick proxy using bins (no raw scores): compare bin distributions
def style_bin_hist(domain, split_csv):
    # expects a CSV with (class, domain, stylebin, total) for that split
    t = split_csv[split_csv["domain"].eq(domain)]
    counts = t.groupby("stylebin")["total"].sum().reindex(["low","mid","high"]).fillna(0).values
    return counts / counts.sum()

# Load your source-val counts and target-test counts
val_cds = pd.read_csv(os.path.join(ROOT,"ERM_2",TARGET,"val_source_CDS.csv"))
tst_cds = pd.read_csv(os.path.join(ROOT,"ERM_2",TARGET,"groups_cds.csv"))

src_domains = sorted(val_cds["domain"].unique())

h_target = style_bin_hist(TARGET, tst_cds)
for d in src_domains:
    h_src = style_bin_hist(d, val_cds)
    # Use W1 over the ordinal bins (0,1,2) as a coarse distance
    xs = np.array([0,1,2], dtype=float)
    dist = W1(xs, xs, u_weights=h_src, v_weights=h_target)
    print(f"W1(stylebin) {d} → Sketch: {dist:.3f}")


W1(stylebin) Art_painting → Sketch: 0.033
W1(stylebin) Cartoon → Sketch: 0.055
W1(stylebin) Photo → Sketch: 0.019


In [None]:
import os, pandas as pd, numpy as np, matplotlib.pyplot as plt

ROOT   = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"  # your fixed target
LOG_GD = os.path.join(ROOT, "GroupDRO", TARGET, "train_log.csv")

df = pd.read_csv(LOG_GD)

# --- sanity check: worst_src_domain_acc equals per-domain min ---
dom_cols = ["acc_Art_painting","acc_Cartoon","acc_Photo"]
df["min_domains"] = df[dom_cols].min(axis=1)
if "worst_src_domain_acc" in df.columns:
    diff = (df["worst_src_domain_acc"] - df["min_domains"]).abs().max()
    print(f"[sanity] max |worst_src_domain_acc - min(per-domain)| = {diff:.4g}")

# --- handy aggregates ---
df["gap_max_min"] = df[dom_cols].max(axis=1) - df[dom_cols].min(axis=1)
df["mean_domains"] = df[dom_cols].mean(axis=1)

ep = df["epoch"] if "epoch" in df.columns else np.arange(len(df))

# --- 1) Worst-source curve (what the assignment asks to report) ---
plt.figure(figsize=(6,3.4))
plt.plot(ep, df["worst_src_domain_acc"], lw=2, label="Worst source-domain acc")
plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.ylim(0,1)
plt.title("GroupDRO: worst source-domain accuracy (source-val)")
plt.grid(alpha=0.3); plt.tight_layout()
out1 = os.path.join(ROOT, "GroupDRO", TARGET, "worst_source_curve.png")
plt.savefig(out1, dpi=150); plt.close()
print("Saved:", out1)

# --- 2) Per-domain curves + domain gap ---
plt.figure(figsize=(7.2,3.8))
for c in dom_cols:
    plt.plot(ep, df[c], lw=1.8, label=c.replace("acc_",""))
plt.plot(ep, df["gap_max_min"], lw=1.6, ls="--", label="gap (max−min)", alpha=0.9)
plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.ylim(0,1)
plt.title("GroupDRO: per-domain source-val accuracy & gap")
plt.legend(ncol=2, fontsize=9); plt.grid(alpha=0.3); plt.tight_layout()
out2 = os.path.join(ROOT, "GroupDRO", TARGET, "per_domain_and_gap.png")
plt.savefig(out2, dpi=150); plt.close()
print("Saved:", out2)

# --- 3) Entropy & K_eff (effective #domains) ---
if {"weights_entropy","weights_eff_num"}.issubset(df.columns):
    fig, ax1 = plt.subplots(figsize=(6.6,3.4))
    ax1.plot(ep, df["weights_eff_num"], lw=2, label="K_eff (effective domains)")
    ax1.set_xlabel("Epoch"); ax1.set_ylabel("K_eff"); ax1.set_ylim(1,3.05)
    ax2 = ax1.twinx()
    ax2.plot(ep, df["weights_entropy"], lw=1.6, color="tab:orange", label="entropy")
    ax2.set_ylabel("weights entropy")
    ax1.set_title("GroupDRO: domain-weight balance over training")
    ax1.grid(alpha=0.3)
    # 2 legends
    h1, l1 = ax1.get_legend_handles_labels()
    h2, l2 = ax2.get_legend_handles_labels()
    ax1.legend(h1+h2, l1+l2, loc="lower right", fontsize=9)
    fig.tight_layout()
    out3 = os.path.join(ROOT, "GroupDRO", TARGET, "weights_balance.png")
    fig.savefig(out3, dpi=150); plt.close(fig)
    print("Saved:", out3)

# --- tiny summary to quote ---
start_worst = float(df["worst_src_domain_acc"].iloc[0])
end_worst   = float(df["worst_src_domain_acc"].iloc[-1])
best_ep     = int(df["worst_src_domain_acc"].idxmax()) if not df["worst_src_domain_acc"].isna().all() else None
best_val    = float(df["worst_src_domain_acc"].max())

end_gap     = float(df["gap_max_min"].iloc[-1])
start_gap   = float(df["gap_max_min"].iloc[0])

if "weights_eff_num" in df.columns:
    start_keff = float(df["weights_eff_num"].iloc[0]); end_keff = float(df["weights_eff_num"].iloc[-1])
    print(f"Worst-source acc: {start_worst:.3f} → {end_worst:.3f} (best {best_val:.3f} @ epoch {df['epoch'].iloc[best_ep] if best_ep is not None else '-'})")
    print(f"Domain gap (max−min): {start_gap:.3f} → {end_gap:.3f}")
    print(f"K_eff: {start_keff:.2f} → {end_keff:.2f}  (higher = more balanced)")
else:
    print(f"Worst-source acc: {start_worst:.3f} → {end_worst:.3f} (best {best_val:.3f})")
    print(f"Domain gap (max−min): {start_gap:.3f} → {end_gap:.3f}")


[sanity] max |worst_src_domain_acc - min(per-domain)| = 0
Saved: /content/drive/MyDrive/DG_PACS/GroupDRO/Sketch/worst_source_curve.png
Saved: /content/drive/MyDrive/DG_PACS/GroupDRO/Sketch/per_domain_and_gap.png
Saved: /content/drive/MyDrive/DG_PACS/GroupDRO/Sketch/weights_balance.png
Worst-source acc: 0.871 → 0.930 (best 0.935 @ epoch 17)
Domain gap (max−min): 0.111 → 0.051
K_eff: 2.98 → 2.73  (higher = more balanced)


In [None]:
# --- Flatness probe: loss vs parameter perturbation magnitude (ERM_2 vs SAM) ---
import os, copy, torch, numpy as np, matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.utils.data import DataLoader

ROOT   = "/content/drive/MyDrive/DG_PACS"
TARGET = "Sketch"    # fixed target
EPSILONS = [0.0, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3]  # tweak if you like
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Build source-val and target-test loaders exactly like elsewhere ---
train_idx, val_idx, test_idx, classes, source_domains = make_lodo_splits(DATA_ROOT, TARGET, val_frac=0.1)
style_map_train, thresholds = compute_stylebins_per_domain(DATA_ROOT, source_domains, train_idx)
style_map_val   = apply_stylebins_to_indices(DATA_ROOT, val_idx, thresholds)
style_map_test  = apply_stylebins_to_indices(DATA_ROOT, test_idx, thresholds)

ds_val  = PACS(DATA_ROOT, source_domains, split_indices=val_idx,  transform=eval_tf,
               stylebin_map={**style_map_train, **style_map_val})
ds_test = PACS(DATA_ROOT, [TARGET],      split_indices=test_idx, transform=eval_tf,
               stylebin_map={**style_map_train, **style_map_test})
dl_val  = DataLoader(ds_val,  batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)
dl_test = DataLoader(ds_test, batch_size=128, shuffle=False, num_workers=2, collate_fn=collate_keep_meta)

def load_model(method):
    ckpt = torch.load(os.path.join(ROOT, method, TARGET, "best_model.pt"), map_location=DEVICE)
    bb, hd = build_model(n_classes=len(classes))
    bb.load_state_dict(ckpt["backbone"]); hd.load_state_dict(ckpt["head"])
    return bb.to(DEVICE).eval(), hd.to(DEVICE).eval()

@torch.no_grad()
def ce_loss_on_loader(backbone, head, loader):
    tot_loss, tot_n = 0.0, 0
    for x,y,meta in loader:
        x,y = x.to(DEVICE), y.to(DEVICE)
        logits,_ = forward_logits(backbone, head, x)
        loss = F.cross_entropy(logits, y, reduction="sum")
        tot_loss += loss.item(); tot_n += y.numel()
    return tot_loss / max(1, tot_n)

def random_unit_direction_like(params):
    # flattened concatenated random direction normalized to unit L2, then reshaped per tensor
    vecs, flats = [], []
    for p in params:
        r = torch.randn_like(p)
        vecs.append(r)
        flats.append(r.view(-1))
    flat = torch.cat(flats)
    flat = flat / (flat.norm(p=2) + 1e-12)
    # redistribute normalized direction with same proportion as vecs
    out = []
    ofs = 0
    for v in vecs:
        n = v.numel()
        out.append(flat[ofs:ofs+n].view_as(v))
        ofs += n
    return out

def perturb_and_eval(method, loader, epsilons):
    bb, hd = load_model(method)
    # sample one random direction for (backbone + head) combined
    params = list(bb.parameters()) + list(hd.parameters())
    direction = random_unit_direction_like(params)
    # cache original weights
    originals = [p.detach().clone() for p in params]

    losses = []
    for eps in epsilons:
        # restore
        for p, w0 in zip(params, originals):
            p.copy_(w0)
        # apply perturbation
        for p, d in zip(params, direction):
            p.add_(eps * d)
        # measure loss
        loss = ce_loss_on_loader(bb, hd, loader)
        losses.append(loss)
    # restore clean weights at the end
    for p, w0 in zip(params, originals):
        p.copy_(w0)
    return losses

# --- Run probe on source-val and target-test
for split_name, loader in [("val", dl_val), ("test", dl_test)]:
    losses_erm = perturb_and_eval("ERM_2", loader, EPSILONS)
    losses_sam = perturb_and_eval("SAM",    loader, EPSILONS)

    plt.figure(figsize=(5.6,3.4))
    plt.plot(EPSILONS, losses_erm, marker="o", label="ERM_2")
    plt.plot(EPSILONS, losses_sam, marker="o", label="SAM")
    plt.xlabel("parameter perturbation magnitude (ε, L2 dir)"); plt.ylabel("CE loss")
    plt.title(f"Flatness probe: loss vs ε ({split_name}, target={TARGET})")
    plt.grid(alpha=0.3); plt.legend(); plt.tight_layout()
    outp = os.path.join(ROOT, f"flatness_loss_vs_eps_{split_name}_{TARGET}.png")
    plt.savefig(outp, dpi=150); plt.close()
    print(f"Saved: {outp}")

    # Also print the “sharpness” gap at a moderate ε
    mid_eps = EPSILONS[-2]  # e.g., 1e-3 or 2e-3 depending on your list
    i = EPSILONS.index(mid_eps)
    print(f"[{split_name}] loss@ε={mid_eps:g}  ERM_2={losses_erm[i]:.4f}  SAM={losses_sam[i]:.4f}")


NameError: name 'list_files' is not defined