### Imports

In [1]:
import sys
sys.dont_write_bytecode = True

import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from tqdm import tqdm
import warnings
import logging
from datetime import datetime
warnings.filterwarnings('ignore')

from model import get_model
from config import CFG
from dataset import PACSDataset
from plot import plot_domainwise_accuracy
from transform import get_transforms
from runner import run_baseline, run_lodo

torch.manual_seed(CFG["system"]["seed"])
np.random.seed(CFG["system"]["seed"])

device = CFG["system"]["device"]
print(f"Device: {device}")
print(f"PyTorch: {torch.__version__}")


Device: cuda
PyTorch: 2.7.1+cu128


### DataLoading

In [2]:
train_transform, test_transform = get_transforms(img_size=224, augment=False, use_imagenet_norm=False)

pacs = PACSDataset(
    data_root=CFG["datasets"]["PACS"]["root"],
    transform=train_transform,
    batch_size=CFG["train"]["batch_size"]
)

print("\nData loaders ready!")


Data loaders ready!


### Logging

In [3]:
dataset_name = "PACS" 
base_dir = os.path.join(os.getcwd(), dataset_name)
subdirs = ["logs", "checkpoints", "plots"]

for sub in subdirs:
    os.makedirs(os.path.join(base_dir, sub), exist_ok=True)

timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_file = os.path.join(base_dir, "logs", f"train_{timestamp}.log")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger(f"{dataset_name}_logger")

logger.info(f"Initialized experiment directories for {dataset_name}")
logger.info(f"Logs: {os.path.join(base_dir, 'logs')}")
logger.info(f"Checkpoints: {os.path.join(base_dir, 'checkpoints')}")
logger.info(f"Plots: {os.path.join(base_dir, 'plots')}")

2025-10-06 00:02:22,397 | INFO | Initialized experiment directories for PACS
2025-10-06 00:02:22,397 | INFO | Logs: d:\Haseeb\SPROJ\GRQO\Vit-GRQO\vit-tiny\PACS\logs
2025-10-06 00:02:22,397 | INFO | Checkpoints: d:\Haseeb\SPROJ\GRQO\Vit-GRQO\vit-tiny\PACS\checkpoints
2025-10-06 00:02:22,397 | INFO | Plots: d:\Haseeb\SPROJ\GRQO\Vit-GRQO\vit-tiny\PACS\plots


### Setup

In [4]:
domains = CFG["datasets"]["PACS"]["domains"]
loaders = {d: {"train": pacs.get_dataloader(d, train=True), "val": pacs.get_dataloader(d, train=False)} for d in domains}
ckpt_root = os.path.join(base_dir, "checkpoints")
log_dir = os.path.join(base_dir, "logs")
plots_dir = os.path.join(base_dir, "plots")
os.makedirs(ckpt_root, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
os.makedirs(plots_dir, exist_ok=True)
model_factory = lambda cfg, dataset_key: get_model(cfg,dataset="PACS")
optimizer_fn = lambda model: optim.AdamW(model.parameters(), lr=CFG["train"]["lr"], weight_decay=CFG["train"].get("weight_decay", 0.01))
device = CFG["system"]["device"]
epochs = CFG["train"]["epochs"]


{
  "lodo_results": {
    "art_painting": 0.8341463414634146,
    "cartoon": 0.7974413646055437,
    "photo": 0.9580838323353293,
    "sketch": 0.6017811704834606
  },
  "timestamp": "20251004_020611"
}

### Leave One Domain Out

In [5]:
lodo_results, lodo_mean, lodo_summary = run_lodo(
    model_fn=model_factory,
    CFG=CFG,
    logger=logger,
    dataset_key="PACS",
    domains=domains,
    loaders=loaders,
    optimizer_fn=optimizer_fn,
    device=device,
    ckpt_root=ckpt_root,
    log_dir=log_dir,
    epochs=epochs
)

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-10-04 04:41:12,803 | INFO | 
=== LODO: Leaving out domain 'art_painting' ===
Evaluating: 100%|██████████| 4/4 [00:11<00:00,  2.82s/it]
2025-10-04 04:41:40,219 | INFO | [art_painting] Epoch 1/5 | Train - Loss: 1.3446, Cls: 1.3272, GRQO: 0.0173, Acc: 0.5464 | Val - Loss: 1.0757, Cls: 1.0643, GRQO: 0.0114, Acc: 0.6561
2025-10-04 04:41:40,259 | INFO | [art_painting] New best val acc: 0.6561
Evaluating: 100%|██████████| 4/4 [00:11<00:00,  2.80s/it]
2025-10-04 04:42:07,319 | INFO | [art_painting] Epoch 2/5 | Train - Loss: 0.4557, Cls: 0.4497, GRQO: 0.0060, Acc: 0.8484 | Val - Loss: 0.7272, Cls: 0.7218, GRQO: 0.0054, Acc: 0.7610
2025-10-04 04:42:07,370 | INFO | [art_painting] New best val acc: 0.7610
Eva

### Baseline

In [6]:
model_name = "WinKawaks/vit-tiny-patch16-224"
baseline_results, baseline_mean = run_baseline(
    model_name=model_name,
    CFG=CFG,
    logger=logger,
    dataset_key="PACS",
    domains=domains,
    loaders=loaders,
    optimizer_fn=optimizer_fn,
    device=device,
    epochs=CFG["train"]["epochs"]
)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 192]) in the checkpoint and torch.Size([7, 192]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-10-04 04:51:58,600 | INFO | 
=== Baseline LODO: Leaving out domain 'art_painting' ===
2025-10-04 04:52:35,525 | INFO | [art_painting] Epoch 1/5 | Train - Loss: 0.6406, Acc: 0.7664 | Val Acc: 0.7927
2025-10-04 04:53:17,746 | INFO | [art_painting] Epoch 2/5 | Train - Loss: 0.1495, Acc: 0.9495 | Val Acc: 0.7951
2025-10-04 04:53:49,089 | INFO | [art_painting] Epoch 3/5 | Train - Loss: 0.0405, Acc: 0.9899 | Val Acc: 0.8195
2025-10-04 04:54:25,701 | INFO | [art_p

### Plot

In [None]:
baseline_vals = [baseline_results[d] for d in domains]
grqo_vals = [lodo_results[d] for d in domains]
plot_path = os.path.join(plots_dir, f"comparison_{__import__('datetime').datetime.now().strftime('%Y%m%d_%H%M%S')}.png")
plot_domainwise_accuracy(domains, baseline_vals, grqo_vals, save_path=plot_path)
logger.info(f"Saved comparison plot to {plot_path}")

### Ablation1

In [5]:
# Cell: run_ablation_reduced
from itertools import product
import copy
import json
import time
import os

# Reduced grid for ~24 runs
heads_list = [4, 6, 8]        # num_heads
layers_list = [2, 4, 6]          # num_layers
tokens_list = [16, 24, 32, 48]  # num_tokens / topk

results = []
timestamp = time.strftime("%Y%m%d_%H%M%S")
ablation_file = os.path.join(log_dir, f"ablation_results_{timestamp}.json")

# Iterate over combinations
for heads, layers, tokens in product(heads_list, layers_list, tokens_list):
    cfg_run = copy.deepcopy(CFG)
    cfg_run["grqo"]["num_heads"] = heads
    cfg_run["grqo"]["num_layers"] = layers
    cfg_run["grqo"]["num_tokens"] = tokens
    cfg_run["grqo"]["topk"] = tokens

    run_name = f"h{heads}_l{layers}_t{tokens}_{timestamp}"
    logger.info(f"ABlation START {run_name}")

    try:
        lodo_results, mean_acc, summary_path = run_lodo(
            model_fn=model_factory,
            CFG=cfg_run,
            logger=logger,
            dataset_key="PACS",
            domains=domains,
            loaders=loaders,
            optimizer_fn=optimizer_fn,
            device=device,
            ckpt_root=ckpt_root,
            log_dir=log_dir,
            epochs=epochs
        )

        results.append({
            "run": run_name,
            "heads": heads,
            "layers": layers,
            "tokens": tokens,
            "mean_acc": mean_acc,
            "lodo_results": lodo_results,
            "summary": summary_path
        })
        logger.info(f"ABlation DONE {run_name} mean_acc={mean_acc:.4f}")

    except Exception as e:
        results.append({
            "run": run_name,
            "heads": heads,
            "layers": layers,
            "tokens": tokens,
            "error": str(e)
        })
        logger.exception(f"ABlation FAIL {run_name}")

    # Save intermediate results to JSON
    with open(ablation_file, "w") as f:
        json.dump(results, f, indent=2)

logger.info(f"Ablation finished | Results saved to {ablation_file}")


2025-10-04 05:09:25,124 | INFO | ABlation START h4_l2_t16_20251004_050925
Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2025-10-04 05:09:27,355 | INFO | 
=== LODO: Leaving out domain 'art_painting' ===
Evaluating: 100%|██████████| 4/4 [00:11<00:00,  2.82s/it]
2025-10-04 05:09:56,655 | INFO | [art_painting] Epoch 1/5 | Train - Loss: 1.2713, Cls: 1.2511, GRQO: 0.0202, Acc: 0.5840 | Val - Loss: 1.0274, Cls: 1.0149, GRQO: 0.0125, Acc: 0.6537
2025-10-04 05:09:56,691 | INFO | [art_painting] New best val acc: 0.6537
Evaluating: 100%|██████████| 4/4 [00:11<00:00,  2.79s/it]
2025-10-04 05:10:23,537 | INFO | [art_painting] Epoch 2/5 | Train - Loss: 0.4615, Cls: 0.4545, GRQO: 0.0070, Acc: 0.8464 | Val - Loss: 0.7127, Cls: 0.7065, GRQO: 0.0063, Acc: 0.7537
20