### Imports

In [1]:
import sys
sys.dont_write_bytecode = True

import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
import warnings
import logging
from datetime import datetime
warnings.filterwarnings('ignore')

from model import get_model
from config import CFG
from dataset import PACSDataset
from plot import plot_domainwise_accuracy
from transform import get_transforms
from runner import run_baseline, run_lodo

torch.manual_seed(CFG["system"]["seed"])
np.random.seed(CFG["system"]["seed"])

device = CFG["system"]["device"]
print(f"Device: {device}")
print(f"PyTorch: {torch.__version__}")


Device: cuda
PyTorch: 2.7.1+cu128


### DataLoading

In [2]:
train_transform, test_transform = get_transforms(img_size=224, augment=False, use_imagenet_norm=False)

pacs = PACSDataset(
    data_root=CFG["datasets"]["PACS"]["root"],
    transform=train_transform,
    batch_size=CFG["train"]["batch_size"]
)

print("\nData loaders ready!")


Data loaders ready!


### Logging

In [3]:
dataset_name = "PACS" 
base_dir = os.path.join(os.getcwd(), dataset_name)
subdirs = ["logs", "checkpoints", "plots"]

for sub in subdirs:
    os.makedirs(os.path.join(base_dir, sub), exist_ok=True)

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_file = os.path.join(base_dir, "logs", f"train_{timestamp}.log")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger(f"{dataset_name}_logger")

logger.info(f"Initialized experiment directories for {dataset_name}")
logger.info(f"Logs: {os.path.join(base_dir, 'logs')}")
logger.info(f"Checkpoints: {os.path.join(base_dir, 'checkpoints')}")
logger.info(f"Plots: {os.path.join(base_dir, 'plots')}")

2025-10-05 21:53:29,859 | INFO | Initialized experiment directories for PACS
2025-10-05 21:53:29,859 | INFO | Logs: d:\Haseeb\SPROJ\GRQO\Vit-GRQO\resnet\PACS\logs
2025-10-05 21:53:29,859 | INFO | Checkpoints: d:\Haseeb\SPROJ\GRQO\Vit-GRQO\resnet\PACS\checkpoints
2025-10-05 21:53:29,859 | INFO | Plots: d:\Haseeb\SPROJ\GRQO\Vit-GRQO\resnet\PACS\plots


### Setup

In [4]:
domains = CFG["datasets"]["PACS"]["domains"]
loaders = {d: {"train": pacs.get_dataloader(d, train=True), "val": pacs.get_dataloader(d, train=False)} for d in domains}
ckpt_root = os.path.join(base_dir, "checkpoints")
log_dir = os.path.join(base_dir, "logs")
plots_dir = os.path.join(base_dir, "plots")
os.makedirs(ckpt_root, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True)
model_factory = lambda cfg, dataset_key: get_model(cfg,dataset="PACS")
optimizer_fn = lambda model: optim.AdamW(model.parameters(), lr=CFG["train"]["lr"], weight_decay=CFG["train"].get("weight_decay", 0.01))
device = CFG["system"]["device"]
epochs = CFG["train"]["epochs"]


{
  "lodo_results": {
    "art_painting": 0.8341463414634146,
    "cartoon": 0.7974413646055437,
    "photo": 0.9580838323353293,
    "sketch": 0.6017811704834606
  },
  "timestamp": "20251004_020611"
}

### Leave One Domain Out

In [6]:
lodo_results, lodo_mean, lodo_summary = run_lodo(
    model_fn=model_factory,
    CFG=CFG,
    logger=logger,
    dataset_key="PACS",
    domains=domains,
    loaders=loaders,
    optimizer_fn=optimizer_fn,
    device=device,
    ckpt_root=ckpt_root,
    log_dir=log_dir,
    epochs=epochs
)

2025-10-05 21:38:59,990 | INFO | 
=== LODO: Leaving out domain 'art_painting' ===
Evaluating: 100%|██████████| 7/7 [00:11<00:00,  1.57s/it]
2025-10-05 21:39:27,405 | INFO | [art_painting] Epoch 1/5 | Train - Loss: 0.7952, Cls: 0.7840, GRQO: 0.0111, Acc: 0.7386 | Val - Loss: 1.1947, Cls: 1.1929, GRQO: 0.0018, Acc: 0.6512
2025-10-05 21:39:27,471 | INFO | [art_painting] New best val acc: 0.6512
Evaluating: 100%|██████████| 7/7 [00:11<00:00,  1.58s/it]
2025-10-05 21:39:54,605 | INFO | [art_painting] Epoch 2/5 | Train - Loss: 0.1377, Cls: 0.1352, GRQO: 0.0025, Acc: 0.9533 | Val - Loss: 1.0999, Cls: 1.0988, GRQO: 0.0011, Acc: 0.6927
2025-10-05 21:39:54,671 | INFO | [art_painting] New best val acc: 0.6927
Evaluating: 100%|██████████| 7/7 [00:11<00:00,  1.61s/it]
2025-10-05 21:40:22,005 | INFO | [art_painting] Epoch 3/5 | Train - Loss: 0.0330, Cls: 0.0309, GRQO: 0.0021, Acc: 0.9912 | Val - Loss: 0.7521, Cls: 0.7514, GRQO: 0.0008, Acc: 0.7976
2025-10-05 21:40:22,078 | INFO | [art_painting] New 

### Baseline

In [6]:
model_name ="resnet18"
baseline_results, baseline_mean = run_baseline(
    model_name=model_name,
    CFG=CFG,
    logger=logger,
    dataset_key="PACS",
    domains=domains,
    loaders=loaders,
    optimizer_fn=optimizer_fn,
    device=device,
    epochs=CFG["train"]["epochs"]
)


2025-10-05 21:14:51,274 | INFO | Initializing ResNet baseline: resnet18
2025-10-05 21:14:51,363 | INFO | 
=== Baseline LODO: Leaving out domain 'art_painting' ===
2025-10-05 21:15:18,034 | INFO | [art_painting] Epoch 1/5 | Train - Loss: 0.5796, Acc: 0.8195 | Val Acc: 0.6902
2025-10-05 21:15:44,760 | INFO | [art_painting] Epoch 2/5 | Train - Loss: 0.1104, Acc: 0.9720 | Val Acc: 0.6415
2025-10-05 21:16:11,293 | INFO | [art_painting] Epoch 3/5 | Train - Loss: 0.0267, Acc: 0.9972 | Val Acc: 0.7341
2025-10-05 21:16:38,164 | INFO | [art_painting] Epoch 4/5 | Train - Loss: 0.0097, Acc: 0.9997 | Val Acc: 0.7195
2025-10-05 21:17:05,047 | INFO | [art_painting] Epoch 5/5 | Train - Loss: 0.0051, Acc: 1.0000 | Val Acc: 0.7390
2025-10-05 21:17:05,047 | INFO | [art_painting] Best Val Acc: 0.7390
------------------------------------------------------------
2025-10-05 21:17:05,047 | INFO | Initializing ResNet baseline: resnet18
2025-10-05 21:17:05,142 | INFO | 
=== Baseline LODO: Leaving out domain 'ca

### Plot

In [None]:
baseline_vals = [baseline_results[d] for d in domains]
grqo_vals = [lodo_results[d] for d in domains]
plot_path = os.path.join(plots_dir, f"comparison_{__import__('datetime').datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
plot_domainwise_accuracy(domains, baseline_vals, grqo_vals, save_path=plot_path)
logger.info(f"Saved comparison plot to {plot_path}")

### Ablation1

In [None]:
# Cell: run_ablation_reduced
from itertools import product
import copy
import json
import time
import os

# Reduced grid for ~24 runs
heads_list = [4, 6, 8]        # num_heads
layers_list = [2, 4, 6]          # num_layers
tokens_list = [16, 24, 32, 48]  # num_tokens / topk

results = []
timestamp = time.strftime("%Y%m%d_%H%M%S")
ablation_file = os.path.join(log_dir, f"ablation_results_{timestamp}.json")

# Iterate over combinations
for heads, layers, tokens in product(heads_list, layers_list, tokens_list):
    cfg_run = copy.deepcopy(CFG)
    cfg_run["grqo"]["num_heads"] = heads
    cfg_run["grqo"]["num_layers"] = layers
    cfg_run["grqo"]["num_tokens"] = tokens
    cfg_run["grqo"]["topk"] = tokens

    run_name = f"h{heads}_l{layers}_t{tokens}_{timestamp}"
    logger.info(f"ABlation START {run_name}")

    try:
        lodo_results, mean_acc, summary_path = run_lodo(
            model_fn=model_factory,
            CFG=cfg_run,
            logger=logger,
            dataset_key="PACS",
            domains=domains,
            loaders=loaders,
            optimizer_fn=optimizer_fn,
            device=device,
            ckpt_root=ckpt_root,
            log_dir=log_dir,
            epochs=epochs
        )

        results.append({
            "run": run_name,
            "heads": heads,
            "layers": layers,
            "tokens": tokens,
            "mean_acc": mean_acc,
            "lodo_results": lodo_results,
            "summary": summary_path
        })
        logger.info(f"ABlation DONE {run_name} mean_acc={mean_acc:.4f}")

    except Exception as e:
        results.append({
            "run": run_name,
            "heads": heads,
            "layers": layers,
            "tokens": tokens,
            "error": str(e)
        })
        logger.exception(f"ABlation FAIL {run_name}")

    # Save intermediate results to JSON
    with open(ablation_file, "w") as f:
        json.dump(results, f, indent=2)

logger.info(f"Ablation finished | Results saved to {ablation_file}")


2025-10-05 21:53:33,306 | INFO | ABlation START h4_l2_t16_20251005_215333
2025-10-05 21:53:33,492 | INFO | 
=== LODO: Leaving out domain 'art_painting' ===
Evaluating: 100%|██████████| 4/4 [00:12<00:00,  3.03s/it]
2025-10-05 21:54:02,424 | INFO | [art_painting] Epoch 1/5 | Train - Loss: 0.8215, Cls: 0.8096, GRQO: 0.0120, Acc: 0.7302 | Val - Loss: 0.9730, Cls: 0.9687, GRQO: 0.0042, Acc: 0.7024
2025-10-05 21:54:02,489 | INFO | [art_painting] New best val acc: 0.7024
Evaluating: 100%|██████████| 4/4 [00:11<00:00,  2.95s/it]
2025-10-05 21:54:30,536 | INFO | [art_painting] Epoch 2/5 | Train - Loss: 0.1867, Cls: 0.1828, GRQO: 0.0039, Acc: 0.9394 | Val - Loss: 0.9185, Cls: 0.9164, GRQO: 0.0021, Acc: 0.7122
2025-10-05 21:54:30,587 | INFO | [art_painting] New best val acc: 0.7122
Evaluating: 100%|██████████| 4/4 [00:11<00:00,  2.88s/it]
2025-10-05 21:54:58,175 | INFO | [art_painting] Epoch 3/5 | Train - Loss: 0.0503, Cls: 0.0474, GRQO: 0.0030, Acc: 0.9876 | Val - Loss: 1.0418, Cls: 1.0406, GRQO