# TASK 1





# Source Only Baseline

One-time setup

In [None]:
# === Fix Arrow/Datasets ABI mismatch (run once) ===
%pip -q uninstall -y pyarrow apache-beam
%pip -q install -U "pyarrow>=18.0.0,<19.0.0" "datasets>=2.19.0"

# 🔁 Auto-restart the Colab kernel so the new PyArrow is picked up cleanly
import os, IPython
print("Restarting runtime to finalize PyArrow install...")
IPython.display.clear_output(wait=True)
os.kill(os.getpid(), 9)


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.1/40.1 MB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[?25hRestarting runtime to finalize PyArrow install...


In [None]:
# === Cell A: Fix deps, auth, mount Drive, config ===
# 1) Ensure HF + I/O stack is compatible with your current Colab (transformers/gradio/diffusers)
!pip -q uninstall -y huggingface-hub -q
!pip -q install -U "huggingface_hub>=0.34.0,<1.0" "datasets>=2.19.0" "fsspec==2025.3.0" "gcsfs==2025.3.0"

# 2) Colab auth first (prevents "credential propagation was unsuccessful")
from google.colab import auth
auth.authenticate_user()

# 3) Mount Drive (retry-safe)
from google.colab import drive
try:
    drive.mount('/content/drive', force_remount=True)
except Exception as e:
    # Fallback: unmount then remount once
    try:
        drive.flush_and_unmount()
    except:
        pass
    drive.mount('/content/drive', force_remount=True)

# 4) Env + config
import os, random, json, time, math
from pathlib import Path
import numpy as np
import torch

# Persist HF cache to Drive (no re-downloads later)
os.environ["HF_HOME"] = "/content/drive/MyDrive/hf_cache"
os.environ["HF_DATASETS_CACHE"] = "/content/drive/MyDrive/hf_cache/datasets"
os.environ["TRANSFORMERS_CACHE"] = "/content/drive/MyDrive/hf_cache/transformers"

def set_all_seeds(seed=1337):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_all_seeds(1337)

class Cfg:
    # HF PACS domains: 'photo', 'sketch', 'art_painting', 'cartoon'
    SOURCE_DOMAIN = "photo"
    TARGET_DOMAIN = "sketch"
    NUM_CLASSES = 7
    IMG_SIZE = 224
    BATCH_SIZE = 64
    NUM_WORKERS = 4
    EPOCHS = 20
    LR = 0.003
    WD = 1e-4
    MOMENTUM = 0.9
    LABEL_SMOOTH = 0.0

    # ✅ All Task 1 artifacts -> Drive
    SAVE_ROOT = "/content/drive/MyDrive/DG_PACS/Task1"
    EXP_NAME = "T1.1_SourceOnly_PACS_photo2sketch_ResNet50"
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

Cfg = Cfg()

def make_out_dirs(exp_name=None):
    out_dir = Path(Cfg.SAVE_ROOT) / (exp_name or Cfg.EXP_NAME)
    (out_dir / "figs").mkdir(parents=True, exist_ok=True)
    (out_dir / "ckpts").mkdir(parents=True, exist_ok=True)
    return out_dir

OUT_DIR = make_out_dirs()
print(f"Device: {Cfg.DEVICE}")
print(f"Outputs -> {OUT_DIR}")
print("HF cache ->", os.environ["HF_DATASETS_CACHE"])


Mounted at /content/drive
Device: cuda
Outputs -> /content/drive/MyDrive/DG_PACS/Task1/T1.1_SourceOnly_PACS_photo2sketch_ResNet50
HF cache -> /content/drive/MyDrive/hf_cache/datasets


Data Pipeline

In [None]:
# === Cell B: PACS via Hugging Face + loaders ===
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from collections import Counter

# 1) Load PACS (single split with image/domain/label)
ds_all = load_dataset("flwrlabs/pacs")  # cached in Drive due to env vars above
ds_train = ds_all["train"]

# Inspect domains and labels
domain_counts = Counter(ds_train["domain"])
label_names = ds_train.features["label"].names if hasattr(ds_train.features["label"], "names") else None
print("Domains:", domain_counts)
print("Label names:", label_names)

# 2) Transforms
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_tf = transforms.Compose([
    transforms.Resize((Cfg.IMG_SIZE, Cfg.IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

test_tf = transforms.Compose([
    transforms.Resize((Cfg.IMG_SIZE, Cfg.IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

# 3) Domain filters
def _subset_by_domain(hf_ds, domain_name):
    return hf_ds.filter(lambda x: x["domain"] == domain_name)

src_hf = _subset_by_domain(ds_train, Cfg.SOURCE_DOMAIN)
tgt_hf = _subset_by_domain(ds_train, Cfg.TARGET_DOMAIN)

# 4) Torch Dataset wrapper
class HFDataset(Dataset):
    def __init__(self, hf_split, transform):
        self.ds = hf_split
        self.transform = transform
    def __len__(self): return self.ds.num_rows
    def __getitem__(self, idx):
        item = self.ds[idx]
        img = item["image"].convert("RGB")
        y = int(item["label"])
        return self.transform(img), y

src_train_ds = HFDataset(src_hf, transform=train_tf)
src_test_ds  = HFDataset(src_hf, transform=test_tf)   # source "test" = same domain, no aug
tgt_test_ds  = HFDataset(tgt_hf, transform=test_tf)   # target eval

loaders = {
    "src_train": DataLoader(src_train_ds, batch_size=Cfg.BATCH_SIZE, shuffle=True,
                            num_workers=Cfg.NUM_WORKERS, pin_memory=True),
    "src_test":  DataLoader(src_test_ds,  batch_size=Cfg.BATCH_SIZE, shuffle=False,
                            num_workers=Cfg.NUM_WORKERS, pin_memory=True),
    "tgt_test":  DataLoader(tgt_test_ds,  batch_size=Cfg.BATCH_SIZE, shuffle=False,
                            num_workers=Cfg.NUM_WORKERS, pin_memory=True),
}

IDX2CLASS = {i: name for i, name in enumerate(label_names)} if label_names else {i: f"class_{i}" for i in range(Cfg.NUM_CLASSES)}
print(f"Classes: {len(IDX2CLASS)} -> {IDX2CLASS}")
print(f"Source: {Cfg.SOURCE_DOMAIN} | Target: {Cfg.TARGET_DOMAIN}")


Domains: Counter({'sketch': 3929, 'cartoon': 2344, 'art_painting': 2048, 'photo': 1670})
Label names: ['dog', 'elephant', 'giraffe', 'guitar', 'horse', 'house', 'person']
Classes: 7 -> {0: 'dog', 1: 'elephant', 2: 'giraffe', 3: 'guitar', 4: 'horse', 5: 'house', 6: 'person'}
Source: photo | Target: sketch




Model & training utilities



In [None]:
# === Cell C: Model & utilities ===
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# ---- helpers ----
def save_json(obj, path):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)

class LabelSmoothingCE(nn.Module):
    def __init__(self, eps=0.0):
        super().__init__()
        self.eps = eps
        self.log_softmax = nn.LogSoftmax(dim=1)
    def forward(self, logits, target):
        if self.eps <= 1e-8:
            return F.cross_entropy(logits, target)
        n = logits.size(1)
        log_probs = self.log_softmax(logits)
        with torch.no_grad():
            true_dist = torch.zeros_like(log_probs)
            true_dist.fill_(self.eps / (n - 1))
            true_dist.scatter_(1, target.unsqueeze(1), 1 - self.eps)
        return torch.mean(torch.sum(-true_dist * log_probs, dim=1))

class ResNet50Classifier(nn.Module):
    def __init__(self, num_classes=7, pretrained=True):
        super().__init__()
        weights = models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        self.backbone = models.resnet50(weights=weights)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)
    def forward(self, x):
        return self.backbone(x)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    all_preds, all_labels = [], []
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        pred = logits.argmax(1)
        correct += (pred == y).sum().item()
        total += y.numel()
        all_preds.append(pred.cpu().numpy())
        all_labels.append(y.cpu().numpy())
    all_preds = np.concatenate(all_preds) if len(all_preds) else np.array([])
    all_labels = np.concatenate(all_labels) if len(all_labels) else np.array([])
    acc = 100.0 * correct / total if total > 0 else 0.0
    return acc, all_preds, all_labels

def per_class_stats(preds, labels, idx2class):
    if preds.size == 0:
        return np.zeros((len(idx2class), len(idx2class)), dtype=int), {idx2class[i]: 0.0 for i in range(len(idx2class))}
    cm = confusion_matrix(labels, preds, labels=list(range(len(idx2class))))
    per_cls_acc = (cm.diagonal() / cm.sum(axis=1).clip(min=1)) * 100.0
    per_cls_dict = {idx2class[i]: float(per_cls_acc[i]) for i in range(len(per_cls_acc))}
    return cm, per_cls_dict

def plot_confusion(cm, idx2class, title, savepath):
    fig = plt.figure(figsize=(6,5))
    plt.imshow(cm, interpolation='nearest')
    plt.title(title); plt.colorbar()
    ticks = np.arange(len(idx2class))
    labels = [idx2class[i] for i in range(len(idx2class))]
    plt.xticks(ticks, labels, rotation=45, ha='right'); plt.yticks(ticks, labels)
    plt.tight_layout(); plt.ylabel('True'); plt.xlabel('Predicted')
    plt.savefig(savepath, bbox_inches="tight", dpi=160); plt.close(fig)


Train Source-Only (ERM on source), then evaluate on source & target

In [None]:
# === Cell D: Train ERM on source-only & track metrics ===
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch

device = torch.device(Cfg.DEVICE)
model = ResNet50Classifier(num_classes=Cfg.NUM_CLASSES, pretrained=True).to(device)

criterion = LabelSmoothingCE(eps=Cfg.LABEL_SMOOTH)
optimizer = SGD(model.parameters(), lr=Cfg.LR, momentum=Cfg.MOMENTUM, weight_decay=Cfg.WD, nesterov=True)
scheduler = CosineAnnealingLR(optimizer, T_max=Cfg.EPOCHS)

best_tgt = -1.0
history = {"epoch": [], "src_acc": [], "tgt_acc": [], "lr": []}

for epoch in range(1, Cfg.EPOCHS + 1):
    model.train()
    running = 0.0
    nseen = 0
    for x, y in loaders["src_train"]:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        running += loss.item() * y.size(0)
        nseen += y.size(0)

    scheduler.step()

    # quick eval each epoch
    src_acc, _, _ = evaluate(model, loaders["src_test"], device)
    tgt_acc, _, _ = evaluate(model, loaders["tgt_test"], device)

    history["epoch"].append(epoch)
    history["src_acc"].append(src_acc)
    history["tgt_acc"].append(tgt_acc)
    history["lr"].append(optimizer.param_groups[0]["lr"])

    print(f"[{epoch:03d}/{Cfg.EPOCHS}] loss={running/max(nseen,1):.4f} | src={src_acc:.2f} | tgt={tgt_acc:.2f} | lr={history['lr'][-1]:.5f}")

    # save best-by-target checkpoint
    if tgt_acc > best_tgt:
        best_tgt = tgt_acc
        torch.save(model.state_dict(), OUT_DIR / "ckpts" / "best_by_target.pt")

# Save training curves & history to Drive
save_json(history, OUT_DIR / "history.json")

plt.figure()
plt.plot(history["epoch"], history["src_acc"], label="Source Acc")
plt.plot(history["epoch"], history["tgt_acc"], label="Target Acc")
plt.xlabel("Epoch"); plt.ylabel("Accuracy (%)"); plt.legend(); plt.title("T1.1 Source-Only: Acc vs Epoch")
plt.savefig(OUT_DIR / "figs" / "acc_curve.png", bbox_inches="tight", dpi=160); plt.close()

print(f"Saved: ckpt -> {OUT_DIR/'ckpts'/'best_by_target.pt'} | history -> {OUT_DIR/'history.json'} | curve -> {OUT_DIR/'figs'/'acc_curve.png'}")


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 190MB/s]


[001/20] loss=1.1498 | src=97.72 | tgt=22.55 | lr=0.00298
[002/20] loss=0.1965 | src=99.10 | tgt=16.37 | lr=0.00293
[003/20] loss=0.0620 | src=99.82 | tgt=16.90 | lr=0.00284
[004/20] loss=0.0376 | src=100.00 | tgt=17.05 | lr=0.00271
[005/20] loss=0.0254 | src=100.00 | tgt=16.39 | lr=0.00256
[006/20] loss=0.0202 | src=100.00 | tgt=16.26 | lr=0.00238
[007/20] loss=0.0176 | src=100.00 | tgt=16.29 | lr=0.00218
[008/20] loss=0.0142 | src=100.00 | tgt=16.49 | lr=0.00196
[009/20] loss=0.0120 | src=100.00 | tgt=16.70 | lr=0.00173
[010/20] loss=0.0105 | src=100.00 | tgt=16.34 | lr=0.00150
[011/20] loss=0.0072 | src=100.00 | tgt=16.31 | lr=0.00127
[012/20] loss=0.0071 | src=100.00 | tgt=16.29 | lr=0.00104
[013/20] loss=0.0077 | src=100.00 | tgt=16.54 | lr=0.00082
[014/20] loss=0.0070 | src=100.00 | tgt=16.93 | lr=0.00062
[015/20] loss=0.0091 | src=100.00 | tgt=16.34 | lr=0.00044
[016/20] loss=0.0066 | src=100.00 | tgt=16.37 | lr=0.00029
[017/20] loss=0.0089 | src=100.00 | tgt=16.87 | lr=0.00016


Final evaluation, per-class stats, confusion matrices, summary

In [None]:
# === Cell E: Final evaluation, per-class stats, confusion matrices, summary (saved to Drive) ===
import json
import numpy as np
import torch

# 1) Load best checkpoint and eval on source & target
device = torch.device(Cfg.DEVICE)
model = ResNet50Classifier(num_classes=Cfg.NUM_CLASSES, pretrained=False).to(device)
ckpt_path = OUT_DIR / "ckpts" / "best_by_target.pt"
model.load_state_dict(torch.load(ckpt_path, map_location=device))

src_acc, src_preds, src_labels = evaluate(model, loaders["src_test"], device)
tgt_acc, tgt_preds, tgt_labels = evaluate(model, loaders["tgt_test"], device)

# 2) Per-class + confusion matrices
src_cm, src_percls = per_class_stats(src_preds, src_labels, IDX2CLASS)
tgt_cm, tgt_percls = per_class_stats(tgt_preds, tgt_labels, IDX2CLASS)

plot_confusion(src_cm, IDX2CLASS, f"Source ({Cfg.SOURCE_DOMAIN}) Confusion", OUT_DIR / "figs" / "cm_source.png")
plot_confusion(tgt_cm, IDX2CLASS, f"Target ({Cfg.TARGET_DOMAIN}) Confusion", OUT_DIR / "figs" / "cm_target.png")

# 3) Aggregate metrics
avg_domain_acc = float((src_acc + tgt_acc) / 2.0)
worst_group_acc = float(min(src_acc, tgt_acc))

summary = {
    "exp_name": Cfg.EXP_NAME,
    "domains": {"source": Cfg.SOURCE_DOMAIN, "target": Cfg.TARGET_DOMAIN},
    "metrics": {
        "source_acc": float(src_acc),
        "target_acc": float(tgt_acc),
        "avg_domain_acc": avg_domain_acc,
        "worst_group_acc": worst_group_acc
    },
    "per_class_source": src_percls,
    "per_class_target": tgt_percls,
    "artifacts": {
        "ckpt_best_by_target": str(ckpt_path),
        "history_json": str(OUT_DIR / "history.json"),
        "acc_curve_png": str(OUT_DIR / "figs" / "acc_curve.png"),
        "cm_source_png": str(OUT_DIR / "figs" / "cm_source.png"),
        "cm_target_png": str(OUT_DIR / "figs" / "cm_target.png"),
        "summary_json": str(OUT_DIR / "summary.json"),
    },
    "notes": "Source-only baseline (ERM on source). Same backbone will be reused for Task 1 methods."
}

# 4) Save JSON and print a compact recap
with open(OUT_DIR / "summary.json", "w") as f:
    json.dump(summary, f, indent=2)

print(json.dumps({
    "source_acc": summary["metrics"]["source_acc"],
    "target_acc": summary["metrics"]["target_acc"],
    "avg_domain_acc": summary["metrics"]["avg_domain_acc"],
    "worst_group_acc": summary["metrics"]["worst_group_acc"],
}, indent=2))

print(f"Saved summary and figures to: {OUT_DIR}")


{
  "source_acc": 97.72455089820359,
  "target_acc": 22.55026724357343,
  "avg_domain_acc": 60.13740907088851,
  "worst_group_acc": 22.55026724357343
}
Saved summary and figures to: /content/drive/MyDrive/DG_PACS/Task1/T1.1_SourceOnly_PACS_photo2sketch_ResNet50


# Task 1.2: Domain Alignment starting with DANN (adversarial alignment)

DANN model pieces (GRL, feature extractor, domain head) + target train loader

Train DANN (source CE + adversarial domain loss) and log curves

In [None]:
# === Cell F (memory-lite): DANN utils & tgt-train loader ===
import torch
import torch.nn as nn
from torchvision import models
from torch.utils.data import DataLoader

# --- knobs you can tweak quickly ---
Cfg.MIXED_PREC = True          # amp on/off
Cfg.ACCUM_STEPS = 2            # gradient accumulation to emulate bigger batch
Cfg.BATCH_SIZE = max(16, Cfg.BATCH_SIZE // 2)  # halve per-step batch (effective restored via ACCUM_STEPS)
Cfg.NUM_WORKERS = 2            # fewer loader workers to save RAM
PERSISTENT_WORKERS = False

# Rebuild target train loader with new batch size/workers if needed
tgt_train_ds = HFDataset(tgt_hf, transform=train_tf)
tgt_train_loader = DataLoader(
    tgt_train_ds, batch_size=Cfg.BATCH_SIZE, shuffle=True,
    num_workers=Cfg.NUM_WORKERS, pin_memory=True, persistent_workers=PERSISTENT_WORKERS
)

# Feature extractor that can return just features (avoid unnecessary class logits on target)
class ResNet_Feature(nn.Module):
    def __init__(self, num_classes=7, backbone="resnet50", pretrained=True):
        super().__init__()
        self.backbone_name = backbone
        if backbone == "resnet50":
            weights = models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
            net = models.resnet50(weights=weights)
            feat_dim = 2048
        elif backbone == "resnet18":
            weights = models.ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
            net = models.resnet18(weights=weights)
            feat_dim = 512
        else:
            raise ValueError("Unsupported backbone")

        self.features = nn.Sequential(
            net.conv1, net.bn1, net.relu, net.maxpool,
            net.layer1, net.layer2, net.layer3, net.layer4, net.avgpool
        )
        self.feat_dim = feat_dim
        self.classifier = nn.Linear(feat_dim, num_classes)

    def forward(self, x, return_feat=False, class_head=True):
        f = self.features(x)
        f = torch.flatten(f, 1)
        if class_head:
            logits = self.classifier(f)
            return (logits, f) if return_feat else logits
        else:
            return f  # features only

# GRL
class GradReverse(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambd): ctx.lambd = lambd; return x.view_as(x)
    @staticmethod
    def backward(ctx, g): return g.neg() * ctx.lambd, None
def grad_reverse(x, lambd=1.0): return GradReverse.apply(x, lambd)

# Domain head
class DomainDiscriminator(nn.Module):
    def __init__(self, in_dim, hidden=1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(True), nn.Dropout(0.2),
            nn.Linear(hidden, 1)
        )
    def forward(self, feat): return self.net(feat).squeeze(1)

# DANN container
class DANN(nn.Module):
    def __init__(self, num_classes=7, backbone="resnet50", pretrained=True):
        super().__init__()
        self.backbone = ResNet_Feature(num_classes, backbone=backbone, pretrained=pretrained)
        self.domain_disc = DomainDiscriminator(self.backbone.feat_dim)

    def forward_domain(self, feat, grl_lambda=1.0):
        return self.domain_disc(grad_reverse(feat, grl_lambda))


In [None]:
# === Cell G (memory-lite): Train DANN with concat batches + AMP + grad accumulation ===
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR
from itertools import cycle
import numpy as np, torch, json, matplotlib.pyplot as plt

device = torch.device(Cfg.DEVICE)
model = DANN(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=True).to(device)

clf_criterion = LabelSmoothingCE(eps=Cfg.LABEL_SMOOTH)
dom_criterion = nn.BCEWithLogitsLoss()
optimizer = SGD(model.parameters(), lr=Cfg.LR, momentum=Cfg.MOMENTUM, weight_decay=Cfg.WD, nesterov=True)
scheduler = CosineAnnealingLR(optimizer, T_max=Cfg.EPOCHS)

scaler = torch.cuda.amp.GradScaler(enabled=Cfg.MIXED_PREC)

# λ schedule
def dann_lambda(p): return 2.0 / (1.0 + np.exp(-10 * p)) - 1.0

# New experiment dir (keeps 1.2 artifacts separate)
Cfg.EXP_NAME = "T1.2_DANN_PACS_photo2sketch_ResNet50_memlite"
OUT_DIR = make_out_dirs()

history = {"epoch": [], "src_acc": [], "tgt_acc": [], "dom_acc": [], "dom_loss": [], "lambda": [], "lr": []}
best_tgt = -1.0

src_iter = cycle(loaders["src_train"])
tgt_iter = cycle(tgt_train_loader)
batches_per_epoch = max(len(loaders["src_train"]), len(tgt_train_loader))

accum_steps = max(1, int(Cfg.ACCUM_STEPS))

for epoch in range(1, Cfg.EPOCHS+1):
    model.train()
    dom_correct = 0; dom_total = 0; dom_loss_sum = 0.0

    optimizer.zero_grad(set_to_none=True)

    for step in range(batches_per_epoch):
        p = ((epoch - 1) * batches_per_epoch + step) / (Cfg.EPOCHS * batches_per_epoch)
        lam = float(dann_lambda(p))

        xs, ys = next(src_iter)
        xt, _  = next(tgt_iter)
        # concat once → single forward
        x = torch.cat([xs, xt], 0).to(device, non_blocking=True)
        ys = ys.to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=Cfg.MIXED_PREC):
            # forward once; request features; only compute class head for source half
            feat_all = model.backbone(x, return_feat=False, class_head=False)  # [B_s+B_t, D]
            bs = xs.size(0)
            feat_s, feat_t = feat_all[:bs], feat_all[bs:]

            # classification on source: run small head on source features only
            logits_s = model.backbone.classifier(feat_s)
            loss_ce = clf_criterion(logits_s, ys)

            # domain loss on source+target (labels: 1 for source, 0 for target)
            dom_logits = torch.cat([model.forward_domain(feat_s, lam),
                                    model.forward_domain(feat_t, lam)], dim=0)
            dom_labels = torch.cat([torch.ones(bs, device=device),
                                    torch.zeros(feat_t.size(0), device=device)], dim=0)
            loss_dom = dom_criterion(dom_logits, dom_labels)

            loss = (loss_ce + loss_dom) / accum_steps  # normalize for accumulation

        scaler.scale(loss).backward()

        if (step + 1) % accum_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        # quick domain stats (no grad)
        with torch.no_grad():
            pred_dom = (torch.sigmoid(dom_logits) > 0.5).long()
            dom_correct += (pred_dom == dom_labels.long()).sum().item()
            dom_total += dom_labels.numel()
            dom_loss_sum += float(loss_dom.item())  # already reduced

        # free references ASAP
        del x, xs, xt, ys, feat_all, feat_s, feat_t, logits_s, dom_logits, dom_labels

    # catch leftover grads if batches_per_epoch not divisible by accum_steps
    if (batches_per_epoch % accum_steps) != 0:
        scaler.step(optimizer); scaler.update(); optimizer.zero_grad(set_to_none=True)

    scheduler.step()
    # Light CUDA GC between epochs
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Eval (classification only) using the backbone classifier
    src_acc, _, _ = evaluate(model.backbone, loaders["src_test"], device)
    tgt_acc, _, _ = evaluate(model.backbone, loaders["tgt_test"], device)

    dom_acc = 100.0 * dom_correct / max(dom_total, 1)
    history["epoch"].append(epoch)
    history["src_acc"].append(float(src_acc))
    history["tgt_acc"].append(float(tgt_acc))
    history["dom_acc"].append(float(dom_acc))
    history["dom_loss"].append(float(dom_loss_sum / max(batches_per_epoch,1)))
    history["lambda"].append(lam)
    history["lr"].append(optimizer.param_groups[0]["lr"])

    print(f"[DANN-mem {epoch:03d}/{Cfg.EPOCHS}] src={src_acc:.2f} | tgt={tgt_acc:.2f} | dom_acc={dom_acc:.1f} | λ={lam:.3f}")

    if tgt_acc > best_tgt:
        best_tgt = tgt_acc
        torch.save(model.state_dict(), OUT_DIR / "ckpts" / "best_dann.pt")
        torch.save(model.backbone.state_dict(), OUT_DIR / "ckpts" / "best_dann_backbone.pt")

# Save curves
save_json(history, OUT_DIR / "history_dann.json")
plt.figure(); plt.plot(history["epoch"], history["src_acc"], label="Source"); plt.plot(history["epoch"], history["tgt_acc"], label="Target")
plt.xlabel("Epoch"); plt.ylabel("Acc (%)"); plt.legend(); plt.title("DANN (mem-lite): Acc"); plt.savefig(OUT_DIR / "figs" / "dann_acc_curve.png", bbox_inches="tight", dpi=160); plt.close()
plt.figure(); plt.plot(history["epoch"], history["dom_acc"], label="Domain Acc"); plt.xlabel("Epoch"); plt.ylabel("Acc (%)"); plt.legend(); plt.title("DANN (mem-lite): Domain Acc")
plt.savefig(OUT_DIR / "figs" / "dann_domain_acc_curve.png", bbox_inches="tight", dpi=160); plt.close()
print(f"Saved DANN (mem-lite) artifacts to {OUT_DIR}")


  scaler = torch.cuda.amp.GradScaler(enabled=Cfg.MIXED_PREC)
  with torch.cuda.amp.autocast(enabled=Cfg.MIXED_PREC):


[DANN-mem 001/20] src=99.58 | tgt=17.74 | dom_acc=85.9 | λ=0.243
[DANN-mem 002/20] src=99.82 | tgt=33.32 | dom_acc=98.8 | λ=0.461
[DANN-mem 003/20] src=99.82 | tgt=15.50 | dom_acc=91.3 | λ=0.634
[DANN-mem 004/20] src=99.88 | tgt=38.30 | dom_acc=65.7 | λ=0.761
[DANN-mem 005/20] src=99.94 | tgt=49.94 | dom_acc=78.4 | λ=0.848
[DANN-mem 006/20] src=99.76 | tgt=61.03 | dom_acc=74.0 | λ=0.905
[DANN-mem 007/20] src=99.34 | tgt=29.65 | dom_acc=68.5 | λ=0.941
[DANN-mem 008/20] src=99.76 | tgt=28.58 | dom_acc=73.6 | λ=0.964
[DANN-mem 009/20] src=99.64 | tgt=60.91 | dom_acc=66.1 | λ=0.978
[DANN-mem 010/20] src=99.82 | tgt=23.52 | dom_acc=64.1 | λ=0.987
[DANN-mem 011/20] src=100.00 | tgt=61.01 | dom_acc=68.0 | λ=0.992
[DANN-mem 012/20] src=100.00 | tgt=46.12 | dom_acc=58.8 | λ=0.995
[DANN-mem 013/20] src=99.94 | tgt=60.19 | dom_acc=74.2 | λ=0.997
[DANN-mem 014/20] src=99.94 | tgt=54.36 | dom_acc=68.7 | λ=0.998
[DANN-mem 015/20] src=99.94 | tgt=57.60 | dom_acc=67.9 | λ=0.999
[DANN-mem 016/20] src=9

Final evaluation + domain-shift proxy (A-distance) + summary

In [None]:
# === Cell H (mem-lite): DANN final eval + A-distance proxy + summaries ===
import json, numpy as np, torch
from itertools import islice
from torch.utils.data import DataLoader

device = torch.device(Cfg.DEVICE)

# --- rebuild the *mem-lite* DANN you trained ---
model = DANN(num_classes=Cfg.NUM_CLASSES, backbone=getattr(Cfg, "BACKBONE", "resnet50"), pretrained=False).to(device)
state = torch.load(OUT_DIR / "ckpts" / "best_dann.pt", map_location=device)
model.load_state_dict(state)
model.eval()

# --- lighter eval loaders (smaller batch, no workers/pin) ---
def make_small_loader(big_loader, bs=32):
    return DataLoader(
        big_loader.dataset, batch_size=min(bs, getattr(Cfg, "BATCH_SIZE", 64)),
        shuffle=False, num_workers=0, pin_memory=False
    )
src_test_small = make_small_loader(loaders["src_test"], bs=32)
tgt_test_small = make_small_loader(loaders["tgt_test"], bs=32)

@torch.no_grad()
def eval_acc_cm(backbone, loader, n_classes):
    """Streamed accuracy + confusion (no big arrays)."""
    cm = np.zeros((n_classes, n_classes), dtype=np.int64)
    total = correct = 0
    for x, y in loader:
        x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=getattr(Cfg, "MIXED_PREC", True)):
            # forward with classifier head (backbone is ResNet_Feature)
            f = backbone(x, return_feat=False, class_head=False)
            logits = backbone.classifier(f)
            preds = logits.argmax(1)
        correct += (preds == y).sum().item()
        total += y.numel()
        # update confusion
        for t, p in zip(y.view(-1).tolist(), preds.view(-1).tolist()):
            cm[t, p] += 1
    acc = 100.0 * correct / max(total, 1)
    # per-class accuracy
    percls = {}
    for i in range(n_classes):
        denom = cm[i].sum()
        percls[list(IDX2CLASS.values())[i]] = (100.0 * cm[i, i] / denom) if denom > 0 else 0.0
    return acc, cm, percls

# --- classification eval (source/target) ---
src_acc, src_cm, src_percls = eval_acc_cm(model.backbone, src_test_small, Cfg.NUM_CLASSES)
tgt_acc, tgt_cm, tgt_percls = eval_acc_cm(model.backbone, tgt_test_small, Cfg.NUM_CLASSES)

# save confusion figures
plot_confusion(src_cm, IDX2CLASS, f"DANN Source ({Cfg.SOURCE_DOMAIN})", OUT_DIR / "figs" / "dann_cm_source.png")
plot_confusion(tgt_cm, IDX2CLASS, f"DANN Target ({Cfg.TARGET_DOMAIN})", OUT_DIR / "figs" / "dann_cm_target.png")

# --- domain proxy (single concatenated forward; capped batches) ---
@torch.no_grad()
def domain_accuracy_fast(model, src_loader, tgt_loader, max_batches=30):
    correct = total = 0
    it_src, it_tgt = iter(src_loader), iter(tgt_loader)
    for _ in range(max_batches):
        try:
            xs, _ = next(it_src); xt, _ = next(it_tgt)
        except StopIteration:
            break
        bs = xs.size(0)
        x = torch.cat([xs, xt], 0).to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=getattr(Cfg, "MIXED_PREC", True)):
            # one backbone pass for features only
            f_all = model.backbone(x, return_feat=False, class_head=False)
            f_s, f_t = f_all[:bs], f_all[bs:]
            logits = torch.cat([model.forward_domain(f_s, grl_lambda=0.0),
                                model.forward_domain(f_t, grl_lambda=0.0)], 0)
        labels = torch.cat([torch.ones(bs, device=device),
                            torch.zeros(f_t.size(0), device=device)], 0)
        pred = (torch.sigmoid(logits) > 0.5).long()
        correct += (pred == labels.long()).sum().item()
        total += labels.numel()
    return 100.0 * correct / max(total, 1)

dom_acc_eval = domain_accuracy_fast(model, src_test_small, tgt_test_small, max_batches=30)
dom_err = 1.0 - dom_acc_eval / 100.0
a_distance_hat = float(2.0 * (1.0 - 2.0 * dom_err))  # 2(1-2ε)

avg_domain_acc = float((src_acc + tgt_acc) / 2.0)
worst_group_acc = float(min(src_acc, tgt_acc))

summary = {
    "exp_name": Cfg.EXP_NAME,
    "domains": {"source": Cfg.SOURCE_DOMAIN, "target": Cfg.TARGET_DOMAIN},
    "method": "DANN (mem-lite eval)",
    "metrics": {
        "source_acc": float(src_acc),
        "target_acc": float(tgt_acc),
        "avg_domain_acc": avg_domain_acc,
        "worst_group_acc": worst_group_acc,
        "domain_classifier_acc_eval": float(dom_acc_eval),
        "a_distance_hat": a_distance_hat
    },
    "per_class_source": src_percls,
    "per_class_target": tgt_percls,
    "artifacts": {
        "ckpt_dann": str(OUT_DIR / "ckpts" / "best_dann.pt"),
        "ckpt_dann_backbone": str(OUT_DIR / "ckpts" / "best_dann_backbone.pt"),
        "history": str(OUT_DIR / "history_dann.json"),
        "acc_curve": str(OUT_DIR / "figs" / "dann_acc_curve.png"),
        "domain_acc_curve": str(OUT_DIR / "figs" / "dann_domain_acc_curve.png"),
        "cm_source": str(OUT_DIR / "figs" / "dann_cm_source.png"),
        "cm_target": str(OUT_DIR / "figs" / "dann_cm_target.png"),
        "summary_json": str(OUT_DIR / "summary_dann.json"),
    },
    "notes": "DANN with GRL; memory-lite eval: small loaders, streamed confusion, single-pass domain proxy."
}

with open(OUT_DIR / "summary_dann.json", "w") as f:
    json.dump(summary, f, indent=2)

print(json.dumps(summary["metrics"], indent=2))
print(f"Saved DANN summary & figures to: {OUT_DIR}")


  with torch.cuda.amp.autocast(enabled=getattr(Cfg, "MIXED_PREC", True)):
  with torch.cuda.amp.autocast(enabled=getattr(Cfg, "MIXED_PREC", True)):


{
  "source_acc": 99.88023952095809,
  "target_acc": 62.17867141766353,
  "avg_domain_acc": 81.0294554693108,
  "worst_group_acc": 62.17867141766353,
  "domain_classifier_acc_eval": 58.645833333333336,
  "a_distance_hat": 0.3458333333333332
}
Saved DANN summary & figures to: /content/drive/MyDrive/DG_PACS/Task1/T1.2_DANN_PACS_photo2sketch_ResNet50_memlite


DAN (MMD) utilities + training

In [None]:
# === Cell I (mem-lite): DAN utils + training ===
import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np, json, gc, matplotlib.pyplot as plt
from itertools import cycle
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR

device = torch.device(Cfg.DEVICE)

# ---- Tiny bottleneck + classifier (reuse backbone features) ----
class DAN_Head(nn.Module):
    def __init__(self, in_dim, num_classes, hid=256, p=0.2):
        super().__init__()
        self.bottleneck = nn.Sequential(
            nn.Linear(in_dim, hid),
            nn.BatchNorm1d(hid),
            nn.ReLU(inplace=True),
            nn.Dropout(p),
        )
        self.classifier = nn.Linear(hid, num_classes)
    def forward(self, f, return_feat=False):
        z = self.bottleneck(f)
        logits = self.classifier(z)
        return (logits, z) if return_feat else logits

# ---- Unbiased multi-kernel RBF MMD (memory-lite) ----
def _pairwise_sq_dists(x, y):
    xx = (x @ x.t())
    yy = (y @ y.t())
    xy = (x @ y.t())
    x_sq = xx.diag().unsqueeze(1)
    y_sq = yy.diag().unsqueeze(0)
    dist_xx = x_sq + x_sq.t() - 2*xx
    dist_yy = y_sq + y_sq.t() - 2*yy
    dist_xy = x_sq + y_sq - 2*xy
    return dist_xx, dist_yy, dist_xy

def mmd_unbiased(x, y, sigmas=(2., 5., 10.)):
    dist_xx, dist_yy, dist_xy = _pairwise_sq_dists(x, y)
    n, m = x.size(0), y.size(0)
    Kxx = Kyy = Kxy = 0.0
    for s in sigmas:
        g = 1.0 / (2.0 * s * s)
        Kxx = Kxx + torch.exp(-g * dist_xx)
        Kyy = Kyy + torch.exp(-g * dist_yy)
        Kxy = Kxy + torch.exp(-g * dist_xy)
    Kxx = (Kxx.sum() - Kxx.diag().sum()) / max(n*(n-1), 1)
    Kyy = (Kyy.sum() - Kyy.diag().sum()) / max(m*(m-1), 1)
    Kxy = Kxy.mean()
    return Kxx + Kyy - 2.0 * Kxy

# ---- Build model (reuse feature backbone defined earlier) ----
Cfg.EXP_NAME = "T1.3_DAN_PACS_%s2%s_ResNet50_memlite" % (Cfg.SOURCE_DOMAIN, Cfg.TARGET_DOMAIN)
OUT_DIR = make_out_dirs()

dan_backbone = ResNet_Feature(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=True).to(device)
feat_dim = dan_backbone.feat_dim
dan_head = DAN_Head(in_dim=feat_dim, num_classes=Cfg.NUM_CLASSES, hid=256, p=0.2).to(device)

# Optim & sched (only train head + backbone; backbone is full but you can optionally freeze more below)
for i, p in enumerate(dan_backbone.parameters()):
    p.requires_grad = True  # set False for first N params if you want extra savings

params = list(dan_backbone.parameters()) + list(dan_head.parameters())
optimizer = SGD(params, lr=Cfg.LR, momentum=Cfg.MOMENTUM, weight_decay=Cfg.WD, nesterov=True)
scheduler = CosineAnnealingLR(optimizer, T_max=Cfg.EPOCHS)
clf_criterion = LabelSmoothingCE(eps=Cfg.LABEL_SMOOTH)

scaler = torch.cuda.amp.GradScaler(enabled=getattr(Cfg, "MIXED_PREC", True))

# Use existing loaders: loaders["src_train"], tgt_train_loader from Cell F
src_iter = cycle(loaders["src_train"])
tgt_iter = cycle(tgt_train_loader)
batches_per_epoch = max(len(loaders["src_train"]), len(tgt_train_loader))
accum_steps = max(1, int(getattr(Cfg, "ACCUM_STEPS", 1)))

history_dan = {"epoch": [], "src_acc": [], "tgt_acc": [], "mmd": [], "lr": []}
best_tgt = -1.0

for epoch in range(1, Cfg.EPOCHS+1):
    dan_backbone.train(); dan_head.train()
    optimizer.zero_grad(set_to_none=True)

    mmd_running = 0.0
    for step in range(batches_per_epoch):
        xs, ys = next(src_iter)
        xt, _  = next(tgt_iter)
        x = torch.cat([xs, xt], 0).to(device, non_blocking=True)
        ys = ys.to(device, non_blocking=True)
        bs = xs.size(0)

        with torch.cuda.amp.autocast(enabled=getattr(Cfg, "MIXED_PREC", True)):
            # one forward for features
            f_all = dan_backbone(x, return_feat=False, class_head=False)  # [B_s+B_t, D]
            fs, ft = f_all[:bs], f_all[bs:]

            # classification only on source
            logits_s, zs = dan_head(fs, return_feat=True)
            zt = dan_head.bottleneck(ft)  # bottleneck for target (no cls)
            loss_ce = clf_criterion(logits_s, ys)
            loss_mmd = mmd_unbiased(zs, zt)
            loss = (loss_ce + 0.5 * loss_mmd) / accum_steps  # λ=0.5 by default

        scaler.scale(loss).backward()
        if (step + 1) % accum_steps == 0:
            nn.utils.clip_grad_norm_(params, max_norm=5.0)
            scaler.step(optimizer); scaler.update()
            optimizer.zero_grad(set_to_none=True)

        mmd_running += float(loss_mmd.detach())

        # free ASAP
        del x, xs, xt, ys, f_all, fs, ft, logits_s, zs, zt

    if (batches_per_epoch % accum_steps) != 0:
        nn.utils.clip_grad_norm_(params, max_norm=5.0)
        scaler.step(optimizer); scaler.update()
        optimizer.zero_grad(set_to_none=True)

    scheduler.step()
    if torch.cuda.is_available(): torch.cuda.empty_cache(); gc.collect()

    # eval using your small streamed routine from Cell H pattern
    @torch.no_grad()
    def _eval_acc(backbone, head, loader):
        backbone.eval(); head.eval()
        total = correct = 0
        for x, y in loader:
            x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=getattr(Cfg, "MIXED_PREC", True)):
                f = backbone(x, return_feat=False, class_head=False)
                logits = head(f)
                pred = logits.argmax(1)
            correct += (pred == y).sum().item(); total += y.numel()
        return 100.0 * correct / max(total, 1)

    src_acc = _eval_acc(dan_backbone, dan_head, loaders["src_test"])
    tgt_acc = _eval_acc(dan_backbone, dan_head, loaders["tgt_test"])

    history_dan["epoch"].append(epoch)
    history_dan["src_acc"].append(float(src_acc))
    history_dan["tgt_acc"].append(float(tgt_acc))
    history_dan["mmd"].append(float(mmd_running / max(batches_per_epoch,1)))
    history_dan["lr"].append(optimizer.param_groups[0]["lr"])

    print(f"[DAN {epoch:03d}/{Cfg.EPOCHS}] src={src_acc:.2f} | tgt={tgt_acc:.2f} | mmd≈{history_dan['mmd'][-1]:.4f}")

    if tgt_acc > best_tgt:
        best_tgt = tgt_acc
        torch.save(dan_backbone.state_dict(), OUT_DIR / "ckpts" / "best_dan_backbone.pt")
        torch.save(dan_head.state_dict(),      OUT_DIR / "ckpts" / "best_dan_head.pt")

# save curves
save_json(history_dan, OUT_DIR / "history_dan.json")
plt.figure(); plt.plot(history_dan["epoch"], history_dan["src_acc"], label="Source"); plt.plot(history_dan["epoch"], history_dan["tgt_acc"], label="Target")
plt.xlabel("Epoch"); plt.ylabel("Acc (%)"); plt.legend(); plt.title("DAN (mem-lite): Acc"); plt.savefig(OUT_DIR / "figs" / "dan_acc_curve.png", bbox_inches="tight", dpi=160); plt.close()
plt.figure(); plt.plot(history_dan["epoch"], history_dan["mmd"], label="MMD"); plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.title("DAN: MMD"); plt.savefig(OUT_DIR / "figs" / "dan_mmd_curve.png", bbox_inches="tight", dpi=160); plt.close()
print(f"Saved DAN artifacts to {OUT_DIR}")


  scaler = torch.cuda.amp.GradScaler(enabled=getattr(Cfg, "MIXED_PREC", True))
  with torch.cuda.amp.autocast(enabled=getattr(Cfg, "MIXED_PREC", True)):
  with torch.cuda.amp.autocast(enabled=getattr(Cfg, "MIXED_PREC", True)):


[DAN 001/20] src=14.19 | tgt=15.40 | mmd≈-0.0114
[DAN 002/20] src=14.67 | tgt=15.78 | mmd≈-0.0118
[DAN 003/20] src=14.55 | tgt=15.30 | mmd≈-0.0115
[DAN 004/20] src=14.97 | tgt=15.81 | mmd≈-0.0119
[DAN 005/20] src=14.43 | tgt=15.45 | mmd≈-0.0114
[DAN 006/20] src=15.39 | tgt=15.68 | mmd≈-0.0117
[DAN 007/20] src=15.21 | tgt=15.22 | mmd≈-0.0114
[DAN 008/20] src=15.15 | tgt=15.65 | mmd≈-0.0120
[DAN 009/20] src=15.51 | tgt=15.93 | mmd≈-0.0119
[DAN 010/20] src=15.03 | tgt=15.42 | mmd≈-0.0114
[DAN 011/20] src=15.75 | tgt=15.81 | mmd≈-0.0120
[DAN 012/20] src=15.51 | tgt=15.30 | mmd≈-0.0115
[DAN 013/20] src=15.45 | tgt=15.75 | mmd≈-0.0120
[DAN 014/20] src=15.39 | tgt=15.63 | mmd≈-0.0115
[DAN 015/20] src=15.93 | tgt=15.73 | mmd≈-0.0120
[DAN 016/20] src=16.05 | tgt=15.37 | mmd≈-0.0113
[DAN 017/20] src=15.81 | tgt=15.65 | mmd≈-0.0118
[DAN 018/20] src=15.87 | tgt=15.91 | mmd≈-0.0118
[DAN 019/20] src=15.75 | tgt=15.40 | mmd≈-0.0114
[DAN 020/20] src=16.11 | tgt=15.88 | mmd≈-0.0118
Saved DAN artifacts 

DAN final evaluation + summary

In [None]:
# === Cell J (mem-lite): DAN final eval + rare-class F1 + proxy distance + summary ===
import numpy as np, json, torch
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
import torch.nn as nn

device = torch.device(Cfg.DEVICE)

# rebuild DAN modules (match Cell I head exactly)
dan_backbone = ResNet_Feature(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=False).to(device)
dan_backbone.load_state_dict(torch.load(OUT_DIR / "ckpts" / "best_dan_backbone.pt", map_location=device))

class DAN_Head(nn.Module):
    def __init__(self, in_dim, num_classes, hid=256, p=0.2):
        super().__init__()
        self.bottleneck = nn.Sequential(
            nn.Linear(in_dim, hid),
            nn.BatchNorm1d(hid),
            nn.ReLU(inplace=True),
            nn.Dropout(p),
        )
        self.classifier = nn.Linear(hid, num_classes)
    def forward(self, f):
        return self.classifier(self.bottleneck(f))

dan_head = DAN_Head(dan_backbone.feat_dim, Cfg.NUM_CLASSES).to(device)
dan_head.load_state_dict(torch.load(OUT_DIR / "ckpts" / "best_dan_head.pt", map_location=device))
dan_backbone.eval(); dan_head.eval()

# small loaders for eval
def make_small_loader(big_loader, bs=32):
    return DataLoader(big_loader.dataset, batch_size=min(bs, getattr(Cfg, "BATCH_SIZE", 64)),
                      shuffle=False, num_workers=0, pin_memory=False)
src_small = make_small_loader(loaders["src_test"], bs=32)
tgt_small = make_small_loader(loaders["tgt_test"], bs=32)

@torch.no_grad()
def eval_acc_cm(backbone, head, loader, idx2class):
    cm = np.zeros((len(idx2class), len(idx2class)), dtype=np.int64)
    total = correct = 0
    for x, y in loader:
        x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
        with torch.amp.autocast('cuda', enabled=getattr(Cfg, "MIXED_PREC", True)):
            f = backbone(x, return_feat=False, class_head=False)
            logits = head(f); preds = logits.argmax(1)
        correct += (preds == y).sum().item(); total += y.numel()
        for t, p in zip(y.tolist(), preds.tolist()): cm[t, p] += 1
    acc = 100.0 * correct / max(total, 1)
    percls = {}
    for i in range(len(idx2class)):
        denom = cm[i].sum()
        percls[idx2class[i]] = (100.0 * cm[i, i] / denom) if denom > 0 else 0.0
    return acc, cm, percls

src_acc, src_cm, src_percls = eval_acc_cm(dan_backbone, dan_head, src_small, IDX2CLASS)
tgt_acc, tgt_cm, tgt_percls = eval_acc_cm(dan_backbone, dan_head, tgt_small, IDX2CLASS)

# rare-3 classes (stream counts to stay memory-light)
counts = np.zeros(Cfg.NUM_CLASSES, dtype=np.int64)
for _, yb in tgt_small:
    y_np = yb.numpy() if hasattr(yb, "numpy") else np.array(yb)
    for t in y_np: counts[t] += 1
r3 = np.argsort(counts)[:3].tolist()

@torch.no_grad()
def collect_preds(backbone, head, loader):
    ys = []; yh = []
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        with torch.amp.autocast('cuda', enabled=getattr(Cfg, "MIXED_PREC", True)):
            f = backbone(x, return_feat=False, class_head=False)
            logits = head(f); preds = logits.argmax(1)
        ys.append(y.numpy()); yh.append(preds.cpu().numpy())
    return np.concatenate(ys), np.concatenate(yh)

y_true, y_pred = collect_preds(dan_backbone, dan_head, tgt_small)
rare3_f1 = float(f1_score(y_true, y_pred, labels=r3, average="macro"))

# ---- proxy domain distance (dtype-safe: ensure FP32) ----
@torch.no_grad()
def collect_feats(backbone, loader, max_batches=40):
    F = []
    it = iter(loader)
    for _ in range(max_batches):
        try: x, _ = next(it)
        except StopIteration: break
        x = x.to(device, non_blocking=True)
        with torch.amp.autocast('cuda', enabled=getattr(Cfg, "MIXED_PREC", True)):
            f = backbone(x, return_feat=False, class_head=False)
        F.append(f.detach().cpu())
    if len(F) == 0:
        return torch.zeros(0, dan_backbone.feat_dim, dtype=torch.float32)
    return torch.cat(F).float()  # <- force FP32 to avoid Half/Float mismatch

Xs = collect_feats(dan_backbone, src_small, max_batches=40)
Xt = collect_feats(dan_backbone, tgt_small, max_batches=40)

X  = torch.cat([Xs, Xt], 0).to(device).float()  # FP32
y  = torch.cat([torch.ones(Xs.size(0)), torch.zeros(Xt.size(0))], 0).to(device).float()

lin = nn.Linear(dan_backbone.feat_dim, 1).to(device).float()
opt = torch.optim.SGD(lin.parameters(), lr=0.05)
bce = nn.BCEWithLogitsLoss()

lin.train()
for _ in range(120):
    idx = torch.randperm(X.size(0), device=device)[:256]
    xb, yb = X[idx], y[idx]  # already float32
    opt.zero_grad(set_to_none=True)
    # keep probe training in FP32 (no autocast)
    loss = bce(lin(xb).squeeze(1), yb)
    loss.backward()
    opt.step()

lin.eval()
with torch.no_grad():
    pred = (torch.sigmoid(lin(X).squeeze(1)) > 0.5).float()
    dom_err = 1.0 - (pred == y).float().mean().item()
a_distance_hat = float(2.0 * (1.0 - 2.0 * dom_err))

# save plots
plot_confusion(src_cm, IDX2CLASS, f"DAN Source ({Cfg.SOURCE_DOMAIN})", OUT_DIR / "figs" / "dan_cm_source.png")
plot_confusion(tgt_cm, IDX2CLASS, f"DAN Target ({Cfg.TARGET_DOMAIN})", OUT_DIR / "figs" / "dan_cm_target.png")

avg_domain_acc = float((src_acc + tgt_acc) / 2.0)
worst_group_acc = float(min(src_acc, tgt_acc))

summary = {
    "exp_name": Cfg.EXP_NAME,
    "domains": {"source": Cfg.SOURCE_DOMAIN, "target": Cfg.TARGET_DOMAIN},
    "method": "DAN (mem-lite)",
    "metrics": {
        "source_acc": float(src_acc),
        "target_acc": float(tgt_acc),
        "avg_domain_acc": avg_domain_acc,
        "worst_group_acc": worst_group_acc,
        "rare3": r3,
        "rare3_f1": rare3_f1,
        "a_distance_hat": a_distance_hat
    },
    "artifacts": {
        "ckpt_dan_backbone": str(OUT_DIR / "ckpts" / "best_dan_backbone.pt"),
        "ckpt_dan_head": str(OUT_DIR / "ckpts" / "best_dan_head.pt"),
        "history": str(OUT_DIR / "history_dan.json"),
        "acc_curve": str(OUT_DIR / "figs" / "dan_acc_curve.png"),
        "mmd_curve": str(OUT_DIR / "figs" / "dan_mmd_curve.png"),
        "cm_source": str(OUT_DIR / "figs" / "dan_cm_source.png"),
        "cm_target": str(OUT_DIR / "figs" / "dan_cm_target.png"),
        "summary_json": str(OUT_DIR / "summary_dan.json")
    }
}

with open(OUT_DIR / "summary_dan.json", "w") as f:
    json.dump(summary, f, indent=2)

print(json.dumps(summary["metrics"], indent=2))
print(f"Saved DAN summary & figures to: {OUT_DIR}")


{
  "source_acc": 15.449101796407186,
  "target_acc": 15.932807330109442,
  "avg_domain_acc": 15.690954563258314,
  "worst_group_acc": 15.449101796407186,
  "rare3": [
    5,
    6,
    3
  ],
  "rare3_f1": 0.07937603542595735,
  "a_distance_hat": 2.0
}
Saved DAN summary & figures to: /content/drive/MyDrive/DG_PACS/Task1/T1.3_DAN_PACS_photo2sketch_ResNet50_memlite


CDAN utilities + training

In [None]:
# === Cell K (mem-lite): CDAN utils + training — FIXED ===
import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np, json, gc, matplotlib.pyplot as plt
from itertools import cycle
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR

device = torch.device(Cfg.DEVICE)

class CDAN_Head(nn.Module):
    def __init__(self, in_dim, num_classes, hid=256, p=0.2):
        super().__init__()
        self.bottleneck = nn.Sequential(
            nn.Linear(in_dim, hid),
            nn.BatchNorm1d(hid),
            nn.ReLU(inplace=True),
            nn.Dropout(p),
        )
        self.classifier = nn.Linear(hid, num_classes)
    def forward(self, f, return_feat=False, return_prob=False):
        z = self.bottleneck(f)
        logits = self.classifier(z)
        if return_prob:
            probs = F.softmax(logits, dim=1)
            return logits, z, probs
        return (logits, z) if return_feat else logits

class Compress(nn.Module):
    def __init__(self, feat_dim, num_classes, out_dim=512):
        super().__init__()
        self.lin = nn.Linear(feat_dim * num_classes, out_dim)
    def forward(self, z, p):
        op = torch.bmm(p.unsqueeze(2), z.unsqueeze(1))  # [B,C,F]
        op = op.view(op.size(0), -1)                    # [B, C*F]
        return self.lin(op)

class SmallDomainDisc(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 256), nn.ReLU(True),
            nn.Linear(256, 128), nn.ReLU(True),
            nn.Linear(128, 1)
        )
    def forward(self, x):
        return self.net(x).squeeze(1)  # -> [B]

def grl(x, lambd=1.0): return grad_reverse(x, lambd)

Cfg.EXP_NAME = f"T1.4_CDAN_PACS_{Cfg.SOURCE_DOMAIN}2{Cfg.TARGET_DOMAIN}_ResNet50_memlite"
OUT_DIR = make_out_dirs()

cdan_backbone = ResNet_Feature(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=True).to(device)
feat_dim = cdan_backbone.feat_dim
cdan_head = CDAN_Head(feat_dim, Cfg.NUM_CLASSES, hid=256, p=0.2).to(device)
compress  = Compress(256, Cfg.NUM_CLASSES, out_dim=512).to(device)
d_disc    = SmallDomainDisc(512).to(device)

params_main = list(cdan_backbone.parameters()) + list(cdan_head.parameters()) + list(compress.parameters())
opt_main = SGD(params_main, lr=Cfg.LR, momentum=Cfg.MOMENTUM, weight_decay=Cfg.WD, nesterov=True)
opt_d    = SGD(d_disc.parameters(), lr=Cfg.LR, momentum=Cfg.MOMENTUM, weight_decay=Cfg.WD)
sched_main = CosineAnnealingLR(opt_main, T_max=Cfg.EPOCHS)

clf_criterion = LabelSmoothingCE(eps=Cfg.LABEL_SMOOTH)
bce = nn.BCEWithLogitsLoss()

# ✅ New AMP scaler API
scaler = torch.amp.GradScaler('cuda', enabled=getattr(Cfg, "MIXED_PREC", True))

src_iter = cycle(loaders["src_train"])
tgt_iter = cycle(tgt_train_loader)
batches_per_epoch = max(len(loaders["src_train"]), len(tgt_train_loader))
accum_steps = max(1, int(getattr(Cfg, "ACCUM_STEPS", 1)))

def entropy_weight(p):
    ent = -(p.clamp_min(1e-6) * p.clamp_min(1e-6).log()).sum(1)
    ent = ent / np.log(Cfg.NUM_CLASSES)
    return (1.0 + torch.exp(-ent)).detach()

history_cdan = {"epoch": [], "src_acc": [], "tgt_acc": [], "dom_loss": [], "lr": []}
best_tgt = -1.0

for epoch in range(1, Cfg.EPOCHS+1):
    cdan_backbone.train(); cdan_head.train(); compress.train(); d_disc.train()
    opt_main.zero_grad(set_to_none=True); opt_d.zero_grad(set_to_none=True)

    dom_loss_sum = 0.0

    for step in range(batches_per_epoch):
        xs, ys = next(src_iter)
        xt, _  = next(tgt_iter)
        x = torch.cat([xs, xt], 0).to(device, non_blocking=True)
        ys = ys.to(device, non_blocking=True)
        bs = xs.size(0)

        # ---- main (cls + adv via GRL) ----
        with torch.amp.autocast('cuda', enabled=getattr(Cfg, "MIXED_PREC", True)):
            f_all = cdan_backbone(x, return_feat=False, class_head=False)
            fs, ft = f_all[:bs], f_all[bs:]

            logits_s, zs, ps = cdan_head(fs, return_feat=True, return_prob=True)
            loss_ce = clf_criterion(logits_s, ys)

            _, zt, pt = cdan_head(ft, return_feat=True, return_prob=True)

            cond_s = compress(zs, ps)
            cond_t = compress(zt, pt)
            cond   = torch.cat([cond_s, cond_t], 0)

            dom_logits = d_disc(grl(cond, 1.0))  # already [B]; no .squeeze(1)
            dom_labels = torch.cat([torch.ones(bs, device=device),
                                    torch.zeros(cond_t.size(0), device=device)], 0)
            w = torch.cat([entropy_weight(ps), entropy_weight(pt)], 0)
            loss_adv = (bce(dom_logits, dom_labels) * w).mean()

            loss = (loss_ce + 0.5 * loss_adv) / accum_steps

        scaler.scale(loss).backward()
        if (step + 1) % accum_steps == 0:
            nn.utils.clip_grad_norm_(params_main, max_norm=5.0)
            scaler.step(opt_main); scaler.update()
            opt_main.zero_grad(set_to_none=True)

        # ---- update D on detached cond ----
        with torch.no_grad():
            cond_s_d = cond_s.detach(); cond_t_d = cond_t.detach()
            dom_in = torch.cat([cond_s_d, cond_t_d], 0)
            dom_y  = torch.cat([torch.ones(bs, device=device),
                                torch.zeros(cond_t_d.size(0), device=device)], 0)

        opt_d.zero_grad(set_to_none=True)
        with torch.amp.autocast('cuda', enabled=getattr(Cfg, "MIXED_PREC", True)):
            dom_logits_d = d_disc(dom_in)         # shape [B]; ✅ no .squeeze(1)
            loss_d = bce(dom_logits_d, dom_y)     # dom_y is [B]
        scaler.scale(loss_d).backward()
        nn.utils.clip_grad_norm_(d_disc.parameters(), max_norm=5.0)
        scaler.step(opt_d); scaler.update()

        dom_loss_sum += float(loss_d.detach())

        del x, xs, xt, ys, f_all, fs, ft, logits_s, zs, ps, zt, pt, cond_s, cond_t, cond, dom_logits, dom_labels, w, dom_logits_d, dom_y

    sched_main.step()
    if torch.cuda.is_available(): torch.cuda.empty_cache(); gc.collect()

    @torch.no_grad()
    def _eval_acc(backbone, head, loader):
        backbone.eval(); head.eval()
        total = correct = 0
        for x, y in loader:
            x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
            with torch.amp.autocast('cuda', enabled=getattr(Cfg, "MIXED_PREC", True)):
                f = backbone(x, return_feat=False, class_head=False)
                logits = head(f)
                pred = logits.argmax(1)
            correct += (pred == y).sum().item(); total += y.numel()
        return 100.0 * correct / max(total, 1)

    src_acc = _eval_acc(cdan_backbone, cdan_head, loaders["src_test"])
    tgt_acc = _eval_acc(cdan_backbone, cdan_head, loaders["tgt_test"])

    history_cdan["epoch"].append(epoch)
    history_cdan["src_acc"].append(float(src_acc))
    history_cdan["tgt_acc"].append(float(tgt_acc))
    history_cdan["dom_loss"].append(float(dom_loss_sum / max(batches_per_epoch,1)))
    history_cdan["lr"].append(opt_main.param_groups[0]["lr"])

    print(f"[CDAN {epoch:03d}/{Cfg.EPOCHS}] src={src_acc:.2f} | tgt={tgt_acc:.2f} | dom_loss≈{history_cdan['dom_loss'][-1]:.4f}")

    if tgt_acc > best_tgt:
        best_tgt = tgt_acc
        torch.save(cdan_backbone.state_dict(), OUT_DIR / "ckpts" / "best_cdan_backbone.pt")
        torch.save(cdan_head.state_dict(),      OUT_DIR / "ckpts" / "best_cdan_head.pt")
        torch.save(compress.state_dict(),       OUT_DIR / "ckpts" / "best_cdan_compress.pt")
        torch.save(d_disc.state_dict(),         OUT_DIR / "ckpts" / "best_cdan_ddisc.pt")

save_json(history_cdan, OUT_DIR / "history_cdan.json")
plt.figure(); plt.plot(history_cdan["epoch"], history_cdan["src_acc"], label="Source"); plt.plot(history_cdan["epoch"], history_cdan["tgt_acc"], label="Target")
plt.xlabel("Epoch"); plt.ylabel("Acc (%)"); plt.legend(); plt.title("CDAN (mem-lite): Acc")
plt.savefig(OUT_DIR / "figs" / "cdan_acc_curve.png", bbox_inches="tight", dpi=160); plt.close()

plt.figure(); plt.plot(history_cdan["epoch"], history_cdan["dom_loss"], label="Domain Loss")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.title("CDAN: Domain Loss")
plt.savefig(OUT_DIR / "figs" / "cdan_dom_loss_curve.png", bbox_inches="tight", dpi=160); plt.close()

print(f"Saved CDAN artifacts to {OUT_DIR}")


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 238MB/s]


[CDAN 001/20] src=14.91 | tgt=15.65 | dom_loss≈0.6945
[CDAN 002/20] src=14.73 | tgt=16.09 | dom_loss≈0.6945
[CDAN 003/20] src=15.39 | tgt=15.83 | dom_loss≈0.6944
[CDAN 004/20] src=14.85 | tgt=15.93 | dom_loss≈0.6944
[CDAN 005/20] src=15.63 | tgt=15.96 | dom_loss≈0.6944
[CDAN 006/20] src=14.91 | tgt=16.06 | dom_loss≈0.6943
[CDAN 007/20] src=15.69 | tgt=16.06 | dom_loss≈0.6942
[CDAN 008/20] src=15.09 | tgt=15.78 | dom_loss≈0.6941
[CDAN 009/20] src=15.39 | tgt=16.39 | dom_loss≈0.6940
[CDAN 010/20] src=15.45 | tgt=15.60 | dom_loss≈0.6939
[CDAN 011/20] src=15.15 | tgt=15.93 | dom_loss≈0.6938
[CDAN 012/20] src=15.87 | tgt=15.86 | dom_loss≈0.6938
[CDAN 013/20] src=15.09 | tgt=16.14 | dom_loss≈0.6937
[CDAN 014/20] src=16.11 | tgt=15.88 | dom_loss≈0.6936
[CDAN 015/20] src=15.15 | tgt=16.01 | dom_loss≈0.6935
[CDAN 016/20] src=15.75 | tgt=16.01 | dom_loss≈0.6934
[CDAN 017/20] src=15.45 | tgt=15.65 | dom_loss≈0.6934
[CDAN 018/20] src=15.63 | tgt=16.47 | dom_loss≈0.6933
[CDAN 019/20] src=15.87 | tg

CDAN final evaluation + A-distance proxy + summary

In [None]:
# === Cell L (mem-lite): CDAN final eval + rare-class F1 + proxy distance + summary ===
import numpy as np, json, torch
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
import torch.nn as nn

device = torch.device(Cfg.DEVICE)

# --- rebuild CDAN modules EXACTLY like Cell K ---
class CDAN_Head(nn.Module):
    def __init__(self, in_dim, num_classes, hid=256, p=0.2):
        super().__init__()
        self.bottleneck = nn.Sequential(
            nn.Linear(in_dim, hid),
            nn.BatchNorm1d(hid),
            nn.ReLU(inplace=True),
            nn.Dropout(p),
        )
        self.classifier = nn.Linear(hid, num_classes)
    def forward(self, f):
        z = self.bottleneck(f)
        return self.classifier(z)

cdan_backbone = ResNet_Feature(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=False).to(device)
cdan_backbone.load_state_dict(torch.load(OUT_DIR / "ckpts" / "best_cdan_backbone.pt", map_location=device))
cdan_head = CDAN_Head(cdan_backbone.feat_dim, Cfg.NUM_CLASSES).to(device)
cdan_head.load_state_dict(torch.load(OUT_DIR / "ckpts" / "best_cdan_head.pt", map_location=device))
cdan_backbone.eval(); cdan_head.eval()

# --- small loaders for eval ---
def make_small_loader(big_loader, bs=32):
    return DataLoader(big_loader.dataset, batch_size=min(bs, getattr(Cfg, "BATCH_SIZE", 64)),
                      shuffle=False, num_workers=0, pin_memory=False)
src_small = make_small_loader(loaders["src_test"], bs=32)
tgt_small = make_small_loader(loaders["tgt_test"], bs=32)

@torch.no_grad()
def eval_acc_cm(backbone, head, loader, idx2class):
    cm = np.zeros((len(idx2class), len(idx2class)), dtype=np.int64)
    total = correct = 0
    for x, y in loader:
        x = x.to(device, non_blocking=True); y = y.to(device, non_blocking=True)
        with torch.amp.autocast('cuda', enabled=getattr(Cfg, "MIXED_PREC", True)):
            f = backbone(x, return_feat=False, class_head=False)
            logits = head(f); preds = logits.argmax(1)
        correct += (preds == y).sum().item(); total += y.numel()
        for t, p in zip(y.tolist(), preds.tolist()):
            cm[t, p] += 1
    acc = 100.0 * correct / max(total, 1)
    percls = {}
    for i in range(len(idx2class)):
        denom = cm[i].sum()
        percls[idx2class[i]] = (100.0 * cm[i, i] / denom) if denom > 0 else 0.0
    return acc, cm, percls

# --- classification eval ---
src_acc, src_cm, src_percls = eval_acc_cm(cdan_backbone, cdan_head, src_small, IDX2CLASS)
tgt_acc, tgt_cm, tgt_percls = eval_acc_cm(cdan_backbone, cdan_head, tgt_small, IDX2CLASS)

# --- rare-3 F1 (stream counts) ---
counts = np.zeros(Cfg.NUM_CLASSES, dtype=np.int64)
for _, yb in tgt_small:
    y_np = yb.numpy() if hasattr(yb, "numpy") else np.array(yb)
    for t in y_np: counts[t] += 1
r3 = np.argsort(counts)[:3].tolist()

@torch.no_grad()
def collect_preds(backbone, head, loader):
    ys = []; yh = []
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        with torch.amp.autocast('cuda', enabled=getattr(Cfg, "MIXED_PREC", True)):
            f = backbone(x, return_feat=False, class_head=False)
            logits = head(f); preds = logits.argmax(1)
        ys.append(y.numpy()); yh.append(preds.cpu().numpy())
    return np.concatenate(ys), np.concatenate(yh)

y_true, y_pred = collect_preds(cdan_backbone, cdan_head, tgt_small)
rare3_f1 = float(f1_score(y_true, y_pred, labels=r3, average="macro"))

# --- proxy domain distance (dtype-safe FP32) ---
@torch.no_grad()
def collect_feats(backbone, loader, max_batches=40):
    F = []
    it = iter(loader)
    for _ in range(max_batches):
        try: x, _ = next(it)
        except StopIteration: break
        x = x.to(device, non_blocking=True)
        with torch.amp.autocast('cuda', enabled=getattr(Cfg, "MIXED_PREC", True)):
            f = backbone(x, return_feat=False, class_head=False)
        F.append(f.detach().cpu())
    if len(F) == 0:
        return torch.zeros(0, cdan_backbone.feat_dim, dtype=torch.float32)
    return torch.cat(F).float()  # ensure FP32

Xs = collect_feats(cdan_backbone, src_small, max_batches=40)
Xt = collect_feats(cdan_backbone, tgt_small, max_batches=40)
X  = torch.cat([Xs, Xt], 0).to(device).float()
y  = torch.cat([torch.ones(Xs.size(0)), torch.zeros(Xt.size(0))], 0).to(device).float()

lin = nn.Linear(cdan_backbone.feat_dim, 1).to(device).float()
opt = torch.optim.SGD(lin.parameters(), lr=0.05)
bce = nn.BCEWithLogitsLoss()

lin.train()
for _ in range(120):
    idx = torch.randperm(X.size(0), device=device)[:256]
    xb, yb = X[idx], y[idx]
    opt.zero_grad(set_to_none=True)
    # FP32 probe (no autocast)
    loss = bce(lin(xb).squeeze(1), yb)
    loss.backward()
    opt.step()

lin.eval()
with torch.no_grad():
    pred = (torch.sigmoid(lin(X).squeeze(1)) > 0.5).float()
    dom_err = 1.0 - (pred == y).float().mean().item()
a_distance_hat = float(2.0 * (1.0 - 2.0 * dom_err))

# --- save plots ---
plot_confusion(src_cm, IDX2CLASS, f"CDAN Source ({Cfg.SOURCE_DOMAIN})", OUT_DIR / "figs" / "cdan_cm_source.png")
plot_confusion(tgt_cm, IDX2CLASS, f"CDAN Target ({Cfg.TARGET_DOMAIN})", OUT_DIR / "figs" / "cdan_cm_target.png")

avg_domain_acc = float((src_acc + tgt_acc) / 2.0)
worst_group_acc = float(min(src_acc, tgt_acc))

summary = {
    "exp_name": Cfg.EXP_NAME,
    "domains": {"source": Cfg.SOURCE_DOMAIN, "target": Cfg.TARGET_DOMAIN},
    "method": "CDAN (mem-lite)",
    "metrics": {
        "source_acc": float(src_acc),
        "target_acc": float(tgt_acc),
        "avg_domain_acc": avg_domain_acc,
        "worst_group_acc": worst_group_acc,
        "rare3": r3,
        "rare3_f1": rare3_f1,
        "a_distance_hat": a_distance_hat
    },
    "artifacts": {
        "ckpt_cdan_backbone": str(OUT_DIR / "ckpts" / "best_cdan_backbone.pt"),
        "ckpt_cdan_head": str(OUT_DIR / "ckpts" / "best_cdan_head.pt"),
        "history": str(OUT_DIR / "history_cdan.json"),
        "acc_curve": str(OUT_DIR / "figs" / "cdan_acc_curve.png"),
        "dom_loss_curve": str(OUT_DIR / "figs" / "cdan_dom_loss_curve.png"),
        "cm_source": str(OUT_DIR / "figs" / "cdan_cm_source.png"),
        "cm_target": str(OUT_DIR / "figs" / "cdan_cm_target.png"),
        "summary_json": str(OUT_DIR / "summary_cdan.json")
    }
}

with open(OUT_DIR / "summary_cdan.json", "w") as f:
    json.dump(summary, f, indent=2)

print(json.dumps(summary["metrics"], indent=2))
print(f"Saved CDAN summary & figures to: {OUT_DIR}")


In [None]:
# === Cell M: Build-eval helpers + load all Task-1 checkpoints ===
import json, torch, numpy as np
from pathlib import Path
from torch.utils.data import DataLoader
device = torch.device(Cfg.DEVICE)

SAVE_ROOT = Path(Cfg.SAVE_ROOT)

# Known exp names from your run:
EXP_ERM  = "T1.1_SourceOnly_PACS_photo2sketch_ResNet50"
EXP_DANN = "T1.2_DANN_PACS_photo2sketch_ResNet50_memlite"
EXP_DAN  = "T1.3_DAN_PACS_photo2sketch_ResNet50_memlite"
EXP_CDAN = "T1.4_CDAN_PACS_photo2sketch_ResNet50_memlite"

# --- Minimal feature backbones matching your code ---
class ResNet50Classifier(nn.Module):
    def __init__(self, num_classes=Cfg.NUM_CLASSES, pretrained=False):
        super().__init__()
        from torchvision import models
        weights = models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        self.backbone = models.resnet50(weights=weights)
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Linear(in_features, num_classes)
    def forward(self, x): return self.backbone(x)

# From your DANN/DAN/CDAN code (assumed available in session):
# - DANN(backbone="resnet50") with .backbone and .forward_domain
# - ResNet_Feature(backbone="resnet50") exposing .feat_dim and returning feats when requested

@torch.no_grad()
def eval_clf(model, loader):
    model.eval()
    total = correct = 0
    for x,y in loader:
        x,y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
        logits = model(x)
        pred = logits.argmax(1)
        correct += (pred==y).sum().item(); total += y.numel()
    return 100.0 * correct / max(total,1)

def make_small_loader(big_loader, bs=32):
    return DataLoader(big_loader.dataset, batch_size=min(bs, Cfg.BATCH_SIZE),
                      shuffle=False, num_workers=0, pin_memory=False)

small_src = make_small_loader(loaders["src_test"], 32)
small_tgt = make_small_loader(loaders["tgt_test"], 32)

# --- Rebuilders that return a callable "classifier(x)->logits" for uniform eval ---
def build_erm():
    path = SAVE_ROOT/EXP_ERM/"ckpts"/"best_by_target.pt"
    m = ResNet50Classifier(num_classes=Cfg.NUM_CLASSES, pretrained=False).to(device)
    m.load_state_dict(torch.load(path, map_location=device)); m.eval()
    return m

def build_dann_backbone():
    path = SAVE_ROOT/EXP_DANN/"ckpts"/"best_dann_backbone.pt"
    b = ResNet_Feature(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=False).to(device)
    b.load_state_dict(torch.load(path, map_location=device)); b.eval()
    return b

class DAN_Head(nn.Module):
    def __init__(self, in_dim, num_classes, hid=256, p=0.2):
        super().__init__()
        self.bottleneck = nn.Sequential(
            nn.Linear(in_dim, hid), nn.BatchNorm1d(hid), nn.ReLU(True), nn.Dropout(p)
        )
        self.classifier = nn.Linear(hid, num_classes)
    def forward(self, f): return self.classifier(self.bottleneck(f))

def build_dan():
    b = ResNet_Feature(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=False).to(device)
    h = DAN_Head(b.feat_dim, Cfg.NUM_CLASSES).to(device)
    b.load_state_dict(torch.load(SAVE_ROOT/EXP_DAN/"ckpts"/"best_dan_backbone.pt", map_location=device))
    h.load_state_dict(torch.load(SAVE_ROOT/EXP_DAN/"ckpts"/"best_dan_head.pt", map_location=device))
    b.eval(); h.eval()
    class C(nn.Module):
        def __init__(self,b,h): super().__init__(); self.b,self.h=b,h
        def forward(self,x):
            f = self.b(x, return_feat=False, class_head=False)
            return self.h(f)
    return C(b,h).to(device)

class CDAN_Head(nn.Module):
    def __init__(self, in_dim, num_classes, hid=256, p=0.2):
        super().__init__()
        self.bottleneck = nn.Sequential(
            nn.Linear(in_dim, hid), nn.BatchNorm1d(hid), nn.ReLU(True), nn.Dropout(p)
        )
        self.classifier = nn.Linear(hid, num_classes)
    def forward(self, f): return self.classifier(self.bottleneck(f))

def build_cdan():
    b = ResNet_Feature(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=False).to(device)
    h = CDAN_Head(b.feat_dim, Cfg.NUM_CLASSES).to(device)
    b.load_state_dict(torch.load(SAVE_ROOT/EXP_CDAN/"ckpts"/"best_cdan_backbone.pt", map_location=device))
    h.load_state_dict(torch.load(SAVE_ROOT/EXP_CDAN/"ckpts"/"best_cdan_head.pt", map_location=device))
    b.eval(); h.eval()
    class C(nn.Module):
        def __init__(self,b,h): super().__init__(); self.b,self.h=b,h
        def forward(self,x):
            f = self.b(x, return_feat=False, class_head=False)
            return self.h(f)
    return C(b,h).to(device)

MODELS = {
    "ERM": build_erm,
    "DANN": lambda: nn.Sequential(build_dann_backbone(), nn.Identity()),  # eval via .backbone classifier below in your H cell
    "DAN": build_dan,
    "CDAN": build_cdan,
}

print("Helpers ready. You can now build/eval: ERM, DAN, CDAN. (DANN uses backbone+classifier as in your H cell.)")


Helpers ready. You can now build/eval: ERM, DAN, CDAN. (DANN uses backbone+classifier as in your H cell.)


Self-Training on Target (pseudo-labels) — memory-lite

In [None]:
# === Cell N: Self-training on target (pseudo-labeling, verbose & robust) ===
import json, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
from collections import defaultdict, Counter

import torch
from torch.utils.data import Dataset, DataLoader
from torch.amp import GradScaler, autocast

# ---- config/paths ----
device = torch.device(Cfg.DEVICE)
SAVE_ROOT = Path(Cfg.SAVE_ROOT)
EXP_SELF  = f"T1.5_SelfTrain_PACS_{Cfg.SOURCE_DOMAIN}2{Cfg.TARGET_DOMAIN}_ResNet50"
OUT_SELF  = make_out_dirs(EXP_SELF)  # creates .../T1.5_SelfTrain_.../{figs,ckpts}

# ---- teacher (ERM) ----
def _build_teacher_from_ckpt():
    # robust: try build_erm() if defined, else rebuild from checkpoint directly
    try:
        m = build_erm().to(device).eval()
        print("[SelfTrain] Teacher: build_erm() loaded.")
        return m
    except Exception as e:
        print("[SelfTrain] build_erm() unavailable, loading ERM checkpoint directly:", e)
        m = ResNet50Classifier(num_classes=Cfg.NUM_CLASSES, pretrained=False).to(device)
        ckpt = SAVE_ROOT / "T1.1_SourceOnly_PACS_photo2sketch_ResNet50" / "ckpts" / "best_by_target.pt"
        m.load_state_dict(torch.load(ckpt, map_location=device))
        m.eval()
        return m

teacher = _build_teacher_from_ckpt()

# ---- collect pseudo-labels (with optional tiny TTA) ----
def collect_pseudos(dataset, tta=False):
    confs, yhat, idxs = [], [], []
    teacher.eval()
    with torch.no_grad():
        for i in range(len(dataset)):
            x, _ = dataset[i]
            x = x.unsqueeze(0).to(device, non_blocking=True)
            if not tta:
                prob = teacher(x).softmax(1).squeeze(0)
            else:
                probs = []
                for flip in [False, True]:
                    xb = torch.flip(x, dims=[3]) if flip else x
                    probs.append(teacher(xb).softmax(1))
                prob = torch.stack(probs, 0).mean(0).squeeze(0)
            c, y = prob.max(0)
            confs.append(float(c)); yhat.append(int(y)); idxs.append(i)
    return np.array(idxs), np.array(yhat, dtype=np.int64), np.array(confs, dtype=np.float32)

tgt_ds = loaders["tgt_test"].dataset
idx_all, yhat_all, conf_all = collect_pseudos(tgt_ds, tta=False)

BASE_THRESH = 0.80
MIN_KEEP    = 500    # target minimum pseudo-labeled samples
TOP_P       = 0.08   # fallback top-p% overall
TOP_K_PER_C = 60     # fallback top-K per class

keep = conf_all >= BASE_THRESH
kept_idx = idx_all[keep]; kept_y = yhat_all[keep]; kept_conf = conf_all[keep]

if kept_idx.size < MIN_KEEP:
    for th in [0.75, 0.70, 0.65, 0.60, 0.55]:
        keep = conf_all >= th
        if keep.sum() >= MIN_KEEP:
            kept_idx, kept_y, kept_conf = idx_all[keep], yhat_all[keep], conf_all[keep]
            print(f"[SelfTrain] Adapted threshold -> {th:.2f} (kept {keep.sum()} samples)")
            break

if kept_idx.size < MIN_KEEP:
    k = max(int(TOP_P * len(conf_all)), MIN_KEEP // 2)
    top = np.argsort(-conf_all)[:k]
    kept_idx, kept_y, kept_conf = idx_all[top], yhat_all[top], conf_all[top]
    print(f"[SelfTrain] Fallback top-{TOP_P*100:.1f}% overall -> kept {k} samples")

if kept_idx.size < MIN_KEEP:
    by_cls = defaultdict(list)
    for i, y, c in zip(idx_all, yhat_all, conf_all):
        by_cls[int(y)].append((c, i))
    sel_idx, sel_y = [], []
    for cls in range(Cfg.NUM_CLASSES):
        pairs = sorted(by_cls.get(cls, []), key=lambda t: -t[0])[:TOP_K_PER_C]
        sel_idx.extend([i for _, i in pairs])
        sel_y.extend([cls]*len(pairs))
    if len(sel_idx) > kept_idx.size:
        kept_idx = np.array(sel_idx, dtype=np.int64)
        kept_y   = np.array(sel_y, dtype=np.int64)
        print(f"[SelfTrain] Fallback per-class top-{TOP_K_PER_C} -> kept {len(sel_idx)} samples")

print(f"[SelfTrain] Pseudo-labeled target: {kept_idx.size} / {len(tgt_ds)} kept (≥{BASE_THRESH:.2f} or adaptive)")

# class histogram for transparency
if kept_idx.size > 0:
    cls_hist = Counter(kept_y.tolist())
    print("[SelfTrain] Pseudo-label class histogram (kept):")
    for c in range(Cfg.NUM_CLASSES):
        print(f"  {IDX2CLASS[c]:>14}: {cls_hist.get(c, 0)}")
else:
    print("[SelfTrain] Pseudo-label set is empty after adaptation.")

# === guard: if still zero, SKIP self-training cleanly ===
if kept_idx.size == 0:
    summary = {
        "exp_name": EXP_SELF,
        "domains": {"source": Cfg.SOURCE_DOMAIN, "target": Cfg.TARGET_DOMAIN},
        "method": "Self-Training (skipped)",
        "metrics": {"target_acc": float('nan'), "kept_fraction": 0.0, "conf_thresh": BASE_THRESH},
        "artifacts": {}
    }
    with open(OUT_SELF / "summary_selftrain.json","w") as f:
        json.dump(summary, f, indent=2)
    print(f"[SelfTrain] Skipped. Summary saved -> {OUT_SELF/'summary_selftrain.json'}")
else:
    # ---- build pseudo-labeled loader ----
    class PseudoTarget(Dataset):
        def __init__(self, base_ds, keep_idx, labels):
            self.base, self.keep, self.labels = base_ds, list(map(int, keep_idx)), list(map(int, labels))
        def __len__(self): return len(self.keep)
        def __getitem__(self, k):
            i = self.keep[k]; x, _ = self.base[i]
            return x, int(self.labels[k])

    pl_loader = DataLoader(
        PseudoTarget(tgt_ds, kept_idx, kept_y),
        batch_size=min(64, Cfg.BATCH_SIZE),
        shuffle=True, num_workers=0, pin_memory=False
    )

    # ---- student: start from ERM; freeze all but final FC ----
    student = ResNet50Classifier(num_classes=Cfg.NUM_CLASSES, pretrained=False).to(device)
    erm_ckpt = SAVE_ROOT / "T1.1_SourceOnly_PACS_photo2sketch_ResNet50" / "ckpts" / "best_by_target.pt"
    student.load_state_dict(torch.load(erm_ckpt, map_location=device))
    for n,p in student.named_parameters():
        p.requires_grad = ("backbone.fc" in n)

    LR_ST, WD_ST, EPOCHS_ST = 1e-3, 1e-4, 5
    optim  = torch.optim.SGD(filter(lambda p: p.requires_grad, student.parameters()),
                             lr=LR_ST, momentum=0.9, weight_decay=WD_ST, nesterov=True)
    scaler = GradScaler('cuda', enabled=torch.cuda.is_available())
    crit   = torch.nn.CrossEntropyLoss()

    print(f"[SelfTrain] Starting FC-only fine-tune for {EPOCHS_ST} epochs on {len(pl_loader.dataset)} samples")
    print(f"[SelfTrain] Optim: SGD(lr={LR_ST}, wd={WD_ST}); Batch={pl_loader.batch_size}")

    history_st = {"epoch": [], "tgt_acc": []}
    best_tgt = -1.0

    for ep in range(1, EPOCHS_ST+1):
        student.train()
        running = 0.0
        seen = 0
        for step, (xb, yb) in enumerate(pl_loader, start=1):
            xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)
            optim.zero_grad(set_to_none=True)
            with autocast('cuda', enabled=torch.cuda.is_available()):
                logits = student(xb)
                loss = crit(logits, yb)
            scaler.scale(loss).backward()
            scaler.step(optim); scaler.update()

            running += loss.item() * yb.size(0)
            seen += yb.size(0)
            if step % 10 == 0 or step == len(pl_loader):
                avg_loss = running / max(seen, 1)
                print(f"[ST {ep:02d}/{EPOCHS_ST}] step {step:04d}/{len(pl_loader):04d} | avg_loss={avg_loss:.4f}")

        # end-of-epoch eval
        tgt_acc, _, _ = evaluate(student, loaders["tgt_test"], device)
        history_st["epoch"].append(ep); history_st["tgt_acc"].append(float(tgt_acc))
        print(f"[ST {ep:02d}/{EPOCHS_ST}] Target Acc = {tgt_acc:.2f}%")

        # save best
        (OUT_SELF / "ckpts").mkdir(parents=True, exist_ok=True)
        if tgt_acc > best_tgt:
            best_tgt = tgt_acc
            torch.save(student.state_dict(), OUT_SELF / "ckpts" / "best_selftrain.pt")
            print(f"[ST] ✔ Saved new best checkpoint @ {best_tgt:.2f}%")

    # ---- curves + artifacts + summary ----
    save_json(history_st, OUT_SELF / "history_selftrain.json")
    (OUT_SELF / "figs").mkdir(parents=True, exist_ok=True)
    plt.figure(); plt.plot(history_st["epoch"], history_st["tgt_acc"], marker="o")
    plt.xlabel("Epoch"); plt.ylabel("Target Acc (%)")
    plt.title("Self-Training on Target (FC-only)")
    plt.savefig(OUT_SELF / "figs" / "selftrain_tgt_curve.png", bbox_inches="tight", dpi=160); plt.close()

    # reproducibility for the pseudo set
    np.save(OUT_SELF / "pseudo_idx.npy", kept_idx)
    np.save(OUT_SELF / "pseudo_y.npy",   kept_y)

    summary = {
        "exp_name": EXP_SELF,
        "domains": {"source": Cfg.SOURCE_DOMAIN, "target": Cfg.TARGET_DOMAIN},
        "method": "Self-Training (pseudo-labels; FC-only)",
        "metrics": {
            "target_acc": float(best_tgt),
            "kept_fraction": float(len(kept_idx) / max(1, len(tgt_ds))),
            "conf_thresh": BASE_THRESH
        },
        "artifacts": {
            "ckpt_selftrain": str(OUT_SELF / "ckpts" / "best_selftrain.pt"),
            "history":        str(OUT_SELF / "history_selftrain.json"),
            "curve":          str(OUT_SELF / "figs" / "selftrain_tgt_curve.png"),
            "pseudo_idx":     str(OUT_SELF / "pseudo_idx.npy"),
            "pseudo_y":       str(OUT_SELF / "pseudo_y.npy")
        }
    }
    with open(OUT_SELF / "summary_selftrain.json", "w") as f:
        json.dump(summary, f, indent=2)

    print(f"[SelfTrain] Finished. Best target acc: {best_tgt:.2f}%. Artifacts -> {OUT_SELF}")


[SelfTrain] Teacher: build_erm() loaded.
[SelfTrain] Fallback top-8.0% overall -> kept 314 samples
[SelfTrain] Pseudo-labeled target: 314 / 3929 kept (≥0.80 or adaptive)
[SelfTrain] Pseudo-label class histogram (kept):
             dog: 0
        elephant: 0
         giraffe: 3
          guitar: 279
           horse: 14
           house: 0
          person: 18
[SelfTrain] Starting FC-only fine-tune for 5 epochs on 314 samples
[SelfTrain] Optim: SGD(lr=0.001, wd=0.0001); Batch=32
[ST 01/5] step 0010/0010 | avg_loss=1.5788




[ST 01/5] Target Acc = 22.09%
[ST] ✔ Saved new best checkpoint @ 22.09%
[ST 02/5] step 0010/0010 | avg_loss=0.8491
[ST 02/5] Target Acc = 15.47%
[ST 03/5] step 0010/0010 | avg_loss=0.5659
[ST 03/5] Target Acc = 15.47%
[ST 04/5] step 0010/0010 | avg_loss=0.4966
[ST 04/5] Target Acc = 15.47%
[ST 05/5] step 0010/0010 | avg_loss=0.4855
[ST 05/5] Target Acc = 15.47%
[SelfTrain] Finished. Best target acc: 22.09%. Artifacts -> /content/drive/MyDrive/DG_PACS/Task1/T1.5_SelfTrain_PACS_photo2sketch_ResNet50


In [None]:
# === Cell A: Self-Training (Primary, ERM teacher; simple & in-scope) ===
import json, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
from collections import defaultdict, Counter
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device(Cfg.DEVICE)
SAVE_ROOT = Path(Cfg.SAVE_ROOT)
EXP_SELF  = f"T1.5_SelfTrain_PACS_{Cfg.SOURCE_DOMAIN}2{Cfg.TARGET_DOMAIN}_ResNet50_PRIMARY"
OUT_SELF  = (SAVE_ROOT / EXP_SELF); (OUT_SELF / "ckpts").mkdir(parents=True, exist_ok=True); (OUT_SELF / "figs").mkdir(parents=True, exist_ok=True)

# --- Teacher = ERM (strict per brief) ---
teacher = ResNet50Classifier(num_classes=Cfg.NUM_CLASSES, pretrained=False).to(device)
erm_ckpt = SAVE_ROOT / "T1.1_SourceOnly_PACS_photo2sketch_ResNet50" / "ckpts" / "best_by_target.pt"
teacher.load_state_dict(torch.load(erm_ckpt, map_location=device)); teacher.eval()
for p in teacher.parameters(): p.requires_grad = False

# --- Pseudo-labels on target (with tiny hflip TTA) ---
tgt_ds = loaders["tgt_test"].dataset
@torch.no_grad()
def collect_pseudos(model, dataset, tta=True):
    idxs, yhat, confs = [], [], []
    model.eval()
    for i in range(len(dataset)):
        x, _ = dataset[i]
        x = x.unsqueeze(0).to(device, non_blocking=True)
        if not tta:
            prob = model(x).softmax(1).squeeze(0)
        else:
            p1 = model(x).softmax(1)
            p2 = model(torch.flip(x, dims=[3])).softmax(1)
            prob = ((p1 + p2) / 2).squeeze(0)
        c, y = prob.max(0)
        idxs.append(i); yhat.append(int(y)); confs.append(float(c))
    return np.array(idxs), np.array(yhat, np.int64), np.array(confs, np.float32)

idx_all, yhat_all, conf_all = collect_pseudos(teacher, tgt_ds, tta=True)

# --- Adaptive + class-balanced selection (simple) ---
BASE_TAU = 0.70
PER_CLASS_K = 120
MIN_TOTAL = 800

def select_adaptive_balanced(idxs, yhat, conf, num_classes):
    tau = BASE_TAU
    for _ in range(5):
        mask = conf >= tau
        by_cls = defaultdict(list)
        for i, y, c in zip(idxs[mask], yhat[mask], conf[mask]):
            by_cls[int(y)].append((c, i))
        kept_idx, kept_y, kept_conf = [], [], []
        for cls in range(num_classes):
            pairs = sorted(by_cls.get(cls, []), key=lambda t: -t[0])[:PER_CLASS_K]
            for c, i in pairs:
                kept_idx.append(i); kept_y.append(cls); kept_conf.append(c)
        if len(kept_idx) >= MIN_TOTAL:
            break
        tau = max(0.50, tau - 0.05)
    print(f"[SelfTrain PRIMARY] tau≈{tau:.2f} | kept={len(kept_idx)} | per-class:", dict(Counter(kept_y)))
    return np.array(kept_idx, np.int64), np.array(kept_y, np.int64), np.array(kept_conf, np.float32), float(tau)

kept_idx, kept_y, kept_conf, used_tau = select_adaptive_balanced(idx_all, yhat_all, conf_all, Cfg.NUM_CLASSES)

class PseudoTarget(Dataset):
    def __init__(self, base_ds, keep_idx, labels):
        self.base, self.keep, self.labels = base_ds, list(map(int, keep_idx)), list(map(int, labels))
    def __len__(self): return len(self.keep)
    def __getitem__(self, k):
        i = self.keep[k]; x, _ = self.base[i]
        return x, int(self.labels[k])

pl_loader = DataLoader(PseudoTarget(tgt_ds, kept_idx, kept_y),
                       batch_size=min(64, Cfg.BATCH_SIZE), shuffle=True, num_workers=0, pin_memory=False)

# --- Student init from ERM; 2 phases: FC-only → unfreeze layer4 (tiny LR) ---
def build_student():
    m = ResNet50Classifier(num_classes=Cfg.NUM_CLASSES, pretrained=False).to(device)
    m.load_state_dict(torch.load(erm_ckpt, map_location=device))
    return m

def set_trainable(student, phase):
    for p in student.parameters(): p.requires_grad = False
    if phase >= 1:
        for n,p in student.named_parameters():
            if "backbone.fc" in n: p.requires_grad = True
    if phase >= 2:
        for n,p in student.named_parameters():
            if "backbone.layer4" in n or "backbone.fc" in n: p.requires_grad = True

LR_FC, LR_L4, WD, MOM = 1e-3, 3e-4, 1e-4, 0.9
E1, E2 = 3, 2
student = build_student()

history = {"epoch": [], "tgt_acc": [], "phase": []}
best_state = None; best_tgt = -1.0

def eval_src_tgt(model):
    src_acc, _, _ = evaluate(model, loaders["src_test"], device)
    tgt_acc, _, _ = evaluate(model, loaders["tgt_test"], device)
    return float(src_acc), float(tgt_acc)

# Phase 1: FC-only
set_trainable(student, phase=1)
opt = torch.optim.SGD([{"params":[p for n,p in student.named_parameters() if p.requires_grad and "backbone.fc" in n], "lr": LR_FC}],
                      momentum=MOM, weight_decay=WD, nesterov=True)
crit = nn.CrossEntropyLoss()
for ep in range(1, E1+1):
    student.train()
    for xb, yb in pl_loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad(set_to_none=True)
        loss = crit(student(xb), yb)
        loss.backward(); nn.utils.clip_grad_norm_(student.parameters(), 5.0); opt.step()
    _, tgt_acc = eval_src_tgt(student)
    history["epoch"].append(len(history["epoch"])+1); history["tgt_acc"].append(tgt_acc); history["phase"].append(1)
    print(f"[PRIMARY P1 E{ep}/{E1}] tgt_acc={tgt_acc:.2f}%")
    if tgt_acc > best_tgt: best_tgt = tgt_acc; best_state = student.state_dict().copy()

# Phase 2: unfreeze layer4
set_trainable(student, phase=2)
opt = torch.optim.SGD(
    [
        {"params":[p for n,p in student.named_parameters() if p.requires_grad and "backbone.fc" in n], "lr": LR_FC},
        {"params":[p for n,p in student.named_parameters() if p.requires_grad and "backbone.layer4" in n], "lr": LR_L4},
    ],
    momentum=MOM, weight_decay=WD, nesterov=True
)
for ep in range(1, E2+1):
    student.train()
    for xb, yb in pl_loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad(set_to_none=True)
        loss = crit(student(xb), yb)
        loss.backward(); nn.utils.clip_grad_norm_(student.parameters(), 5.0); opt.step()
    _, tgt_acc = eval_src_tgt(student)
    history["epoch"].append(len(history["epoch"])+1); history["tgt_acc"].append(tgt_acc); history["phase"].append(2)
    print(f"[PRIMARY P2 E{ep}/{E2}] tgt_acc={tgt_acc:.2f}%")
    if tgt_acc > best_tgt: best_tgt = tgt_acc; best_state = student.state_dict().copy()

# Save best checkpoint and compute final metrics
if best_state is not None: student.load_state_dict(best_state)
src_acc, tgt_acc = eval_src_tgt(student)
avg_domain_acc = float((src_acc + tgt_acc) / 2.0)
worst_group_acc = float(min(src_acc, tgt_acc))
kept_fraction = float(len(kept_idx) / max(1, len(tgt_ds)))

torch.save(student.state_dict(), OUT_SELF / "ckpts" / "best_selftrain_primary.pt")
with open(OUT_SELF / "history_selftrain_primary.json", "w") as f:
    json.dump({"phase": history["phase"], "epoch": history["epoch"], "tgt_acc": history["tgt_acc"]}, f, indent=2)

plt.figure(); plt.plot(history["epoch"], history["tgt_acc"], marker="o")
plt.xlabel("Epoch (P1+P2)"); plt.ylabel("Target Acc (%)"); plt.title("Self-Training (Primary, ERM teacher)")
plt.grid(True, alpha=0.3)
plt.savefig(OUT_SELF / "figs" / "selftrain_primary_curve.png", bbox_inches="tight", dpi=160); plt.close()

summary = {
    "exp_name": EXP_SELF,
    "domains": {"source": Cfg.SOURCE_DOMAIN, "target": Cfg.TARGET_DOMAIN},
    "method": "Self-Training (Primary, ERM teacher)",
    "selection": {"tau_used": used_tau, "per_class_k": PER_CLASS_K, "min_total": MIN_TOTAL,
                  "kept_fraction": kept_fraction, "kept_count": int(len(kept_idx))},
    "metrics": {
        "source_acc": src_acc,
        "target_acc": tgt_acc,
        "avg_domain_acc": avg_domain_acc,
        "worst_group_acc": worst_group_acc
    },
    "artifacts": {
        "ckpt_best": str(OUT_SELF / "ckpts" / "best_selftrain_primary.pt"),
        "history": str(OUT_SELF / "history_selftrain_primary.json"),
        "curve": str(OUT_SELF / "figs" / "selftrain_primary_curve.png")
    },
    "notes": "ERM teacher; simple adaptive + per-class selection; FC warm-up → tiny layer4 unfreeze."
}
with open(OUT_SELF / "summary_selftrain_primary.json", "w") as f:
    json.dump(summary, f, indent=2)

print(json.dumps(summary["metrics"], indent=2))
print(f"[PRIMARY] Saved artifacts -> {OUT_SELF}")


[SelfTrain PRIMARY] tau≈0.50 | kept=4 | per-class: {3: 4}
[PRIMARY P1 E1/3] tgt_acc=21.35%
[PRIMARY P1 E2/3] tgt_acc=23.92%
[PRIMARY P1 E3/3] tgt_acc=25.30%
[PRIMARY P2 E1/2] tgt_acc=27.36%
[PRIMARY P2 E2/2] tgt_acc=28.48%
{
  "source_acc": 95.02994011976048,
  "target_acc": 28.480529396793077,
  "avg_domain_acc": 61.75523475827678,
  "worst_group_acc": 28.480529396793077
}
[PRIMARY] Saved artifacts -> /content/drive/MyDrive/DG_PACS/Task1/T1.5_SelfTrain_PACS_photo2sketch_ResNet50_PRIMARY


In [None]:
# === Cell B: Self-Training (Ablation, DANN teacher; robust selection + early stopping) ===
import json, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
from collections import defaultdict, Counter
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device(Cfg.DEVICE)
SAVE_ROOT = Path(Cfg.SAVE_ROOT)
EXP_ABL  = f"T1.5_SelfTrain_PACS_{Cfg.SOURCE_DOMAIN}2{Cfg.TARGET_DOMAIN}_ResNet50_ABL_DANNteacher"
OUT_ABL  = (SAVE_ROOT / EXP_ABL); (OUT_ABL / "ckpts").mkdir(parents=True, exist_ok=True); (OUT_ABL / "figs").mkdir(parents=True, exist_ok=True)

# --- Teacher = DANN backbone (features + fc) transplanted into ResNet50Classifier ---
teacher = ResNet50Classifier(num_classes=Cfg.NUM_CLASSES, pretrained=False).to(device)
dann_bb = SAVE_ROOT / "T1.2_DANN_PACS_photo2sketch_ResNet50_memlite" / "ckpts" / "best_dann_backbone.pt"
if not dann_bb.exists():
    raise FileNotFoundError("DANN backbone not found; run Task 1.2 DANN first.")

# Load DANN backbone module to read weights (includes classifier layer)
feat = ResNet_Feature(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=False).to(device)
feat.load_state_dict(torch.load(dann_bb, map_location=device))
# Transfer conv/bn/layers and fc weights into the classifier model
base = teacher.backbone
# features: [0]conv1, [1]bn1, [2]relu, [3]maxpool, [4]layer1, [5]layer2, [6]layer3, [7]layer4, [8]avgpool
base.conv1.load_state_dict(feat.features[0].state_dict())
base.bn1.load_state_dict(feat.features[1].state_dict())
base.layer1.load_state_dict(feat.features[4].state_dict())
base.layer2.load_state_dict(feat.features[5].state_dict())
base.layer3.load_state_dict(feat.features[6].state_dict())
base.layer4.load_state_dict(feat.features[7].state_dict())
# fc
with torch.no_grad():
    base.fc.weight.copy_(feat.classifier.weight)
    base.fc.bias.copy_(feat.classifier.bias)

teacher.eval();  [p.requires_grad_(False) for p in teacher.parameters()]
print("[SelfTrain ABL] DANN teacher loaded (features + fc).")

# --- Pseudo-labels on target (with hflip TTA) ---
tgt_ds = loaders["tgt_test"].dataset
@torch.no_grad()
def collect_pseudos(model, dataset, tta=True):
    idxs, yhat, confs = [], [], []
    model.eval()
    for i in range(len(dataset)):
        x, _ = dataset[i]
        x = x.unsqueeze(0).to(device, non_blocking=True)
        if not tta:
            prob = model(x).softmax(1).squeeze(0)
        else:
            p1 = model(x).softmax(1)
            p2 = model(torch.flip(x, dims=[3])).softmax(1)
            prob = ((p1 + p2) / 2).squeeze(0)
        c, y = prob.max(0)
        idxs.append(i); yhat.append(int(y)); confs.append(float(c))
    return np.array(idxs), np.array(yhat, np.int64), np.array(confs, np.float32)

idx_all, yhat_all, conf_all = collect_pseudos(teacher, tgt_ds, tta=True)

# --- Adaptive + class-balanced selection, with robust fallbacks ---
BASE_TAU = 0.70
PER_CLASS_K = 120
MIN_TOTAL = 800
TOP_P = 0.20            # 20% overall fallback
MIN_PER_CLASS = 80      # ensure some coverage per class in fallback

def select_adaptive_balanced_robust(idxs, yhat, conf, num_classes):
    # 1) Try thresholding and per-class top-K
    tau = BASE_TAU
    kept_idx, kept_y, kept_conf = [], [], []
    for _ in range(5):
        mask = conf >= tau
        by_cls = defaultdict(list)
        for i, y, c in zip(idxs[mask], yhat[mask], conf[mask]):
            by_cls[int(y)].append((c, i))
        cur_idx, cur_y, cur_conf = [], [], []
        for cls in range(num_classes):
            pairs = sorted(by_cls.get(cls, []), key=lambda t: -t[0])[:PER_CLASS_K]
            for c, i in pairs: cur_idx.append(i); cur_y.append(cls); cur_conf.append(c)
        if len(cur_idx) >= MIN_TOTAL:
            kept_idx, kept_y, kept_conf = cur_idx, cur_y, cur_conf
            print(f"[SelfTrain ABL] tau≈{tau:.2f} | kept={len(kept_idx)} | per-class:", dict(Counter(cur_y)))
            return np.array(kept_idx, np.int64), np.array(kept_y, np.int64), np.array(cur_conf, np.float32), float(tau)
        tau = max(0.50, tau - 0.05)

    # 2) Fallback: top-p overall by confidence
    k = max(int(TOP_P * len(conf)), MIN_TOTAL)
    top = np.argsort(-conf)[:k]
    kept_idx, kept_y, kept_conf = idxs[top], yhat[top], conf[top]
    print(f"[SelfTrain ABL] Fallback top-{int(TOP_P*100)}% overall -> kept={len(kept_idx)}")

    # 3) Ensure per-class coverage by backfilling classes below MIN_PER_CLASS
    counts = Counter(kept_y.tolist())
    if any(counts[c] < MIN_PER_CLASS for c in range(num_classes)):
        remaining = np.setdiff1d(np.arange(len(conf)), top, assume_unique=False)
        pool_by_cls = defaultdict(list)
        for j in remaining:
            pool_by_cls[int(yhat[j])].append((conf[j], idxs[j]))
        kept_idx = kept_idx.tolist(); kept_y = kept_y.tolist(); kept_conf = kept_conf.tolist()
        for cls in range(num_classes):
            need = max(0, MIN_PER_CLASS - counts.get(cls, 0))
            if need > 0 and len(pool_by_cls[cls]) > 0:
                extra = sorted(pool_by_cls[cls], key=lambda t: -t[0])[:need]
                for c,i in extra:
                    kept_idx.append(i); kept_y.append(cls); kept_conf.append(c)
        kept_idx = np.array(kept_idx, np.int64); kept_y = np.array(kept_y, np.int64); kept_conf = np.array(kept_conf, np.float32)
        print(f"[SelfTrain ABL] Backfilled per-class to ensure coverage | per-class:", dict(Counter(kept_y.tolist())))

    # 4) If still zero (extremely unlikely), return empty and caller will skip
    if len(kept_idx) == 0:
        print("[SelfTrain ABL] Pseudo set is EMPTY after all fallbacks.")
        return kept_idx, kept_y, kept_conf, float(tau)

    return kept_idx, kept_y, kept_conf, float(tau)

kept_idx, kept_y, kept_conf, used_tau = select_adaptive_balanced_robust(idx_all, yhat_all, conf_all, Cfg.NUM_CLASSES)

# --- Guard: handle empty set cleanly (skip training, save summary) ---
if len(kept_idx) == 0:
    summary = {
        "exp_name": EXP_ABL,
        "domains": {"source": Cfg.SOURCE_DOMAIN, "target": Cfg.TARGET_DOMAIN},
        "method": "Self-Training (Ablation, DANN teacher)",
        "selection": {"tau_used": used_tau, "per_class_k": PER_CLASS_K, "min_total": MIN_TOTAL,
                      "kept_fraction": 0.0, "kept_count": 0},
        "metrics": {
            "source_acc": float('nan'),
            "target_acc": float('nan'),
            "avg_domain_acc": float('nan'),
            "worst_group_acc": float('nan')
        },
        "artifacts": {}
    }
    with open(OUT_ABL / "summary_selftrain_ablation.json", "w") as f:
        json.dump(summary, f, indent=2)
    print("[ABLATION] Skipped training: no pseudo-labels could be selected. Summary saved.")
else:
    class PseudoTarget(Dataset):
        def __init__(self, base_ds, keep_idx, labels):
            self.base, self.keep, self.labels = base_ds, list(map(int, keep_idx)), list(map(int, labels))
        def __len__(self): return len(self.keep)
        def __getitem__(self, k):
            i = self.keep[k]; x, _ = self.base[i]
            return x, int(self.labels[k])

    pl_loader = DataLoader(PseudoTarget(tgt_ds, kept_idx, kept_y),
                           batch_size=min(64, Cfg.BATCH_SIZE), shuffle=True, num_workers=0, pin_memory=False)

    # --- Student = ERM init; 2-phase schedule with early stopping (same as primary style) ---
    def build_student():
        m = ResNet50Classifier(num_classes=Cfg.NUM_CLASSES, pretrained=False).to(device)
        m.load_state_dict(torch.load(SAVE_ROOT / "T1.1_SourceOnly_PACS_photo2sketch_ResNet50" / "ckpts" / "best_by_target.pt",
                                     map_location=device))
        return m

    def set_trainable(student, phase):
        for p in student.parameters(): p.requires_grad = False
        if phase >= 1:
            for n,p in student.named_parameters():
                if "backbone.fc" in n: p.requires_grad = True
        if phase >= 2:
            for n,p in student.named_parameters():
                if "backbone.layer4" in n or "backbone.fc" in n: p.requires_grad = True

    LR_FC, LR_L4, WD, MOM = 1e-3, 3e-4, 1e-4, 0.9
    student = build_student()

    history = {"epoch": [], "tgt_acc": [], "phase": []}
    best_state = None; best_tgt = -1.0

    def eval_src_tgt(model):
        src_acc, _, _ = evaluate(model, loaders["src_test"], device)
        tgt_acc, _, _ = evaluate(model, loaders["tgt_test"], device)
        return float(src_acc), float(tgt_acc)

    # ---- Phase 1: FC-only (up to 6 epochs, patience=2) ----
    set_trainable(student, phase=1)
    opt = torch.optim.SGD(
        [{"params":[p for n,p in student.named_parameters() if p.requires_grad and "backbone.fc" in n], "lr": LR_FC}],
        momentum=MOM, weight_decay=WD, nesterov=True
    )
    crit = nn.CrossEntropyLoss()
    E1, PATIENCE1 = 6, 2
    best_tgt_phase = -1.0; no_improve = 0

    for ep in range(1, E1+1):
        student.train()
        for xb, yb in pl_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            loss = crit(student(xb), yb)
            loss.backward(); nn.utils.clip_grad_norm_(student.parameters(), 5.0); opt.step()
        _, tgt_acc = eval_src_tgt(student)
        history["epoch"].append(len(history["epoch"])+1); history["tgt_acc"].append(tgt_acc); history["phase"].append(1)
        print(f"[ABL P1 E{ep}/{E1}] tgt_acc={tgt_acc:.2f}%")
        if tgt_acc > best_tgt_phase:
            best_tgt_phase = tgt_acc; no_improve = 0
            if tgt_acc > best_tgt: best_tgt = tgt_acc; best_state = student.state_dict().copy()
        else:
            no_improve += 1
            if no_improve >= PATIENCE1:
                print("[ABL] Early stop Phase 1."); break

    # ---- Phase 2: unfreeze layer4 (up to 10 epochs, patience=3) ----
    set_trainable(student, phase=2)
    opt = torch.optim.SGD(
        [
            {"params":[p for n,p in student.named_parameters() if p.requires_grad and "backbone.fc" in n], "lr": LR_FC},
            {"params":[p for n,p in student.named_parameters() if p.requires_grad and "backbone.layer4" in n], "lr": LR_L4},
        ],
        momentum=MOM, weight_decay=WD, nesterov=True
    )
    E2, PATIENCE2 = 10, 3
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=E2)
    best_tgt_phase = -1.0; no_improve = 0

    for ep in range(1, E2+1):
        student.train()
        for xb, yb in pl_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            loss = crit(student(xb), yb)
            loss.backward(); nn.utils.clip_grad_norm_(student.parameters(), 5.0); opt.step()
        sched.step()
        _, tgt_acc = eval_src_tgt(student)
        history["epoch"].append(len(history["epoch"])+1); history["tgt_acc"].append(tgt_acc); history["phase"].append(2)
        print(f"[ABL P2 E{ep}/{E2}] tgt_acc={tgt_acc:.2f}%")
        if tgt_acc > best_tgt_phase:
            best_tgt_phase = tgt_acc; no_improve = 0
            if tgt_acc > best_tgt: best_tgt = tgt_acc; best_state = student.state_dict().copy()
        else:
            no_improve += 1
            if no_improve >= PATIENCE2:
                print("[ABL] Early stop Phase 2."); break

    # ---- Save best + metrics (consistent with other methods) ----
    if best_state is not None: student.load_state_dict(best_state)
    src_acc, tgt_acc = eval_src_tgt(student)
    avg_domain_acc = float((src_acc + tgt_acc) / 2.0)
    worst_group_acc = float(min(src_acc, tgt_acc))
    kept_fraction = float(len(kept_idx) / max(1, len(tgt_ds)))

    torch.save(student.state_dict(), OUT_ABL / "ckpts" / "best_selftrain_ablation_dannT.pt")
    with open(OUT_ABL / "history_selftrain_ablation.json", "w") as f:
        json.dump({"phase": history["phase"], "epoch": history["epoch"], "tgt_acc": history["tgt_acc"]}, f, indent=2)

    plt.figure(); plt.plot(history["epoch"], history["tgt_acc"], marker="o")
    plt.xlabel("Epoch (P1+P2)"); plt.ylabel("Target Acc (%)"); plt.title("Self-Training (Ablation, DANN teacher)")
    plt.grid(True, alpha=0.3)
    plt.savefig(OUT_ABL / "figs" / "selftrain_ablation_curve.png", bbox_inches="tight", dpi=160); plt.close()

    summary = {
        "exp_name": EXP_ABL,
        "domains": {"source": Cfg.SOURCE_DOMAIN, "target": Cfg.TARGET_DOMAIN},
        "method": "Self-Training (Ablation, DANN teacher)",
        "selection": {"tau_used": used_tau, "per_class_k": PER_CLASS_K, "min_total": MIN_TOTAL,
                      "kept_fraction": kept_fraction, "kept_count": int(len(kept_idx))},
        "metrics": {
            "source_acc": src_acc,
            "target_acc": tgt_acc,
            "avg_domain_acc": avg_domain_acc,
            "worst_group_acc": worst_group_acc
        },
        "artifacts": {
            "ckpt_best": str(OUT_ABL / "ckpts" / "best_selftrain_ablation_dannT.pt"),
            "history": str(OUT_ABL / "history_selftrain_ablation.json"),
            "curve": str(OUT_ABL / "figs" / "selftrain_ablation_curve.png")
        },
        "notes": "Ablation only: DANN teacher (features+fc); robust selection; FC warm-up → tiny layer4 unfreeze with early stopping."
    }
    with open(OUT_ABL / "summary_selftrain_ablation.json", "w") as f:
        json.dump(summary, f, indent=2)

    print(json.dumps(summary["metrics"], indent=2))
    print(f"[ABLATION] Saved artifacts -> {OUT_ABL}")


[SelfTrain ABL] DANN teacher loaded (features + fc).
[SelfTrain ABL] Fallback top-20% overall -> kept=800
[SelfTrain ABL] Backfilled per-class to ensure coverage | per-class: {3: 348, 4: 155, 0: 101, 1: 188, 2: 80, 5: 76, 6: 10}
[ABL P1 E1/6] tgt_acc=29.96%
[ABL P1 E2/6] tgt_acc=36.73%
[ABL P1 E3/6] tgt_acc=44.57%
[ABL P1 E4/6] tgt_acc=47.98%
[ABL P1 E5/6] tgt_acc=51.49%
[ABL P1 E6/6] tgt_acc=52.23%
[ABL P2 E1/10] tgt_acc=56.30%
[ABL P2 E2/10] tgt_acc=58.51%
[ABL P2 E3/10] tgt_acc=59.66%
[ABL P2 E4/10] tgt_acc=60.09%
[ABL P2 E5/10] tgt_acc=60.37%
[ABL P2 E6/10] tgt_acc=61.42%
[ABL P2 E7/10] tgt_acc=61.90%
[ABL P2 E8/10] tgt_acc=61.11%
[ABL P2 E9/10] tgt_acc=61.90%
[ABL P2 E10/10] tgt_acc=61.42%
[ABL] Early stop Phase 2.
{
  "source_acc": 11.497005988023952,
  "target_acc": 61.41511835072537,
  "avg_domain_acc": 36.45606216937466,
  "worst_group_acc": 11.497005988023952
}
[ABLATION] Saved artifacts -> /content/drive/MyDrive/DG_PACS/Task1/T1.5_SelfTrain_PACS_photo2sketch_ResNet50_ABL_DAN

t-SNE feature plots (source vs target) for each method

In [None]:
# === Cell O: t-SNE embeddings per method (source vs target) ===
import numpy as np, torch, matplotlib.pyplot as plt, json
from sklearn.manifold import TSNE

def collect_feats_logits(model, loader, max_n=800):
    model.eval()
    Fs, Ys, Ds = [], [], []  # feats/logits inference → use penultimate logits as proxy if needed
    n = 0
    with torch.no_grad():
        for x,y in loader:
            x = x.to(device, non_blocking=True)
            z = model(x)                 # logits; we can t-SNE logits (class-separable space)
            Fs.append(z.detach().cpu().float())
            Ys.append(y.numpy()); D = np.full(len(y), fill_value=0, dtype=np.int64)  # placeholder
            n += len(y)
            if n >= max_n: break
    return torch.cat(Fs).numpy(), np.concatenate(Ys)

def tsne_plot(Fs, labels, title, savepath):
    ts = TSNE(n_components=2, init='pca', learning_rate='auto', perplexity=30)
    emb = ts.fit_transform(Fs)
    plt.figure(figsize=(6,5))
    sc = plt.scatter(emb[:,0], emb[:,1], c=labels, s=8)
    plt.title(title); plt.tight_layout()
    plt.savefig(savepath, bbox_inches="tight", dpi=160); plt.close()

METHOD_BUILDERS = {
    "ERM": build_erm,
    "DAN": build_dan,
    "CDAN": build_cdan,
}
# DANN uses backbone+classifier head (already available in your H cell); reuse backbone.classifier path
def build_dann_clf():
    b = build_dann_backbone()
    class C(nn.Module):
        def __init__(self,b): super().__init__(); self.b=b
        def forward(self,x):
            f = self.b(x, return_feat=False, class_head=False)
            return self.b.classifier(f)
    return C(b).to(device).eval()

METHOD_BUILDERS["DANN"] = build_dann_clf

OUT_TSNE = Path(Cfg.SAVE_ROOT)/"T1_TSNE"
(OUT_TSNE).mkdir(parents=True, exist_ok=True)

for name, builder in METHOD_BUILDERS.items():
    clf = builder()
    Fs_src, y_src = collect_feats_logits(clf, loaders["src_test"], max_n=800)
    Fs_tgt, y_tgt = collect_feats_logits(clf, loaders["tgt_test"], max_n=800)
    tsne_plot(np.vstack([Fs_src, Fs_tgt]),
              np.concatenate([y_src, y_tgt]),
              f"t-SNE logits: {name} (src+tg)",
              OUT_TSNE/f"tsne_{name}_logits_src_tgt.png")
    print(f"Saved t-SNE for {name} -> {OUT_TSNE/f'tsne_{name}_logits_src_tgt.png'}")


Saved t-SNE for ERM -> /content/drive/MyDrive/DG_PACS/Task1/T1_TSNE/tsne_ERM_logits_src_tgt.png




Saved t-SNE for DAN -> /content/drive/MyDrive/DG_PACS/Task1/T1_TSNE/tsne_DAN_logits_src_tgt.png




Saved t-SNE for CDAN -> /content/drive/MyDrive/DG_PACS/Task1/T1_TSNE/tsne_CDAN_logits_src_tgt.png




Saved t-SNE for DANN -> /content/drive/MyDrive/DG_PACS/Task1/T1_TSNE/tsne_DANN_logits_src_tgt.png


In [None]:
# === Cell P: Label/Concept shift & rare-class stress tests ===
import numpy as np, torch, json
from torch.utils.data import Subset, DataLoader
from collections import Counter

def downsample_classes(base_loader, keep_ratio_by_class, max_per_cls=None):
    ds = base_loader.dataset
    idx_by_cls = {}
    for i in range(len(ds)):
        _, y = ds[i]; idx_by_cls.setdefault(int(y), []).append(i)
    new_idx = []
    for c, idxs in idx_by_cls.items():
        k = int(len(idxs) * keep_ratio_by_class.get(c, 1.0))
        if max_per_cls: k = min(k, max_per_cls)
        new_idx.extend(idxs[:max(k,1)])
    return DataLoader(Subset(ds, new_idx), batch_size=base_loader.batch_size, shuffle=False,
                      num_workers=0, pin_memory=False)

def eval_all_methods(loader, note):
    results = {}
    # ERM, DAN, CDAN, DANN head
    models = {
        "ERM": build_erm(),
        "DAN": build_dan(),
        "CDAN": build_cdan(),
        "DANN": build_dann_clf(),
    }
    for name, m in models.items():
        acc, _, _ = evaluate(m, loader, device)
        results[name] = float(acc)
    return results

# 1) Label shift: e.g., shrink 3 classes on target to 20% prevalence
rare_classes = [0, 1, 2]  # pick any three; or choose programmatically by counts
keep_ratio = {c: 0.2 for c in rare_classes}
tgt_shift_loader = downsample_classes(loaders["tgt_test"], keep_ratio)
res_labelshift = eval_all_methods(tgt_shift_loader, "label_shift")

# 2) Rare-class stress: keep only 10% of the rarest class on target
#    (find rarest class first)
counts = Counter()
for _, y in loaders["tgt_test"]:
    for v in y.tolist(): counts[int(v)] += 1
rarest = min(counts.keys(), key=lambda c: counts[c])
keep_ratio2 = {rarest: 0.1}
tgt_rare_loader = downsample_classes(loaders["tgt_test"], keep_ratio2)
res_rarestress = eval_all_methods(tgt_rare_loader, "rare_stress")

# Save JSON roll-up
OUT_SHIFT = Path(Cfg.SAVE_ROOT)/"T1_ShiftStress"
OUT_SHIFT.mkdir(parents=True, exist_ok=True)
with open(OUT_SHIFT/"label_shift_results.json","w") as f: json.dump(res_labelshift, f, indent=2)
with open(OUT_SHIFT/"rare_stress_results.json","w") as f: json.dump(res_rarestress, f, indent=2)

print("Label-shift target acc (%):", res_labelshift)
print("Rare-class stress target acc (%):", res_rarestress)
print(f"Saved shift stress results under {OUT_SHIFT}")




Label-shift target acc (%): {'ERM': 41.351606805293, 'DAN': 12.145557655954631, 'CDAN': 22.77882797731569, 'DANN': 73.20415879017013}
Rare-class stress target acc (%): {'ERM': 22.971221156339123, 'DAN': 17.81177080632616, 'CDAN': 19.471091521908217, 'DANN': 65.67280269639616}
Saved shift stress results under /content/drive/MyDrive/DG_PACS/Task1/T1_ShiftStress


In [None]:
# === Cell Q: Consolidated Task-1 summary (table + plot) ===
import json, torch, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd

rows = []

def try_summary(path_json, fallback_eval=None, name=None):
    if path_json.exists():
        j = json.loads(path_json.read_text())
        src = j["metrics"].get("source_acc", np.nan)
        tgt = j["metrics"].get("target_acc", np.nan)
        avg = j["metrics"].get("avg_domain_acc", np.nan)
        worst = j["metrics"].get("worst_group_acc", np.nan)
        rows.append({"method": name, "source_acc": src, "target_acc": tgt, "avg": avg, "worst": worst})
    elif fallback_eval is not None:
        m = fallback_eval()
        src,_,_ = evaluate(m, loaders["src_test"], device)
        tgt,_,_ = evaluate(m, loaders["tgt_test"], device)
        rows.append({"method": name, "source_acc": float(src), "target_acc": float(tgt),
                     "avg": float((src+tgt)/2), "worst": float(min(src,tgt))})

# ERM
try_summary(SAVE_ROOT/EXP_ERM/"summary.json", fallback_eval=build_erm, name="ERM")
# DANN
try_summary(SAVE_ROOT/EXP_DANN/"summary_dann.json", fallback_eval=build_dann_clf, name="DANN")
# DAN
try_summary(SAVE_ROOT/EXP_DAN/"summary_dan.json", fallback_eval=build_dan, name="DAN")
# CDAN
try_summary(SAVE_ROOT/EXP_CDAN/"summary_cdan.json", fallback_eval=build_cdan, name="CDAN")
# Self-Training
try_summary(SAVE_ROOT/(f"T1.5_SelfTrain_PACS_{Cfg.SOURCE_DOMAIN}2{Cfg.TARGET_DOMAIN}_ResNet50")/"summary_selftrain.json",
            fallback_eval=None, name="SelfTrain")

df = pd.DataFrame(rows).sort_values("target_acc", ascending=False)
OUT_SUMMARY = Path(Cfg.SAVE_ROOT)/"T1_Summary"
OUT_SUMMARY.mkdir(parents=True, exist_ok=True)
df.to_csv(OUT_SUMMARY/"task1_summary.csv", index=False)
with open(OUT_SUMMARY/"task1_summary.json","w") as f: json.dump(rows, f, indent=2)

# Quick bar plot
plt.figure(figsize=(7,4))
plt.bar(df["method"], df["target_acc"])
plt.ylabel("Target Accuracy (%)"); plt.title("Task 1: Target Acc by Method")
plt.savefig(OUT_SUMMARY/"task1_target_acc_bar.png", bbox_inches="tight", dpi=160); plt.close()

print(df)
print(f"Saved consolidated summary to: {OUT_SUMMARY}")


      method  source_acc  target_acc        avg      worst
1       DANN  100.000000   65.894630  82.947315  65.894630
0        ERM   97.724551   22.550267  60.137409  22.550267
4  SelfTrain         NaN   22.092135        NaN        NaN
3       CDAN   10.838323   19.292441  15.065382  10.838323
2        DAN    8.562874   17.816238  13.189556   8.562874
Saved consolidated summary to: /content/drive/MyDrive/DG_PACS/Task1/T1_Summary


In [None]:
# === Cell R: Self-Training vs DANN comparison + brief printed analysis ===
import json, numpy as np
from pathlib import Path

SAVE_ROOT = Path(Cfg.SAVE_ROOT)

def _safe_load_json(p):
    if Path(p).exists():
        with open(p, "r") as f: return json.load(f)
    return None

J_ST   = _safe_load_json(SAVE_ROOT / f"T1.5_SelfTrain_PACS_{Cfg.SOURCE_DOMAIN}2{Cfg.TARGET_DOMAIN}_ResNet50" / "summary_selftrain.json")
J_DANN = _safe_load_json(SAVE_ROOT / "T1.2_DANN_PACS_photo2sketch_ResNet50_memlite" / "summary_dann.json")

def _pull_tgt(j):
    if not j: return np.nan
    return float(j.get("metrics", {}).get("target_acc", np.nan))

tgt_st   = _pull_tgt(J_ST)
tgt_dann = _pull_tgt(J_DANN)

print("=== Self-Training vs DANN (Target Accuracy, %) ===")
print(f"Self-Training: {tgt_st:.2f}")
print(f"DANN:         {tgt_dann:.2f}")
if np.isfinite(tgt_st) and np.isfinite(tgt_dann):
    diff = tgt_st - tgt_dann
    trend = "higher" if diff > 0 else ("lower" if diff < 0 else "equal")
    print(f"Δ(Self-Train − DANN) = {diff:+.2f} points ({trend}).")

print("\n--- Analysis (brief) ---")
print("Self-training can win when the source-only model already produces moderately reliable pseudo-labels on the target.")
print("Even if noisy, the correctly-labeled target samples pull the decision boundary toward target structure.")
print("DANN aligns feature distributions globally; if alignment is imperfect or class-conditional alignment is weak,")
print("a simple pseudo-label fine-tune (even FC-only) can match or beat it on target accuracy.")
print("Pitfalls: confirmation bias (reinforcing wrong pseudo-labels), class imbalance amplification, and threshold choice.")
print("Mitigations: higher confidence threshold, class-balanced sampling, temperature/soft labels, or EMA/consistency.")


=== Self-Training vs DANN (Target Accuracy, %) ===
Self-Training: 22.09
DANN:         65.89
Δ(Self-Train − DANN) = -43.80 points (lower).

--- Analysis (brief) ---
Self-training can win when the source-only model already produces moderately reliable pseudo-labels on the target.
Even if noisy, the correctly-labeled target samples pull the decision boundary toward target structure.
DANN aligns feature distributions globally; if alignment is imperfect or class-conditional alignment is weak,
a simple pseudo-label fine-tune (even FC-only) can match or beat it on target accuracy.
Pitfalls: confirmation bias (reinforcing wrong pseudo-labels), class imbalance amplification, and threshold choice.
Mitigations: higher confidence threshold, class-balanced sampling, temperature/soft labels, or EMA/consistency.


In [None]:
# === Cell S: Label-shift/rare-class visuals (confusions + heatmaps) ===
import numpy as np, matplotlib.pyplot as plt, json
from collections import Counter
from pathlib import Path
from torch.utils.data import Subset, DataLoader

OUT_SHIFT = Path(Cfg.SAVE_ROOT) / "T1_ShiftStress"
OUT_SHIFT.mkdir(parents=True, exist_ok=True)

# --- Reuse builders from earlier cells (ERM/DAN/DANN/CDAN) ---
methods = {
    "ERM": build_erm,
    "DANN": lambda: build_dann_clf(),
    "DAN": build_dan,
    "CDAN": build_cdan,
}

# Helper: build a shifted loader and also return its class histogram
def build_shifted_loader(base_loader, keep_ratio_by_class):
    ds = base_loader.dataset
    idx_by_cls = {}
    for i in range(len(ds)):
        _, y = ds[i]
        idx_by_cls.setdefault(int(y), []).append(i)
    new_idx = []
    for c, idxs in idx_by_cls.items():
        k = int(len(idxs) * keep_ratio_by_class.get(c, 1.0))
        new_idx.extend(idxs[:max(k,1)])
    loader = DataLoader(Subset(ds, new_idx), batch_size=base_loader.batch_size,
                        shuffle=False, num_workers=0, pin_memory=False)
    # histogram
    hist = Counter()
    for _, y in loader:
        for t in y.tolist(): hist[int(t)] += 1
    return loader, hist

# Choose label-shift setting (same as Cell P, but we’ll plot now)
# pick 3 rare classes programmatically from the unshifted target
base_counts = Counter()
for _, y in loaders["tgt_test"]:
    for v in y.tolist(): base_counts[int(v)] += 1
rare3 = [c for c,_ in base_counts.most_common()][-3:]
keep_ratio = {c: 0.2 for c in rare3}  # 20% for those classes

tgt_shift_loader, shift_counts = build_shifted_loader(loaders["tgt_test"], keep_ratio)

# --- distribution heatmap: source vs target vs shifted target ---
def counts_to_vec(counter, n):
    return np.array([counter.get(i,0) for i in range(n)], dtype=np.float32)

# source counts (use your src_test loader as source reference)
src_counts = Counter()
for _, y in loaders["src_test"]:
    for v in y.tolist(): src_counts[int(v)] += 1

tgt_counts = base_counts
shf_counts = shift_counts

mat = np.stack([
    counts_to_vec(src_counts, Cfg.NUM_CLASSES),
    counts_to_vec(tgt_counts, Cfg.NUM_CLASSES),
    counts_to_vec(shf_counts, Cfg.NUM_CLASSES),
], axis=0)

plt.figure(figsize=(8,3.2))
plt.imshow(mat, aspect="auto")
plt.yticks([0,1,2], ["Source", "Target", "Target (Shifted)"])
plt.xticks(range(Cfg.NUM_CLASSES), [IDX2CLASS[i] for i in range(Cfg.NUM_CLASSES)], rotation=45, ha="right")
plt.colorbar(label="Count")
plt.title("Class Distribution Heatmap")
plt.tight_layout()
plt.savefig(OUT_SHIFT / "class_distribution_heatmap.png", dpi=160)
plt.close()
print(f"Saved: {OUT_SHIFT/'class_distribution_heatmap.png'}")

# --- per-method confusion matrices on shifted target + per-class accuracy heatmap ---
def confusion_and_perclass(clf, loader, idx2class):
    from sklearn.metrics import confusion_matrix
    import torch
    cm = np.zeros((len(idx2class), len(idx2class)), dtype=np.int64)
    total = correct = 0
    true_all, pred_all = [], []
    with torch.no_grad():
        clf.eval()
        for x, y in loader:
            x = x.to(device); y = y.to(device)
            logits = clf(x); p = logits.argmax(1)
            true_all.extend(y.tolist()); pred_all.extend(p.tolist())
    cm = confusion_matrix(true_all, pred_all, labels=list(range(len(idx2class))))
    per_cls = (cm.diagonal() / np.clip(cm.sum(1), 1, None)) * 100.0
    return cm, per_cls

percls_mat = []
for name, builder in methods.items():
    clf = builder().to(device).eval()
    # confusion on shifted target:
    cm, percls = confusion_and_perclass(clf, tgt_shift_loader, IDX2CLASS)

    # save cm figure:
    plt.figure(figsize=(6,5))
    plt.imshow(cm, interpolation='nearest')
    plt.title(f"{name}: Confusion (Shifted Target)"); plt.colorbar()
    ticks = np.arange(len(IDX2CLASS))
    labels = [IDX2CLASS[i] for i in range(len(IDX2CLASS))]
    plt.xticks(ticks, labels, rotation=45, ha='right'); plt.yticks(ticks, labels)
    plt.tight_layout(); plt.ylabel('True'); plt.xlabel('Predicted')
    fname = OUT_SHIFT / f"{name.lower()}_cm_shifted.png"
    plt.savefig(fname, bbox_inches="tight", dpi=160); plt.close()
    print(f"Saved: {fname}")

    percls_mat.append(percls)

percls_mat = np.vstack(percls_mat)  # [num_methods, num_classes]

plt.figure(figsize=(8,3.6))
plt.imshow(percls_mat, aspect="auto", vmin=0, vmax=100)
plt.yticks(range(len(methods)), list(methods.keys()))
plt.xticks(range(Cfg.NUM_CLASSES), [IDX2CLASS[i] for i in range(Cfg.NUM_CLASSES)], rotation=45, ha="right")
plt.colorbar(label="Per-class Accuracy (%)")
plt.title("Per-class Accuracy under Label Shift (Shifted Target)")
plt.tight_layout()
plt.savefig(OUT_SHIFT / "perclass_accuracy_shift_heatmap.png", dpi=160)
plt.close()
print(f"Saved: {OUT_SHIFT/'perclass_accuracy_shift_heatmap.png'}")


NameError: name 'build_erm' is not defined

[Q1] Unshadowed torchvision.models (removed a dict named 'models').


NameError: name 'models' is not defined

In [None]:
# === Cell CS1: Build stress loaders (label-shift & rare-class) and print distributions ===
import numpy as np
from torch.utils.data import Subset, DataLoader
from collections import Counter
from pathlib import Path

# Helper: downsample classes on an existing loader's dataset
def downsample_classes(base_loader, keep_ratio_by_class, max_per_cls=None):
    ds = base_loader.dataset
    idx_by_cls = {}
    for i in range(len(ds)):
        _, y = ds[i]
        idx_by_cls.setdefault(int(y), []).append(i)
    new_idx = []
    for c, idxs in idx_by_cls.items():
        k = int(len(idxs) * keep_ratio_by_class.get(c, 1.0))
        if max_per_cls is not None: k = min(k, max_per_cls)
        k = max(k, 1)  # keep at least 1 if class exists
        new_idx.extend(idxs[:k])
    return DataLoader(
        Subset(ds, new_idx),
        batch_size=loaders["tgt_test"].batch_size,
        shuffle=False, num_workers=0, pin_memory=False
    )

# 1) Label shift: shrink three classes on target to 20%
three = [0, 1, 2]  # you can swap these to other IDs if you prefer
keep_ratio_lblshift = {c: 0.2 for c in three}
tgt_shift_loader = downsample_classes(loaders["tgt_test"], keep_ratio_lblshift)

# 2) Rare-class: keep only 10% of the rarest class on target
counts = Counter()
for _, yb in loaders["tgt_test"]:
    for v in yb.tolist(): counts[int(v)] += 1
rarest_cls = min(counts.keys(), key=lambda c: counts[c])
tgt_rare_loader = downsample_classes(loaders["tgt_test"], {rarest_cls: 0.1})

def count_classes(loader):
    c = Counter()
    for _, yb in loader:
        for v in yb.tolist(): c[int(v)] += 1
    return c

print("[CS1] Target original counts:", dict(count_classes(loaders["tgt_test"])))
print("[CS1] Label-shift counts     :", dict(count_classes(tgt_shift_loader)))
print("[CS1] Rare-class counts      :", dict(count_classes(tgt_rare_loader)))




[CS1] Target original counts: {0: 772, 1: 740, 2: 753, 3: 608, 4: 816, 5: 80, 6: 160}
[CS1] Label-shift counts     : {0: 154, 1: 148, 2: 150, 3: 608, 4: 816, 5: 80, 6: 160}
[CS1] Rare-class counts      : {0: 772, 1: 740, 2: 753, 3: 608, 4: 816, 5: 8, 6: 160}


In [None]:
# === Patch: (re)bind torchvision.models to the global name `models` ===
import torchvision.models as models
print("[Patch] `models` is now bound to torchvision.models (e.g., resnet50 available).")


[Patch] `models` is now bound to torchvision.models (e.g., resnet50 available).


In [None]:
# === Cell CS2: Evaluate under concept-shift loaders with macro stats + plots ===
import torchvision.models as models  # keep `models` available for ResNet50Classifier
import json, numpy as np, matplotlib.pyplot as plt
from sklearn.metrics import f1_score, confusion_matrix
from pathlib import Path
import torch
import torch.nn as nn
import torchvision.models as tv_models  # just in case

OUT_SHIFT = Path(Cfg.SAVE_ROOT) / "T1_ShiftStress"
(OUT_SHIFT / "figs").mkdir(parents=True, exist_ok=True)
device = torch.device(Cfg.DEVICE)

# --- Names for classes ---
IDX2CLASS = {i: name for i, name in enumerate(label_names)} if 'label_names' in globals() and label_names else {i: f"class_{i}" for i in range(Cfg.NUM_CLASSES)}

# --- Builders that use your existing checkpoints; DAN/CDAN are optional ---
def build_ERM_clf():
    m = ResNet50Classifier(num_classes=Cfg.NUM_CLASSES, pretrained=False).to(device)
    ck = Path(Cfg.SAVE_ROOT) / "T1.1_SourceOnly_PACS_photo2sketch_ResNet50" / "ckpts" / "best_by_target.pt"
    m.load_state_dict(torch.load(ck, map_location=device)); m.eval()
    return m

def build_DANN_clf_optional():
    bb = Path(Cfg.SAVE_ROOT) / "T1.2_DANN_PACS_photo2sketch_ResNet50_memlite" / "ckpts" / "best_dann_backbone.pt"
    if not bb.exists(): return None
    feat = ResNet_Feature(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=False).to(device)
    feat.load_state_dict(torch.load(bb, map_location=device)); feat.eval()
    class Wrap(nn.Module):
        def __init__(self, feat): super().__init__(); self.backbone=feat
        def forward(self, x):
            f = self.backbone(x, return_feat=False, class_head=False)
            return self.backbone.classifier(f)
    return Wrap(feat).to(device).eval()

def build_DAN_clf_optional():
    bdir = Path(Cfg.SAVE_ROOT) / "T1.3_DAN_PACS_photo2sketch_ResNet50_memlite" / "ckpts"
    if not (bdir / "best_dan_backbone.pt").exists() or not (bdir / "best_dan_head.pt").exists():
        return None
    feat = ResNet_Feature(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=False).to(device)
    feat.load_state_dict(torch.load(bdir/"best_dan_backbone.pt", map_location=device)); feat.eval()
    class DAN_Head(nn.Module):
        def __init__(self, in_dim, num_classes, hid=256, p=0.2):
            super().__init__()
            self.bottleneck = nn.Sequential(
                nn.Linear(in_dim, hid), nn.BatchNorm1d(hid), nn.ReLU(inplace=True), nn.Dropout(p),
            )
            self.classifier = nn.Linear(hid, num_classes)
        def forward(self, f): z = self.bottleneck(f); return self.classifier(z)
    head = DAN_Head(feat.feat_dim, Cfg.NUM_CLASSES).to(device)
    head.load_state_dict(torch.load(bdir/"best_dan_head.pt", map_location=device)); head.eval()
    class Wrap(nn.Module):
        def __init__(self, feat, head): super().__init__(); self.backbone=feat; self.head=head
        def forward(self, x):
            f = self.backbone(x, return_feat=False, class_head=False)
            return self.head(f)
    return Wrap(feat, head).to(device).eval()

def build_CDAN_clf_optional():
    cdir = Path(Cfg.SAVE_ROOT) / f"T1.4_CDAN_PACS_{Cfg.SOURCE_DOMAIN}2{Cfg.TARGET_DOMAIN}_ResNet50_memlite" / "ckpts"
    if not (cdir / "best_cdan_backbone.pt").exists() or not (cdir / "best_cdan_head.pt").exists():
        return None
    feat = ResNet_Feature(num_classes=Cfg.NUM_CLASSES, backbone="resnet50", pretrained=False).to(device)
    feat.load_state_dict(torch.load(cdir/"best_cdan_backbone.pt", map_location=device)); feat.eval()
    class CDAN_Head(nn.Module):
        def __init__(self, in_dim, num_classes, hid=256, p=0.2):
            super().__init__()
            self.bottleneck = nn.Sequential(
                nn.Linear(in_dim, hid), nn.BatchNorm1d(hid), nn.ReLU(inplace=True), nn.Dropout(p),
            )
            self.classifier = nn.Linear(hid, num_classes)
        def forward(self, f): z = self.bottleneck(f); return self.classifier(z)
    head = CDAN_Head(feat.feat_dim, Cfg.NUM_CLASSES).to(device)
    head.load_state_dict(torch.load(cdir/"best_cdan_head.pt", map_location=device)); head.eval()
    class Wrap(nn.Module):
        def __init__(self, feat, head): super().__init__(); self.backbone=feat; self.head=head
        def forward(self, x):
            f = self.backbone(x, return_feat=False, class_head=False)
            return self.head(f)
    return Wrap(feat, head).to(device).eval()

# Registry (conditional for DAN/CDAN)
MODEL_REG = {"ERM": build_ERM_clf()}
maybe = build_DANN_clf_optional()
if maybe is not None: MODEL_REG["DANN"] = maybe
maybe = build_DAN_clf_optional()
if maybe is not None: MODEL_REG["DAN"] = maybe
maybe = build_CDAN_clf_optional()
if maybe is not None: MODEL_REG["CDAN"] = maybe

print("[CS2] Evaluating methods:", list(MODEL_REG.keys()))

@torch.no_grad()
def eval_full(model, loader, num_classes):
    model.eval()
    ys, yh = [], []
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        preds = logits.argmax(1)
        ys.append(yb.cpu().numpy()); yh.append(preds.cpu().numpy())
    ys = np.concatenate(ys) if ys else np.zeros(0, dtype=int)
    yh = np.concatenate(yh) if yh else np.zeros(0, dtype=int)

    acc = 100.0 * (yh == ys).mean() if ys.size else 0.0
    cm = confusion_matrix(ys, yh, labels=list(range(num_classes)))
    per_cls = {IDX2CLASS[i]: (100.0 * cm[i, i] / cm[i].sum() if cm[i].sum() > 0 else 0.0)
               for i in range(num_classes)}
    # Macro metrics
    per_acc = [(cm[i, i] / cm[i].sum()) if cm[i].sum() > 0 else 0.0 for i in range(num_classes)]
    macro_acc = 100.0 * float(np.mean(per_acc)) if per_acc else 0.0
    macro_f1  = 100.0 * f1_score(ys, yh, average="macro", labels=list(range(num_classes)), zero_division=0)
    return acc, macro_acc, macro_f1, per_cls, cm

def plot_cm(cm, title, savepath):
    plt.figure(figsize=(6,5))
    plt.imshow(cm, interpolation='nearest')
    plt.title(title); plt.colorbar()
    ticks = np.arange(Cfg.NUM_CLASSES)
    plt.xticks(ticks, [IDX2CLASS[i] for i in ticks], rotation=45, ha='right')
    plt.yticks(ticks, [IDX2CLASS[i] for i in ticks])
    plt.ylabel("True"); plt.xlabel("Predicted")
    plt.tight_layout(); plt.savefig(savepath, dpi=160); plt.close()

def class_hist(loader):
    c = np.zeros(Cfg.NUM_CLASSES, dtype=int)
    for _, yb in loader:
        for v in yb.tolist(): c[int(v)] += 1
    return c

def eval_pack(tag, loader):
    pack = {}
    for name, m in MODEL_REG.items():
        acc, macro_acc, macro_f1, per_cls, cm = eval_full(m, loader, Cfg.NUM_CLASSES)
        pack[name] = {
            "acc_overall": float(acc),
            "macro_acc": float(macro_acc),
            "macro_f1": float(macro_f1),
            "per_class_acc": {k: float(v) for k, v in per_cls.items()},
        }
        plot_cm(cm, f"{name} — {tag}", OUT_SHIFT / "figs" / f"cm_{tag}_{name}.png")
    with open(OUT_SHIFT / f"{tag}_macro_metrics.json", "w") as f:
        json.dump(pack, f, indent=2)
    print(f"[{tag}] saved macro/per-class metrics and confusion matrices.")

# Run on your *shift* loaders built in CS1
eval_pack("label_shift", tgt_shift_loader)
eval_pack("rare_stress", tgt_rare_loader)

# Heatmap: Source vs. label-shift target class counts
src_counts = class_hist(loaders["src_test"])
shift_counts = class_hist(tgt_shift_loader)
plt.figure(figsize=(6.6,3.6))
plt.imshow(np.vstack([src_counts, shift_counts]), aspect="auto")
plt.yticks([0,1], ["Source", "Target (label-shift)"])
plt.xticks(range(Cfg.NUM_CLASSES), [IDX2CLASS[i] for i in range(Cfg.NUM_CLASSES)], rotation=45, ha='right')
plt.title("Class distribution heatmap — Source vs Target (label-shift)")
plt.colorbar(label="count"); plt.tight_layout()
plt.savefig(OUT_SHIFT / "figs" / "label_shift_heatmap.png", dpi=160); plt.close()
print(f"[CS2] Saved heatmap -> {OUT_SHIFT/'figs'/'label_shift_heatmap.png'}")


[CS2] Evaluating methods: ['ERM', 'DANN', 'DAN', 'CDAN']
[label_shift] saved macro/per-class metrics and confusion matrices.
[rare_stress] saved macro/per-class metrics and confusion matrices.




[CS2] Saved heatmap -> /content/drive/MyDrive/DG_PACS/Task1/T1_ShiftStress/figs/label_shift_heatmap.png
