### Imports

In [2]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import ViTForImageClassification
import sys, os
sys.path.append(os.path.abspath(".."))  
from dataset import PACSDataset
from vit_grqo import ViTGRQO, grqo_loss_from_gradients
from encoder_decoder_vit import VisualDecoder, MultiheadAttn, DecoderAttn
from Visual_query_heads import QueryLosses, GRQO

### CONFIG 

In [3]:
import os
from datasets import load_dataset
from PIL import Image

# Your constants
DATA_ROOT = r"D:\Haseeb\SPROJ\PACS ViT\pacs_data\pacs_data"
DOMAINS = ["art_painting", "cartoon", "photo", "sketch"]
CLASSES = ["dog", "elephant", "giraffe", "guitar", "horse", "house", "person"]
BATCH_SIZE = 32
NUM_CLASSES = 7
NUM_EPOCHS = 5
LR = 1e-4
TOPK = 24
ALPHA = 2.0
BETA = 0.5
TAU = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dataset = load_dataset("flwrlabs/pacs", split="train")

# os.makedirs(DATA_ROOT, exist_ok=True)
# for domain in DOMAINS:
#     for cls in CLASSES:
#         os.makedirs(f"{DATA_ROOT}/{domain}/{cls}", exist_ok=True)

# for i, example in enumerate(dataset):
#     domain = example["domain"]  
#     label_idx = example["label"]  
#     label = CLASSES[label_idx]

#     if domain not in DOMAINS:
#         raise ValueError(f"Unexpected domain: {domain}. Expected one of {DOMAINS}")
#     if label not in CLASSES:
#         raise ValueError(f"Unexpected label: {label}. Expected one of {CLASSES}")
    
#     image = example["image"]
#     image.save(f"{DATA_ROOT}/{domain}/{label}/image_{i}.jpg")


### Data

In [4]:
TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])
pacs_data = PACSDataset(DATA_ROOT, DOMAINS, TRANSFORM, BATCH_SIZE)
ALL_DOMAINS = ['photo', 'art_painting', 'cartoon', 'sketch']
LEAVE_OUT = 'sketch'  
TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
VAL_DOMAIN = LEAVE_OUT

train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
    batch_size=BATCH_SIZE,
    shuffle=True
)

# Validation loader
val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)

### Model 

In [5]:
from transformers import ViTModel, AutoFeatureExtractor
import torch
import torch.nn as nn

# ViT-Small Backbone from Hugging Face ---
vit_encoder = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224")
HIDDEN_DIM = vit_encoder.config.hidden_size  # 384

# GRQO Hyperparameters ---
NUM_HEADS = 6
DROPOUT = 0.1
NUM_LAYERS = 3
DDROPOUT = 0.1
NUM_TOKENS = 32
TEMPERATURE = 0.1
ALPHA = 2.0
BETA = 0.5
TAU = 1e-3
LAMBDA_GRQO = 1.0
TEACHER_EMA = 0.99
NUM_CLASSES = 7

# GRQO Decoder ---
grqo_model = GRQO(
    Hidden_dim=HIDDEN_DIM,
    num_heads=NUM_HEADS,
    dropout=DROPOUT,
    num_tokens=NUM_TOKENS,
    ddropout=DDROPOUT,
    num_layers=NUM_LAYERS,
    num_classes=NUM_CLASSES,
    temperature=TEMPERATURE,
    alpha=ALPHA,
    beta=BETA,
    tau=TAU,
    lambda_grqo=LAMBDA_GRQO,
    teacher_ema=TEACHER_EMA,
    reward_proxy="taylor"
)

# Full ViTGRQO Model ---
class ViTGRQO(nn.Module):
    def __init__(self, vit_encoder, grqo_model):
        super().__init__()
        self.vit = vit_encoder
        self.grqo = grqo_model

    def forward(self, x, labels=None):
        # Get patch embeddings from HF ViT (exclude CLS token)
        outputs = self.vit(pixel_values=x, output_hidden_states=True)
        patch_tokens = outputs.last_hidden_state[:, 1:, :]  # [B, N, D] skip CLS token
        return self.grqo(patch_tokens, labels)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Model initialized on {device}")
print(f"Hidden dim: {HIDDEN_DIM}, Num tokens: {NUM_TOKENS}")

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model initialized on cuda
Hidden dim: 384, Num tokens: 32


### training

In [6]:
def train_epoch(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_cls_loss = 0.0
    total_grqo_loss = 0.0
    correct = 0
    total_samples = 0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        output = model(images, labels)
        
        # Extract losses
        loss = output['loss']
        cls_loss = output['cls_loss']
        grqo_loss = output['grqo_loss']
        preds = output['preds']
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        total_loss += loss.item() * images.size(0)
        total_cls_loss += cls_loss.item() * images.size(0)
        total_grqo_loss += grqo_loss.item() * images.size(0)
        correct += (preds == labels).sum().item()
        total_samples += labels.size(0)
        
        if batch_idx % 50 == 0:
            print(f'Batch {batch_idx}, Loss: {loss.item():.4f}, '
                  f'Cls: {cls_loss.item():.4f}, GRQO: {grqo_loss.item():.4f}')
    
    avg_loss = total_loss / total_samples
    avg_cls_loss = total_cls_loss / total_samples
    avg_grqo_loss = total_grqo_loss / total_samples
    accuracy = correct / total_samples
    
    return avg_loss, avg_cls_loss, avg_grqo_loss, accuracy

### evaluation

In [7]:
def validate(model, val_loader, device):
    model.eval()
    correct = 0
    total = 0

    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        
        # GRQO needs gradients even during validation
        with torch.set_grad_enabled(True):
            output = model(images, labels)
            preds = output['preds']
        
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    accuracy = correct / total
    return accuracy

### finetuning 

In [8]:
ALL_DOMAINS = ['sketch', 'photo', 'art_painting', 'cartoon']
lodo_results = {}

for LEAVE_OUT in ALL_DOMAINS:
    print(f"\n=== LODO: Leaving out domain '{LEAVE_OUT}' ===")
    TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
    VAL_DOMAIN = LEAVE_OUT

    # Create data loaders
    train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    
    val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)

    # Reset model for this split
    model = ViTGRQO(vit_encoder, grqo_model).to(DEVICE)
    FREEZE_VIT=False
    if FREEZE_VIT:
        for param in model.vit.parameters():
            param.requires_grad = False
    
        # Optimizer ---
    trainable_params = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.AdamW(trainable_params, lr=LR, weight_decay=0.01)
    best_val_acc = 0.0
    
    # Training loop
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
        
        # Train
        train_loss, train_cls, train_grqo, train_acc = train_epoch(model, train_loader, optimizer, DEVICE)
        
        # Validate  
        val_acc = validate(model, val_loader, DEVICE)
        
        print(f"Train - Loss: {train_loss:.4f}, Cls: {train_cls:.4f}, "
              f"GRQO: {train_grqo:.4f}, Acc: {train_acc:.4f}")
        print(f"Val Acc ({VAL_DOMAIN}): {val_acc:.4f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
    
    lodo_results[VAL_DOMAIN] = best_val_acc
    print(f"Best Val Acc for {VAL_DOMAIN}: {best_val_acc:.4f}")

# Cell 8: Results Summary
print("\n" + "="*50)
print("LODO RESULTS SUMMARY")
print("="*50)
for domain, acc in lodo_results.items():
    print(f"{domain:15}: {acc:.4f}")

avg_lodo = sum(lodo_results.values()) / len(lodo_results)
print(f"{'Average':15}: {avg_lodo:.4f}")
print("="*50)


=== LODO: Leaving out domain 'sketch' ===

Epoch 1/5
Batch 0, Loss: 2.0721, Cls: 2.0182, GRQO: 0.0539
Batch 50, Loss: 0.1220, Cls: 0.1178, GRQO: 0.0042
Batch 100, Loss: 0.1783, Cls: 0.1748, GRQO: 0.0035
Batch 150, Loss: 0.1105, Cls: 0.1079, GRQO: 0.0026
Train - Loss: 0.3089, Cls: 0.3027, GRQO: 0.0061, Acc: 0.9033
Val Acc (sketch): 0.5763

Epoch 2/5
Batch 0, Loss: 0.0794, Cls: 0.0763, GRQO: 0.0030
Batch 50, Loss: 0.0141, Cls: 0.0118, GRQO: 0.0023
Batch 100, Loss: 0.0885, Cls: 0.0863, GRQO: 0.0022
Batch 150, Loss: 0.0059, Cls: 0.0040, GRQO: 0.0019
Train - Loss: 0.0640, Cls: 0.0616, GRQO: 0.0023, Acc: 0.9816
Val Acc (sketch): 0.5611

Epoch 3/5
Batch 0, Loss: 0.1164, Cls: 0.1145, GRQO: 0.0019
Batch 50, Loss: 0.0093, Cls: 0.0074, GRQO: 0.0019
Batch 100, Loss: 0.0059, Cls: 0.0040, GRQO: 0.0018
Batch 150, Loss: 0.1432, Cls: 0.1414, GRQO: 0.0017
Train - Loss: 0.0619, Cls: 0.0601, GRQO: 0.0018, Acc: 0.9837
Val Acc (sketch): 0.4758

Epoch 4/5
Batch 0, Loss: 0.1011, Cls: 0.0993, GRQO: 0.0018
Bat

In [12]:
ALL_DOMAINS = ['sketch', 'photo', 'art_painting', 'cartoon']
lodo_results = {}

for LEAVE_OUT in ALL_DOMAINS:
    print(f"\n=== LODO: Leaving out domain '{LEAVE_OUT}' ===")
    TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
    VAL_DOMAIN = LEAVE_OUT

    # Create data loaders
    train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    
    val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)
    vit_encoder = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224")
    # Reset model for this split
    grqo_model = GRQO(
        Hidden_dim=HIDDEN_DIM,
        num_heads=NUM_HEADS,
        dropout=DROPOUT,
        num_tokens=NUM_TOKENS,
        ddropout=DDROPOUT,
        num_layers=NUM_LAYERS,
        num_classes=NUM_CLASSES,
        temperature=TEMPERATURE,
        alpha=ALPHA,
        beta=BETA,
        tau=TAU,
        lambda_grqo=LAMBDA_GRQO,
        teacher_ema=TEACHER_EMA,
        reward_proxy="taylor"
    )
    model = ViTGRQO(vit_encoder, grqo_model).to(DEVICE)
    FREEZE_VIT=True
    if FREEZE_VIT:
        for param in model.vit.parameters():
            param.requires_grad = False
    
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=LR, weight_decay=0.01
    )

    best_val_acc = 0.0
    
    # Training loop
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
        
        # Train
        train_loss, train_cls, train_grqo, train_acc = train_epoch(model, train_loader, optimizer, DEVICE)
        
        # Validate  
        val_acc = validate(model, val_loader, DEVICE)
        
        print(f"Train - Loss: {train_loss:.4f}, Cls: {train_cls:.4f}, "
              f"GRQO: {train_grqo:.4f}, Acc: {train_acc:.4f}")
        print(f"Val Acc ({VAL_DOMAIN}): {val_acc:.4f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
    
    lodo_results[VAL_DOMAIN] = best_val_acc
    print(f"Best Val Acc for {VAL_DOMAIN}: {best_val_acc:.4f}")

# Cell 8: Results Summary
print("\n" + "="*50)
print("LODO RESULTS SUMMARY")
print("="*50)
for domain, acc in lodo_results.items():
    print(f"{domain:15}: {acc:.4f}")

avg_lodo = sum(lodo_results.values()) / len(lodo_results)
print(f"{'Average':15}: {avg_lodo:.4f}")
print("="*50)


=== LODO: Leaving out domain 'sketch' ===


Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5
Batch 0, Loss: 1.9759, Cls: 1.9270, GRQO: 0.0489
Batch 50, Loss: 0.3641, Cls: 0.3533, GRQO: 0.0109
Batch 100, Loss: 0.1038, Cls: 0.0993, GRQO: 0.0045
Batch 150, Loss: 0.2622, Cls: 0.2580, GRQO: 0.0042
Train - Loss: 0.4285, Cls: 0.4200, GRQO: 0.0085, Acc: 0.8765
Val Acc (sketch): 0.3969

Epoch 2/5
Batch 0, Loss: 0.1284, Cls: 0.1247, GRQO: 0.0037
Batch 50, Loss: 0.0844, Cls: 0.0809, GRQO: 0.0034
Batch 100, Loss: 0.1651, Cls: 0.1614, GRQO: 0.0037
Batch 150, Loss: 0.0296, Cls: 0.0264, GRQO: 0.0032
Train - Loss: 0.1216, Cls: 0.1180, GRQO: 0.0036, Acc: 0.9618
Val Acc (sketch): 0.3842

Epoch 3/5
Batch 0, Loss: 0.0492, Cls: 0.0461, GRQO: 0.0030
Batch 50, Loss: 0.1007, Cls: 0.0979, GRQO: 0.0029
Batch 100, Loss: 0.1257, Cls: 0.1226, GRQO: 0.0031
Batch 150, Loss: 0.0267, Cls: 0.0239, GRQO: 0.0029
Train - Loss: 0.0601, Cls: 0.0573, GRQO: 0.0029, Acc: 0.9806
Val Acc (sketch): 0.3728

Epoch 4/5
Batch 0, Loss: 0.0850, Cls: 0.0823, GRQO: 0.0027
Batch 50, Loss: 0.0187, Cls: 0.0162, GRQO: 0.0

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5
Batch 0, Loss: 1.9707, Cls: 1.8991, GRQO: 0.0715
Batch 50, Loss: 0.6209, Cls: 0.6138, GRQO: 0.0072
Batch 100, Loss: 0.5557, Cls: 0.5491, GRQO: 0.0065
Batch 150, Loss: 0.4079, Cls: 0.4026, GRQO: 0.0053
Batch 200, Loss: 0.5976, Cls: 0.5930, GRQO: 0.0046
Train - Loss: 0.6423, Cls: 0.6347, GRQO: 0.0076, Acc: 0.7806
Val Acc (photo): 0.9910

Epoch 2/5
Batch 0, Loss: 0.3265, Cls: 0.3221, GRQO: 0.0044
Batch 50, Loss: 0.1633, Cls: 0.1594, GRQO: 0.0039
Batch 100, Loss: 0.0954, Cls: 0.0919, GRQO: 0.0035
Batch 150, Loss: 0.1649, Cls: 0.1613, GRQO: 0.0036
Batch 200, Loss: 0.3426, Cls: 0.3395, GRQO: 0.0031
Train - Loss: 0.2804, Cls: 0.2767, GRQO: 0.0037, Acc: 0.9023
Val Acc (photo): 0.9850

Epoch 3/5
Batch 0, Loss: 0.1816, Cls: 0.1782, GRQO: 0.0034
Batch 50, Loss: 0.2601, Cls: 0.2571, GRQO: 0.0030
Batch 100, Loss: 0.3011, Cls: 0.2980, GRQO: 0.0031
Batch 150, Loss: 0.0538, Cls: 0.0504, GRQO: 0.0034
Batch 200, Loss: 0.3761, Cls: 0.3731, GRQO: 0.0030
Train - Loss: 0.1826, Cls: 0.1794, GRQO: 

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5
Batch 0, Loss: 2.0875, Cls: 1.9822, GRQO: 0.1053
Batch 50, Loss: 0.6611, Cls: 0.6532, GRQO: 0.0079
Batch 100, Loss: 0.2723, Cls: 0.2659, GRQO: 0.0064
Batch 150, Loss: 0.4535, Cls: 0.4482, GRQO: 0.0053
Train - Loss: 0.5997, Cls: 0.5912, GRQO: 0.0086, Acc: 0.7965
Val Acc (art_painting): 0.8341

Epoch 2/5
Batch 0, Loss: 0.1724, Cls: 0.1679, GRQO: 0.0045
Batch 50, Loss: 0.4125, Cls: 0.4087, GRQO: 0.0038
Batch 100, Loss: 0.2649, Cls: 0.2609, GRQO: 0.0040
Batch 150, Loss: 0.2352, Cls: 0.2317, GRQO: 0.0035
Train - Loss: 0.2568, Cls: 0.2530, GRQO: 0.0038, Acc: 0.9067
Val Acc (art_painting): 0.8512

Epoch 3/5
Batch 0, Loss: 0.0605, Cls: 0.0573, GRQO: 0.0032
Batch 50, Loss: 0.1783, Cls: 0.1752, GRQO: 0.0031
Batch 100, Loss: 0.2864, Cls: 0.2832, GRQO: 0.0032
Batch 150, Loss: 0.2197, Cls: 0.2164, GRQO: 0.0032
Train - Loss: 0.1690, Cls: 0.1658, GRQO: 0.0032, Acc: 0.9396
Val Acc (art_painting): 0.8610

Epoch 4/5
Batch 0, Loss: 0.0986, Cls: 0.0957, GRQO: 0.0028
Batch 50, Loss: 0.1632, Cls:

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5
Batch 0, Loss: 1.9982, Cls: 1.9123, GRQO: 0.0859
Batch 50, Loss: 0.4416, Cls: 0.4340, GRQO: 0.0076
Batch 100, Loss: 0.2336, Cls: 0.2284, GRQO: 0.0052
Batch 150, Loss: 0.3091, Cls: 0.3050, GRQO: 0.0041
Train - Loss: 0.5776, Cls: 0.5701, GRQO: 0.0076, Acc: 0.8058
Val Acc (cartoon): 0.6503

Epoch 2/5
Batch 0, Loss: 0.3899, Cls: 0.3856, GRQO: 0.0043
Batch 50, Loss: 0.3326, Cls: 0.3290, GRQO: 0.0036
Batch 100, Loss: 0.1211, Cls: 0.1177, GRQO: 0.0034
Batch 150, Loss: 0.0658, Cls: 0.0625, GRQO: 0.0032
Train - Loss: 0.2294, Cls: 0.2260, GRQO: 0.0034, Acc: 0.9174
Val Acc (cartoon): 0.7420

Epoch 3/5
Batch 0, Loss: 0.2090, Cls: 0.2058, GRQO: 0.0032
Batch 50, Loss: 0.1360, Cls: 0.1328, GRQO: 0.0032
Batch 100, Loss: 0.1734, Cls: 0.1708, GRQO: 0.0026
Batch 150, Loss: 0.1257, Cls: 0.1231, GRQO: 0.0026
Train - Loss: 0.1449, Cls: 0.1421, GRQO: 0.0028, Acc: 0.9501
Val Acc (cartoon): 0.7356

Epoch 4/5
Batch 0, Loss: 0.1344, Cls: 0.1318, GRQO: 0.0026
Batch 50, Loss: 0.1200, Cls: 0.1174, GRQO: 

In [14]:
print("\n=== Leave-One-Domain-Out (LODO) Training & Validation ===")

lodo_results = {}  # store per-domain results

for LEAVE_OUT in ALL_DOMAINS:
    print(f"\n=== LODO: Leaving out {LEAVE_OUT.upper()} for validation ===")
    
    TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
    VAL_DOMAIN = LEAVE_OUT

    # Train loader (concat all train domains)
    train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    # Validation loader (only leave-out domain)
    val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)

    # Model (reset for each LODO run)
    model = ViTForImageClassification.from_pretrained(
        "WinKawaks/vit-small-patch16-224",
        num_labels=NUM_CLASSES,
        ignore_mismatched_sizes=True
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = 0.0

    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS} (Leave out {VAL_DOMAIN})")

        # ---- Train ----
        model.train()
        running_loss, running_corrects, running_samples = 0.0, 0, 0

        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            outputs = model(images)
            logits = outputs.logits
            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            preds = logits.argmax(dim=1)
            running_loss += loss.item() * images.size(0)
            running_corrects += (preds == labels).sum().item()
            running_samples += labels.size(0)

        train_loss = running_loss / running_samples
        train_acc = running_corrects / running_samples

        # ---- Validation on held-out domain ----
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                preds = outputs.logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        val_acc = correct / total if total > 0 else 0.0

        print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
        print(f"Val Acc ({VAL_DOMAIN}): {val_acc:.4f}")

        # Track best
        best_val_acc = max(best_val_acc, val_acc)

    # Save best for this LODO run
    lodo_results[VAL_DOMAIN] = best_val_acc

# ---------------- Results Summary ----------------
print("\n" + "="*50)
print("LODO RESULTS SUMMARY")
print("="*50)
for domain, acc in lodo_results.items():
    print(f"{domain:15}: {acc:.4f}")

avg_lodo = sum(lodo_results.values()) / len(lodo_results)
print(f"{'Average LODO':15}: {avg_lodo:.4f}")
print("="*50)



=== Leave-One-Domain-Out (LODO) Training & Validation ===

=== LODO: Leaving out SKETCH for validation ===


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 384]) in the checkpoint and torch.Size([7, 384]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5 (Leave out sketch)
Train - Loss: 0.2360, Acc: 0.9227
Val Acc (sketch): 0.4402

Epoch 2/5 (Leave out sketch)
Train - Loss: 0.0203, Acc: 0.9953
Val Acc (sketch): 0.5115

Epoch 3/5 (Leave out sketch)
Train - Loss: 0.0131, Acc: 0.9961
Val Acc (sketch): 0.5076

Epoch 4/5 (Leave out sketch)
Train - Loss: 0.0500, Acc: 0.9854
Val Acc (sketch): 0.5102

Epoch 5/5 (Leave out sketch)
Train - Loss: 0.0344, Acc: 0.9897
Val Acc (sketch): 0.6285

=== LODO: Leaving out PHOTO for validation ===


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 384]) in the checkpoint and torch.Size([7, 384]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5 (Leave out photo)
Train - Loss: 0.3381, Acc: 0.8800
Val Acc (photo): 0.9731

Epoch 2/5 (Leave out photo)
Train - Loss: 0.0767, Acc: 0.9731
Val Acc (photo): 0.9731

Epoch 3/5 (Leave out photo)
Train - Loss: 0.0467, Acc: 0.9833
Val Acc (photo): 0.9701

Epoch 4/5 (Leave out photo)
Train - Loss: 0.0414, Acc: 0.9857
Val Acc (photo): 0.9611

Epoch 5/5 (Leave out photo)
Train - Loss: 0.0402, Acc: 0.9872
Val Acc (photo): 0.9341

=== LODO: Leaving out ART_PAINTING for validation ===


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 384]) in the checkpoint and torch.Size([7, 384]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5 (Leave out art_painting)
Train - Loss: 0.3617, Acc: 0.8664
Val Acc (art_painting): 0.8561

Epoch 2/5 (Leave out art_painting)
Train - Loss: 0.0735, Acc: 0.9753
Val Acc (art_painting): 0.8634

Epoch 3/5 (Leave out art_painting)
Train - Loss: 0.0303, Acc: 0.9898
Val Acc (art_painting): 0.8707

Epoch 4/5 (Leave out art_painting)
Train - Loss: 0.0276, Acc: 0.9899
Val Acc (art_painting): 0.7878

Epoch 5/5 (Leave out art_painting)
Train - Loss: 0.0571, Acc: 0.9778
Val Acc (art_painting): 0.8439

=== LODO: Leaving out CARTOON for validation ===


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 384]) in the checkpoint and torch.Size([7, 384]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5 (Leave out cartoon)
Train - Loss: 0.3680, Acc: 0.8622
Val Acc (cartoon): 0.7996

Epoch 2/5 (Leave out cartoon)
Train - Loss: 0.0893, Acc: 0.9684
Val Acc (cartoon): 0.7910

Epoch 3/5 (Leave out cartoon)
Train - Loss: 0.0418, Acc: 0.9851
Val Acc (cartoon): 0.7996

Epoch 4/5 (Leave out cartoon)
Train - Loss: 0.0204, Acc: 0.9935
Val Acc (cartoon): 0.7910

Epoch 5/5 (Leave out cartoon)
Train - Loss: 0.0634, Acc: 0.9804
Val Acc (cartoon): 0.7591

LODO RESULTS SUMMARY
sketch         : 0.6285
photo          : 0.9731
art_painting   : 0.8707
cartoon        : 0.7996
Average LODO   : 0.8180


In [None]:
import matplotlib.pyplot as plt

# Data
results = {
    "photo": 0.9581,
    "art_painting": 0.7293,
    "cartoon": 0.7186,
    "sketch": 0.3461,
    "All-domains baseline": 0.9010
}

# Prepare labels & values; wrap long label onto two lines for neatness
labels = []
values = []
for k, v in results.items():
    if "All-domains" in k:
        labels.append("All-domains\nbaseline")   # wrap long label
    else:
        labels.append(k.replace("_", " "))       # nicer display for underscores
    values.append(v)

# Plot
fig, ax = plt.subplots(figsize=(9, 5))
x = range(len(labels))

# Draw bars; make baseline visually distinct
colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2", "#888888"]
bars = ax.bar(x, values, color=colors, edgecolor="black", linewidth=0.7)

# Axis and ticks
ax.set_ylim(0, 1.05)
ax.set_ylabel("Accuracy", fontsize=11)
ax.set_title("LODO Results by Domain", fontsize=13, weight="bold")
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=10, rotation=0, ha="center")

# Add horizontal grid lines for readability (below bars)
ax.yaxis.grid(True, linestyle="--", alpha=0.6)
ax.set_axisbelow(True)

# Annotate values above bars with consistent alignment
for bar, val in zip(bars, values):
    height = bar.get_height()
    ax.text(
        bar.get_x() + bar.get_width() / 2,
        height + 0.02,
        f"{val:.3f}",
        ha="center",
        va="bottom",
        fontsize=9,
        fontweight="semibold"
    )

# Tidy up spines
for spine in ("top", "right"):
    ax.spines[spine].set_visible(False)

plt.tight_layout()
# Save optionally:
# fig.savefig("lodo_results_bar.png", dpi=300)
plt.show()
