### Imports

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from dataset import PACSDataset
from vit_grqo import ViTGRQO, grqo_loss_from_gradients

### CONFIG 

In [None]:
DATA_ROOT = "../../../pacs_data/pacs_data"
DOMAINS = ["art_painting", "cartoon", "photo", "sketch"]
BATCH_SIZE = 24
NUM_CLASSES = 7
NUM_EPOCHS = 5
LR = 1e-4
TOPK = 16
ALPHA = 1.0
BETA = 0.5
TAU = 1e-3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Data

In [None]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

pacs_data = PACSDataset(DATA_ROOT, DOMAINS, transform, BATCH_SIZE)
ALL_DOMAINS = ['photo', 'art_painting', 'cartoon', 'sketch']
LEAVE_OUT = 'sketch'  
TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
VAL_DOMAIN = LEAVE_OUT

train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
    batch_size=BATCH_SIZE,
    shuffle=True
)

# Validation loader
val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)

### Model 

In [None]:
model = ViTGRQO(num_classes=NUM_CLASSES, topk=TOPK)
model.to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.05)
criterion = nn.CrossEntropyLoss()

### Finetuning

In [None]:
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    
    for images, labels in train_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        # Forward pass
        logits, patch_tokens, token_probs = model(images)
        
        # Classification + GRQO loss
        cls_loss = criterion(logits, labels)
        grqo_loss = grqo_loss_from_gradients(
            logits, patch_tokens, token_probs, labels,
            teacher_probs=model.teacher_probs,
            alpha=ALPHA, beta=BETA, topk=TOPK, tau=TAU
        )
        loss = cls_loss + grqo_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # EMA teacher update
        with torch.no_grad():
            if model.teacher_probs.numel() != token_probs.mean(dim=0).numel():
                model.teacher_probs = token_probs.mean(dim=0, keepdim=True)
            else:
                model.teacher_probs = 0.9*model.teacher_probs + 0.1*token_probs.mean(dim=0, keepdim=True)
        
        running_loss += loss.item() * images.size(0)
    
    # ------------------- VALIDATION -----------------
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            logits, _, _ = model(images)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    val_acc = correct / total
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] - Loss: {epoch_loss:.4f}, Val Acc (LODO on {VAL_DOMAIN}): {val_acc:.4f}")
