### Imports

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
from dataset import PACSDataset
from vit_grqo import ViTGRQO, grqo_loss_from_gradients
from encoder_decoder_vit import VisualDecoder, MultiheadAttn, DecoderAttn
from Visual_query_heads import QueryLosses, GRQO

### CONFIG 

In [None]:
import os
from datasets import load_dataset
from PIL import Image

# Your constants
DATA_ROOT = "../../../pacs_data/pacs_data"
DOMAINS = ["art_painting", "cartoon", "photo", "sketch"]
CLASSES = ["dog", "elephant", "giraffe", "guitar", "horse", "house", "person"]
BATCH_SIZE = 32
NUM_CLASSES = 7
NUM_EPOCHS = 5
LR = 1e-4
TOPK = 24
ALPHA = 2.0
BETA = 0.5
TAU = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dataset = load_dataset("flwrlabs/pacs", split="train")

# os.makedirs(DATA_ROOT, exist_ok=True)
# for domain in DOMAINS:
#     for cls in CLASSES:
#         os.makedirs(f"{DATA_ROOT}/{domain}/{cls}", exist_ok=True)

# for i, example in enumerate(dataset):
#     domain = example["domain"]  
#     label_idx = example["label"]  
#     label = CLASSES[label_idx]

#     if domain not in DOMAINS:
#         raise ValueError(f"Unexpected domain: {domain}. Expected one of {DOMAINS}")
#     if label not in CLASSES:
#         raise ValueError(f"Unexpected label: {label}. Expected one of {CLASSES}")
    
#     image = example["image"]
#     image.save(f"{DATA_ROOT}/{domain}/{label}/image_{i}.jpg")


### Data

In [None]:
TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])
pacs_data = PACSDataset(DATA_ROOT, DOMAINS, TRANSFORM, BATCH_SIZE)
ALL_DOMAINS = ['photo', 'art_painting', 'cartoon', 'sketch']
LEAVE_OUT = 'sketch'  
TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
VAL_DOMAIN = LEAVE_OUT

train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
    batch_size=BATCH_SIZE,
    shuffle=True
)

# Validation loader
val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)

### Model 

In [None]:
from transformers import ViTModel, AutoFeatureExtractor
import torch
import torch.nn as nn

#  resnet18 Backbone ---
dummy = torch.randn(1,3,224,224)
model = torchvision.models.resnet18(weights='IMAGENET1K_V1')
backbone = torch.nn.Sequential(*list(model.children())[:-2])
with torch.no_grad():
    dummy_out = backbone(dummy)
_,HIDDEN_DIM,_,_ = dummy_out.shape



# GRQO Hyperparameters ---
NUM_HEADS = 6
DROPOUT = 0.1
NUM_LAYERS = 3
DDROPOUT = 0.1
NUM_TOKENS = 32
TEMPERATURE = 0.1
ALPHA = 2.0
BETA = 0.5
TAU = 1e-3
LAMBDA_GRQO = 1.0
TEACHER_EMA = 0.99
NUM_CLASSES = 7
DECODER_DIM = 192

# GRQO Decoder ---
grqo_model = GRQO(
    Hidden_dim=DECODER_DIM,
    num_heads=NUM_HEADS,
    dropout=DROPOUT,
    num_tokens=NUM_TOKENS,
    ddropout=DDROPOUT,
    num_layers=NUM_LAYERS,
    num_classes=NUM_CLASSES,
    temperature=TEMPERATURE,
    alpha=ALPHA,
    beta=BETA,
    tau=TAU,
    lambda_grqo=LAMBDA_GRQO,
    teacher_ema=TEACHER_EMA,
    reward_proxy="taylor"
)

# Full Model ---
class resnetGRQO(nn.Module):
    def __init__(self, backbone, grqo_model, HD):
        super().__init__()
        self.backbone = backbone
        self.grqo = grqo_model
        self.projection_head = nn.Linear(HD,192)

    def forward(self, x, labels=None):
        outputs = self.backbone(x)
        B,D,H,W = outputs.shape
        outputs = outputs.flatten(2).transpose(1,2)
        tokens = self.projection_head(outputs)
        return self.grqo(tokens, labels)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Optimizer ---
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(trainable_params, lr=LR, weight_decay=0.01)

print(f"Model initialized on {device}")
print(f"Hidden dim: {HIDDEN_DIM}, Num tokens: {NUM_TOKENS}")

### training

In [None]:
def train_epoch(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_cls_loss = 0.0
    total_grqo_loss = 0.0
    correct = 0
    total_samples = 0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        output = model(images, labels)
        
        # Extract losses
        loss = output['loss']
        cls_loss = output['cls_loss']
        grqo_loss = output['grqo_loss']
        preds = output['preds']
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        total_loss += loss.item() * images.size(0)
        total_cls_loss += cls_loss.item() * images.size(0)
        total_grqo_loss += grqo_loss.item() * images.size(0)
        correct += (preds == labels).sum().item()
        total_samples += labels.size(0)
        
        if batch_idx % 50 == 0:
            print(f'Batch {batch_idx}, Loss: {loss.item():.4f}, '
                  f'Cls: {cls_loss.item():.4f}, GRQO: {grqo_loss.item():.4f}')
    
    avg_loss = total_loss / total_samples
    avg_cls_loss = total_cls_loss / total_samples
    avg_grqo_loss = total_grqo_loss / total_samples
    accuracy = correct / total_samples
    
    return avg_loss, avg_cls_loss, avg_grqo_loss, accuracy

### evaluation

In [None]:
def validate(model, val_loader, device):
    model.eval()
    correct = 0
    total = 0

    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        
        # GRQO needs gradients even during validation
        with torch.set_grad_enabled(True):
            output = model(images, labels)
            preds = output['preds']
        
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    accuracy = correct / total
    return accuracy

### finetuning 

In [None]:
ALL_DOMAINS = ['sketch', 'photo', 'art_painting', 'cartoon']
lodo_results = {}

for LEAVE_OUT in ALL_DOMAINS:
    print(f"\n=== LODO: Leaving out domain '{LEAVE_OUT}' ===")
    TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
    VAL_DOMAIN = LEAVE_OUT

    # Create data loaders
    train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    
    val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)
    model = resnetGRQO(backbone, grqo_model, HIDDEN_DIM).to(device)
    FREEZE = False
    if FREEZE:
        for param in model.backbone.parameters():
            param.requires_grad = False
        print("backbone frozen")
    else:
        print("backbone trainable")

    # Ensure GRQO is trainable
    for param in model.grqo.parameters():
        param.requires_grad = True
    best_val_acc = 0.0
    
    # Training loop
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
        
        # Train
        train_loss, train_cls, train_grqo, train_acc = train_epoch(model, train_loader, optimizer, DEVICE)
        
        # Validate  
        val_acc = validate(model, val_loader, DEVICE)
        
        print(f"Train - Loss: {train_loss:.4f}, Cls: {train_cls:.4f}, "
              f"GRQO: {train_grqo:.4f}, Acc: {train_acc:.4f}")
        print(f"Val Acc ({VAL_DOMAIN}): {val_acc:.4f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
    
    lodo_results[VAL_DOMAIN] = best_val_acc
    print(f"Best Val Acc for {VAL_DOMAIN}: {best_val_acc:.4f}")

# Cell 8: Results Summary
print("\n" + "="*50)
print("LODO RESULTS SUMMARY")
print("="*50)
for domain, acc in lodo_results.items():
    print(f"{domain:15}: {acc:.4f}")

avg_lodo = sum(lodo_results.values()) / len(lodo_results)
print(f"{'Average':15}: {avg_lodo:.4f}")
print("="*50)

In [None]:
ALL_DOMAINS = ['sketch', 'photo', 'art_painting', 'cartoon']
lodo_results = {}

for LEAVE_OUT in ALL_DOMAINS:
    print(f"\n=== LODO: Leaving out domain '{LEAVE_OUT}' ===")
    TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
    VAL_DOMAIN = LEAVE_OUT

    # Create data loaders
    train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    
    val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)

    # Reset model for this split
    model = resnetGRQO(backbone, grqo_model, HIDDEN_DIM).to(device)
    FREEZE = True
    if FREEZE:
        for param in model.backbone.parameters():
            param.requires_grad = False
        print("backbone frozen")
    else:
        print("backbone trainable")
    
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=LR, weight_decay=0.01
    )

    best_val_acc = 0.0
    
    # Training loop
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
        
        # Train
        train_loss, train_cls, train_grqo, train_acc = train_epoch(model, train_loader, optimizer, DEVICE)
        
        # Validate  
        val_acc = validate(model, val_loader, DEVICE)
        
        print(f"Train - Loss: {train_loss:.4f}, Cls: {train_cls:.4f}, "
              f"GRQO: {train_grqo:.4f}, Acc: {train_acc:.4f}")
        print(f"Val Acc ({VAL_DOMAIN}): {val_acc:.4f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
    
    lodo_results[VAL_DOMAIN] = best_val_acc
    print(f"Best Val Acc for {VAL_DOMAIN}: {best_val_acc:.4f}")

# Cell 8: Results Summary
print("\n" + "="*50)
print("LODO RESULTS SUMMARY")
print("="*50)
for domain, acc in lodo_results.items():
    print(f"{domain:15}: {acc:.4f}")

avg_lodo = sum(lodo_results.values()) / len(lodo_results)
print(f"{'Average':15}: {avg_lodo:.4f}")
print("="*50)

In [None]:
print("\n=== Baseline: Train & Validate on All Domains ===")

ALL_DOMAINS = ['sketch', 'photo', 'art_painting', 'cartoon']
# Train loader with all domains
all_train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in ALL_DOMAINS]
train_loader_all = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset([d.dataset for d in all_train_datasets]),
    batch_size=BATCH_SIZE,
    shuffle=True
)

# Validation loaders (per domain and all combined)
val_loaders_per_domain = {d: pacs_data.get_dataloader(domain=d, train=False) for d in ALL_DOMAINS}
all_val_datasets = [pacs_data.get_dataloader(domain=d, train=False) for d in ALL_DOMAINS]
val_loader_all = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset([d.dataset for d in all_val_datasets]),
    batch_size=BATCH_SIZE,
    shuffle=False
)

# Model (no freezing)
baseline_model = torchvision.models.resnet18(weights='IMAGENET1K_V1').to(DEVICE)

optimizer = torch.optim.AdamW(baseline_model.parameters(), lr=LR, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

best_val_acc = 0.0
best_val_per_domain = {}

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")

    # ---- Train ----
    baseline_model.train()
    running_loss, running_corrects, running_samples = 0.0, 0, 0

    for images, labels in train_loader_all:
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        outputs = baseline_model(images)
        logits = outputs.logits
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        running_loss += loss.item() * images.size(0)
        running_corrects += (preds == labels).sum().item()
        running_samples += labels.size(0)

    train_loss = running_loss / running_samples
    train_acc = running_corrects / running_samples

    # ---- Validation (all domains combined) ----
    baseline_model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader_all:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = baseline_model(images)
            preds = outputs.logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = correct / total

    # ---- Validation (per domain) ----
    per_domain_accs = {}
    baseline_model.eval()
    with torch.no_grad():
        for domain, loader in val_loaders_per_domain.items():
            correct, total = 0, 0
            for images, labels in loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = baseline_model(images)
                preds = outputs.logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
            per_domain_accs[domain] = correct / total if total > 0 else 0.0

    # Print results
    print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
    print(f"Val Acc (All Domains Combined): {val_acc:.4f}")
    for domain, acc in per_domain_accs.items():
        print(f"  {domain:15}: {acc:.4f}")

    # Track best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_val_per_domain = per_domain_accs.copy()

# ---------------- Results Summary ----------------
print("\n" + "="*50)
print("LODO RESULTS SUMMARY")
print("="*50)
for domain, acc in lodo_results.items():
    print(f"{domain:15}: {acc:.4f}")

avg_lodo = sum(lodo_results.values()) / len(lodo_results)
print(f"{'Average LODO':15}: {avg_lodo:.4f}")
print(f"{'All-domains baseline':15}: {best_val_acc:.4f}")
for domain, acc in best_val_per_domain.items():
    print(f"  {domain:15}: {acc:.4f}")
print("="*50)


In [None]:
import matplotlib.pyplot as plt

# Data
results = {
    "photo": 0.9581,
    "art_painting": 0.7293,
    "cartoon": 0.7186,
    "sketch": 0.3461,
    "All-domains baseline": 0.9010
}

# Prepare labels & values; wrap long label onto two lines for neatness
labels = []
values = []
for k, v in results.items():
    if "All-domains" in k:
        labels.append("All-domains\nbaseline")   # wrap long label
    else:
        labels.append(k.replace("_", " "))       # nicer display for underscores
    values.append(v)

# Plot
fig, ax = plt.subplots(figsize=(9, 5))
x = range(len(labels))

# Draw bars; make baseline visually distinct
colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2", "#888888"]
bars = ax.bar(x, values, color=colors, edgecolor="black", linewidth=0.7)

# Axis and ticks
ax.set_ylim(0, 1.05)
ax.set_ylabel("Accuracy", fontsize=11)
ax.set_title("LODO Results by Domain", fontsize=13, weight="bold")
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=10, rotation=0, ha="center")

# Add horizontal grid lines for readability (below bars)
ax.yaxis.grid(True, linestyle="--", alpha=0.6)
ax.set_axisbelow(True)

# Annotate values above bars with consistent alignment
for bar, val in zip(bars, values):
    height = bar.get_height()
    ax.text(
        bar.get_x() + bar.get_width() / 2,
        height + 0.02,
        f"{val:.3f}",
        ha="center",
        va="bottom",
        fontsize=9,
        fontweight="semibold"
    )

# Tidy up spines
for spine in ("top", "right"):
    ax.spines[spine].set_visible(False)

plt.tight_layout()
# Save optionally:
# fig.savefig("lodo_results_bar.png", dpi=300)
plt.show()
