### Imports

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
from dataset import PACSDataset
from vit_grqo import ViTGRQO, grqo_loss_from_gradients
from encoder_decoder_vit import VisualDecoder, MultiheadAttn, DecoderAttn
from Visual_query_heads import QueryLosses, GRQO

### CONFIG 

In [2]:
import os
from datasets import load_dataset
from PIL import Image

# Your constants
DATA_ROOT = "../../../pacs_data/pacs_data"
DOMAINS = ["art_painting", "cartoon", "photo", "sketch"]
CLASSES = ["dog", "elephant", "giraffe", "guitar", "horse", "house", "person"]
BATCH_SIZE = 32
NUM_CLASSES = 7
NUM_EPOCHS = 5
LR = 1e-4
TOPK = 24
ALPHA = 2.0
BETA = 0.5
TAU = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# dataset = load_dataset("flwrlabs/pacs", split="train")

# os.makedirs(DATA_ROOT, exist_ok=True)
# for domain in DOMAINS:
#     for cls in CLASSES:
#         os.makedirs(f"{DATA_ROOT}/{domain}/{cls}", exist_ok=True)

# for i, example in enumerate(dataset):
#     domain = example["domain"]  
#     label_idx = example["label"]  
#     label = CLASSES[label_idx]

#     if domain not in DOMAINS:
#         raise ValueError(f"Unexpected domain: {domain}. Expected one of {DOMAINS}")
#     if label not in CLASSES:
#         raise ValueError(f"Unexpected label: {label}. Expected one of {CLASSES}")
    
#     image = example["image"]
#     image.save(f"{DATA_ROOT}/{domain}/{label}/image_{i}.jpg")


### Data

In [3]:
TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])
pacs_data = PACSDataset(DATA_ROOT, DOMAINS, TRANSFORM, BATCH_SIZE)
ALL_DOMAINS = ['photo', 'art_painting', 'cartoon', 'sketch']
LEAVE_OUT = 'sketch'  
TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
VAL_DOMAIN = LEAVE_OUT

train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
train_loader = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
    batch_size=BATCH_SIZE,
    shuffle=True
)

# Validation loader
val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)

### Model 

In [4]:
from transformers import ViTModel, AutoFeatureExtractor
import torch
import torch.nn as nn

# --- 1) ViT-Tiny Backbone from Hugging Face ---
vit_encoder = ViTModel.from_pretrained("WinKawaks/vit-tiny-patch16-224")
HIDDEN_DIM = vit_encoder.config.hidden_size  # 192 for ViT-Tiny

# --- 2) GRQO Hyperparameters ---
NUM_HEADS = 6
DROPOUT = 0.1
NUM_LAYERS = 3
DDROPOUT = 0.1
NUM_TOKENS = 32
TEMPERATURE = 0.1
ALPHA = 2.0
BETA = 0.5
TAU = 1e-3
LAMBDA_GRQO = 1.0
TEACHER_EMA = 0.99
NUM_CLASSES = 7

# --- 3) GRQO Decoder ---
grqo_model = GRQO(
    Hidden_dim=HIDDEN_DIM,
    num_heads=NUM_HEADS,
    dropout=DROPOUT,
    num_tokens=NUM_TOKENS,
    ddropout=DDROPOUT,
    num_layers=NUM_LAYERS,
    num_classes=NUM_CLASSES,
    temperature=TEMPERATURE,
    alpha=ALPHA,
    beta=BETA,
    tau=TAU,
    lambda_grqo=LAMBDA_GRQO,
    teacher_ema=TEACHER_EMA,
    reward_proxy="taylor"
)

# --- 4) Full ViTGRQO Model ---
class ViTGRQO(nn.Module):
    def __init__(self, vit_encoder, grqo_model):
        super().__init__()
        self.vit = vit_encoder
        self.grqo = grqo_model

    def forward(self, x, labels=None):
        # Get patch embeddings from HF ViT (exclude CLS token)
        outputs = self.vit(pixel_values=x, output_hidden_states=True)
        patch_tokens = outputs.last_hidden_state[:, 1:, :]  # [B, N, D] skip CLS token
        
        # Pass to GRQO
        if labels is not None:
            return self.grqo(patch_tokens, labels)
        else:
            dummy_labels = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
            grqo_out = self.grqo(patch_tokens, dummy_labels)
            return {
                'img_logits': grqo_out['img_logits'],
                'preds': grqo_out['preds']
            }

# --- 5) Initialize model ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViTGRQO(vit_encoder, grqo_model).to(device)

# --- 6) Freeze ViT backbone if desired ---
FREEZE_VIT = False
if FREEZE_VIT:
    for param in model.vit.parameters():
        param.requires_grad = False
    print("ViT backbone frozen")
else:
    print("ViT backbone trainable")

# Ensure GRQO is trainable
for param in model.grqo.parameters():
    param.requires_grad = True

# --- 7) Optimizer ---
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(trainable_params, lr=LR, weight_decay=0.01)

print(f"Model initialized on {device}")
print(f"Hidden dim: {HIDDEN_DIM}, Num tokens: {NUM_TOKENS}")

Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViT backbone trainable
Model initialized on cuda
Hidden dim: 192, Num tokens: 32


### training

In [5]:
def train_epoch(model, train_loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_cls_loss = 0.0
    total_grqo_loss = 0.0
    correct = 0
    total_samples = 0
    
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        output = model(images, labels)
        
        # Extract losses
        loss = output['loss']
        cls_loss = output['cls_loss']
        grqo_loss = output['grqo_loss']
        preds = output['preds']
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        total_loss += loss.item() * images.size(0)
        total_cls_loss += cls_loss.item() * images.size(0)
        total_grqo_loss += grqo_loss.item() * images.size(0)
        correct += (preds == labels).sum().item()
        total_samples += labels.size(0)
        
        if batch_idx % 50 == 0:
            print(f'Batch {batch_idx}, Loss: {loss.item():.4f}, '
                  f'Cls: {cls_loss.item():.4f}, GRQO: {grqo_loss.item():.4f}')
    
    avg_loss = total_loss / total_samples
    avg_cls_loss = total_cls_loss / total_samples
    avg_grqo_loss = total_grqo_loss / total_samples
    accuracy = correct / total_samples
    
    return avg_loss, avg_cls_loss, avg_grqo_loss, accuracy

### evaluation

In [6]:
def validate(model, val_loader, device):
    model.eval()
    correct = 0
    total = 0

    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        
        # GRQO needs gradients even during validation
        with torch.set_grad_enabled(True):
            output = model(images, labels)
            preds = output['preds']
        
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    accuracy = correct / total
    return accuracy

### finetuning 

In [9]:
ALL_DOMAINS = ['sketch', 'photo', 'art_painting', 'cartoon']
lodo_results = {}

for LEAVE_OUT in ALL_DOMAINS:
    print(f"\n=== LODO: Leaving out domain '{LEAVE_OUT}' ===")
    TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
    VAL_DOMAIN = LEAVE_OUT

    # Create data loaders
    train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    
    val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)

    # Reset model for this split
    model = ViTGRQO(vit_encoder, grqo_model).to(DEVICE)
    
    if FREEZE_VIT:
        for param in model.vit.parameters():
            param.requires_grad = False
    
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=LR, weight_decay=0.01
    )

    best_val_acc = 0.0
    
    # Training loop
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
        
        # Train
        train_loss, train_cls, train_grqo, train_acc = train_epoch(model, train_loader, optimizer, DEVICE)
        
        # Validate  
        val_acc = validate(model, val_loader, DEVICE)
        
        print(f"Train - Loss: {train_loss:.4f}, Cls: {train_cls:.4f}, "
              f"GRQO: {train_grqo:.4f}, Acc: {train_acc:.4f}")
        print(f"Val Acc ({VAL_DOMAIN}): {val_acc:.4f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
    
    lodo_results[VAL_DOMAIN] = best_val_acc
    print(f"Best Val Acc for {VAL_DOMAIN}: {best_val_acc:.4f}")

# Cell 8: Results Summary
print("\n" + "="*50)
print("LODO RESULTS SUMMARY")
print("="*50)
for domain, acc in lodo_results.items():
    print(f"{domain:15}: {acc:.4f}")

avg_lodo = sum(lodo_results.values()) / len(lodo_results)
print(f"{'Average':15}: {avg_lodo:.4f}")
print("="*50)


=== LODO: Leaving out domain 'sketch' ===



Epoch 1/5
Batch 0, Loss: 1.9996, Cls: 1.9305, GRQO: 0.0691
Batch 50, Loss: 1.0056, Cls: 0.9952, GRQO: 0.0103
Batch 100, Loss: 0.5724, Cls: 0.5644, GRQO: 0.0080
Batch 150, Loss: 0.4117, Cls: 0.4049, GRQO: 0.0067
Train - Loss: 0.9679, Cls: 0.9554, GRQO: 0.0126, Acc: 0.7237
Val Acc (sketch): 0.3550

Epoch 2/5
Batch 0, Loss: 0.6012, Cls: 0.5945, GRQO: 0.0067
Batch 50, Loss: 0.2280, Cls: 0.2226, GRQO: 0.0053
Batch 100, Loss: 0.2176, Cls: 0.2124, GRQO: 0.0051
Batch 150, Loss: 0.1684, Cls: 0.1635, GRQO: 0.0049
Train - Loss: 0.2815, Cls: 0.2760, GRQO: 0.0054, Acc: 0.9130
Val Acc (sketch): 0.4084

Epoch 3/5
Batch 0, Loss: 0.0988, Cls: 0.0939, GRQO: 0.0050
Batch 50, Loss: 0.1454, Cls: 0.1411, GRQO: 0.0043
Batch 100, Loss: 0.2461, Cls: 0.2415, GRQO: 0.0045
Batch 150, Loss: 0.0480, Cls: 0.0441, GRQO: 0.0039
Train - Loss: 0.1665, Cls: 0.1623, GRQO: 0.0042, Acc: 0.9511
Val Acc (sketch): 0.3893

Epoch 4/5
Batch 0, Loss: 0.0905, Cls: 0.0866, GRQO: 0.0040
Batch 50, Loss: 0.1168, Cls: 0.1126, GRQO: 0.0

In [None]:
ALL_DOMAINS = ['sketch', 'photo', 'art_painting', 'cartoon']
lodo_results = {}

for LEAVE_OUT in ALL_DOMAINS:
    print(f"\n=== LODO: Leaving out domain '{LEAVE_OUT}' ===")
    TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
    VAL_DOMAIN = LEAVE_OUT

    # Create data loaders
    train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
        batch_size=BATCH_SIZE,
        shuffle=True
    )
    
    val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)

    # Reset model for this split
    model = ViTGRQO(vit_encoder, grqo_model).to(DEVICE)
    
    if FREEZE_VIT:
        for param in model.vit.parameters():
            param.requires_grad = False
    
    optimizer = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=LR, weight_decay=0.01
    )

    best_val_acc = 0.0
    
    # Training loop
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
        
        # Train
        train_loss, train_cls, train_grqo, train_acc = train_epoch(model, train_loader, optimizer, DEVICE)
        
        # Validate  
        val_acc = validate(model, val_loader, DEVICE)
        
        print(f"Train - Loss: {train_loss:.4f}, Cls: {train_cls:.4f}, "
              f"GRQO: {train_grqo:.4f}, Acc: {train_acc:.4f}")
        print(f"Val Acc ({VAL_DOMAIN}): {val_acc:.4f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
    
    lodo_results[VAL_DOMAIN] = best_val_acc
    print(f"Best Val Acc for {VAL_DOMAIN}: {best_val_acc:.4f}")

# Cell 8: Results Summary
print("\n" + "="*50)
print("LODO RESULTS SUMMARY")
print("="*50)
for domain, acc in lodo_results.items():
    print(f"{domain:15}: {acc:.4f}")

avg_lodo = sum(lodo_results.values()) / len(lodo_results)
print(f"{'Average':15}: {avg_lodo:.4f}")
print("="*50)


=== LODO: Leaving out domain 'sketch' ===

Epoch 1/5
Batch 0, Loss: 2.0397, Cls: 1.9501, GRQO: 0.0896
Batch 50, Loss: 0.4309, Cls: 0.4234, GRQO: 0.0075
Batch 100, Loss: 0.4242, Cls: 0.4194, GRQO: 0.0047
Batch 150, Loss: 0.3062, Cls: 0.3018, GRQO: 0.0044
Train - Loss: 0.6380, Cls: 0.6268, GRQO: 0.0111, Acc: 0.8051
Val Acc (sketch): 0.5153

Epoch 2/5
Batch 0, Loss: 0.2020, Cls: 0.1980, GRQO: 0.0040
Batch 50, Loss: 0.0992, Cls: 0.0957, GRQO: 0.0035
Batch 100, Loss: 0.0685, Cls: 0.0654, GRQO: 0.0031
Batch 150, Loss: 0.0799, Cls: 0.0770, GRQO: 0.0029
Train - Loss: 0.1189, Cls: 0.1154, GRQO: 0.0035, Acc: 0.9660
Val Acc (sketch): 0.4962

Epoch 3/5
Batch 0, Loss: 0.1174, Cls: 0.1143, GRQO: 0.0031
Batch 50, Loss: 0.0283, Cls: 0.0257, GRQO: 0.0027
Batch 100, Loss: 0.0125, Cls: 0.0100, GRQO: 0.0025
Batch 150, Loss: 0.0126, Cls: 0.0102, GRQO: 0.0023
Train - Loss: 0.0517, Cls: 0.0490, GRQO: 0.0026, Acc: 0.9854
Val Acc (sketch): 0.6718

Epoch 4/5
Batch 0, Loss: 0.0226, Cls: 0.0201, GRQO: 0.0024
Bat

In [7]:
import torch
import torch.nn as nn
from transformers import ViTForImageClassification

ALL_DOMAINS = ['sketch', 'photo', 'art_painting', 'cartoon']
lodo_results = {}

# ---------------- LODO Experiments ----------------
for LEAVE_OUT in ALL_DOMAINS:
    print(f"\n=== LODO: Leaving out domain '{LEAVE_OUT}' ===")
    TRAIN_DOMAINS = [d for d in ALL_DOMAINS if d != LEAVE_OUT]
    VAL_DOMAIN = LEAVE_OUT

    # Train loader
    train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in TRAIN_DOMAINS]
    train_loader = torch.utils.data.DataLoader(
        torch.utils.data.ConcatDataset([d.dataset for d in train_datasets]),
        batch_size=BATCH_SIZE,
        shuffle=True
    )

    # Validation loader
    val_loader = pacs_data.get_dataloader(domain=VAL_DOMAIN, train=False)

    # Model (no freezing)
    model = ViTForImageClassification.from_pretrained(
        "WinKawaks/vit-tiny-patch16-224",
        num_labels=NUM_CLASSES,
        ignore_mismatched_sizes=True  # <--- important fix
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()

    best_val_acc = 0.0

    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")

        # ---- Train ----
        model.train()
        running_loss, running_corrects, running_samples = 0.0, 0, 0

        for images, labels in train_loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)

            outputs = model(images)
            logits = outputs.logits
            loss = criterion(logits, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            preds = logits.argmax(dim=1)
            running_loss += loss.item() * images.size(0)
            running_corrects += (preds == labels).sum().item()
            running_samples += labels.size(0)

        train_loss = running_loss / running_samples
        train_acc = running_corrects / running_samples

        # ---- Validation ----
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                preds = outputs.logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        val_acc = correct / total

        print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
        print(f"Val Acc ({VAL_DOMAIN}): {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc

    lodo_results[VAL_DOMAIN] = best_val_acc
    print(f"Best Val Acc for {VAL_DOMAIN}: {best_val_acc:.4f}")


# ---------------- All-Domains Baseline ----------------
print("\n=== Baseline: Train & Validate on All Domains ===")

# Train loader with all domains
all_train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in ALL_DOMAINS]
train_loader_all = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset([d.dataset for d in all_train_datasets]),
    batch_size=BATCH_SIZE,
    shuffle=True
)

# Validation loader with all domains
all_val_datasets = [pacs_data.get_dataloader(domain=d, train=False) for d in ALL_DOMAINS]
val_loader_all = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset([d.dataset for d in all_val_datasets]),
    batch_size=BATCH_SIZE,
    shuffle=False
)

# Model (no freezing)
model = ViTForImageClassification.from_pretrained(
    "WinKawaks/vit-tiny-patch16-224",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True 
).to(DEVICE)

optimizer = torch.optim.AdamW(baseline_model.parameters(), lr=LR, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

best_val_acc = 0.0

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")

    # ---- Train ----
    baseline_model.train()
    running_loss, running_corrects, running_samples = 0.0, 0, 0

    for images, labels in train_loader_all:
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        outputs = baseline_model(images)
        logits = outputs.logits
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        running_loss += loss.item() * images.size(0)
        running_corrects += (preds == labels).sum().item()
        running_samples += labels.size(0)

    train_loss = running_loss / running_samples
    train_acc = running_corrects / running_samples

    # ---- Validation ----
    baseline_model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader_all:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = baseline_model(images)
            preds = outputs.logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    val_acc = correct / total

    print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
    print(f"Val Acc (All Domains): {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc


# ---------------- Results Summary ----------------
print("\n" + "="*50)
print("LODO RESULTS SUMMARY")
print("="*50)
for domain, acc in lodo_results.items():
    print(f"{domain:15}: {acc:.4f}")

avg_lodo = sum(lodo_results.values()) / len(lodo_results)
print(f"{'Average LODO':15}: {avg_lodo:.4f}")
print(f"{'All-domains baseline':15}: {best_val_acc:.4f}")
print("="*50)



=== LODO: Leaving out domain 'sketch' ===


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 192]) in the checkpoint and torch.Size([7, 192]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5
Train - Loss: 0.3371, Acc: 0.8824
Val Acc (sketch): 0.5522

Epoch 2/5
Train - Loss: 0.0538, Acc: 0.9843
Val Acc (sketch): 0.4326

Epoch 3/5
Train - Loss: 0.0298, Acc: 0.9913
Val Acc (sketch): 0.5267

Epoch 4/5
Train - Loss: 0.0272, Acc: 0.9911
Val Acc (sketch): 0.5674

Epoch 5/5
Train - Loss: 0.0189, Acc: 0.9934
Val Acc (sketch): 0.4898
Best Val Acc for sketch: 0.5674

=== LODO: Leaving out domain 'photo' ===


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 192]) in the checkpoint and torch.Size([7, 192]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5
Train - Loss: 0.4923, Acc: 0.8202
Val Acc (photo): 0.9671

Epoch 2/5
Train - Loss: 0.1340, Acc: 0.9530
Val Acc (photo): 0.9701

Epoch 3/5
Train - Loss: 0.0638, Acc: 0.9794
Val Acc (photo): 0.9611

Epoch 4/5
Train - Loss: 0.0537, Acc: 0.9805
Val Acc (photo): 0.9641

Epoch 5/5
Train - Loss: 0.0392, Acc: 0.9859
Val Acc (photo): 0.9641
Best Val Acc for photo: 0.9701

=== LODO: Leaving out domain 'art_painting' ===


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 192]) in the checkpoint and torch.Size([7, 192]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5
Train - Loss: 0.4586, Acc: 0.8336
Val Acc (art_painting): 0.7902

Epoch 2/5
Train - Loss: 0.1261, Acc: 0.9569
Val Acc (art_painting): 0.7683

Epoch 3/5
Train - Loss: 0.0563, Acc: 0.9781
Val Acc (art_painting): 0.7098

Epoch 4/5
Train - Loss: 0.0543, Acc: 0.9806
Val Acc (art_painting): 0.6927

Epoch 5/5
Train - Loss: 0.0507, Acc: 0.9832
Val Acc (art_painting): 0.7707
Best Val Acc for art_painting: 0.7902

=== LODO: Leaving out domain 'cartoon' ===


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 192]) in the checkpoint and torch.Size([7, 192]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5
Train - Loss: 0.5261, Acc: 0.8060
Val Acc (cartoon): 0.6930

Epoch 2/5
Train - Loss: 0.1368, Acc: 0.9534
Val Acc (cartoon): 0.7783

Epoch 3/5
Train - Loss: 0.0692, Acc: 0.9761
Val Acc (cartoon): 0.7420

Epoch 4/5
Train - Loss: 0.0493, Acc: 0.9827
Val Acc (cartoon): 0.7399

Epoch 5/5
Train - Loss: 0.0551, Acc: 0.9822
Val Acc (cartoon): 0.7058
Best Val Acc for cartoon: 0.7783

=== Baseline: Train & Validate on All Domains ===


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 192]) in the checkpoint and torch.Size([7, 192]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


NameError: name 'baseline_model' is not defined

In [9]:
print("\n=== Baseline: Train & Validate on All Domains ===")

# Train loader with all domains
all_train_datasets = [pacs_data.get_dataloader(domain=d, train=True) for d in ALL_DOMAINS]
train_loader_all = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset([d.dataset for d in all_train_datasets]),
    batch_size=BATCH_SIZE,
    shuffle=True
)

# Validation loaders (per domain and all combined)
val_loaders_per_domain = {d: pacs_data.get_dataloader(domain=d, train=False) for d in ALL_DOMAINS}
all_val_datasets = [pacs_data.get_dataloader(domain=d, train=False) for d in ALL_DOMAINS]
val_loader_all = torch.utils.data.DataLoader(
    torch.utils.data.ConcatDataset([d.dataset for d in all_val_datasets]),
    batch_size=BATCH_SIZE,
    shuffle=False
)

# Model (no freezing)
baseline_model = ViTForImageClassification.from_pretrained(
    "WinKawaks/vit-tiny-patch16-224",
    num_labels=NUM_CLASSES,
    ignore_mismatched_sizes=True
).to(DEVICE)

optimizer = torch.optim.AdamW(baseline_model.parameters(), lr=LR, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()

best_val_acc = 0.0
best_val_per_domain = {}

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")

    # ---- Train ----
    baseline_model.train()
    running_loss, running_corrects, running_samples = 0.0, 0, 0

    for images, labels in train_loader_all:
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        outputs = baseline_model(images)
        logits = outputs.logits
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = logits.argmax(dim=1)
        running_loss += loss.item() * images.size(0)
        running_corrects += (preds == labels).sum().item()
        running_samples += labels.size(0)

    train_loss = running_loss / running_samples
    train_acc = running_corrects / running_samples

    # ---- Validation (all domains combined) ----
    baseline_model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in val_loader_all:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = baseline_model(images)
            preds = outputs.logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = correct / total

    # ---- Validation (per domain) ----
    per_domain_accs = {}
    baseline_model.eval()
    with torch.no_grad():
        for domain, loader in val_loaders_per_domain.items():
            correct, total = 0, 0
            for images, labels in loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = baseline_model(images)
                preds = outputs.logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
            per_domain_accs[domain] = correct / total if total > 0 else 0.0

    # Print results
    print(f"Train - Loss: {train_loss:.4f}, Acc: {train_acc:.4f}")
    print(f"Val Acc (All Domains Combined): {val_acc:.4f}")
    for domain, acc in per_domain_accs.items():
        print(f"  {domain:15}: {acc:.4f}")

    # Track best
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_val_per_domain = per_domain_accs.copy()

# ---------------- Results Summary ----------------
print("\n" + "="*50)
print("LODO RESULTS SUMMARY")
print("="*50)
for domain, acc in lodo_results.items():
    print(f"{domain:15}: {acc:.4f}")

avg_lodo = sum(lodo_results.values()) / len(lodo_results)
print(f"{'Average LODO':15}: {avg_lodo:.4f}")
print(f"{'All-domains baseline':15}: {best_val_acc:.4f}")
for domain, acc in best_val_per_domain.items():
    print(f"  {domain:15}: {acc:.4f}")
print("="*50)



=== Baseline: Train & Validate on All Domains ===


Some weights of ViTForImageClassification were not initialized from the model checkpoint at WinKawaks/vit-tiny-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([7]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 192]) in the checkpoint and torch.Size([7, 192]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Epoch 1/5
Train - Loss: 0.4178, Acc: 0.8486
Val Acc (All Domains Combined): 0.9145
  sketch         : 0.8766
  photo          : 0.9820
  art_painting   : 0.9268
  cartoon        : 0.9190

Epoch 2/5
Train - Loss: 0.1145, Acc: 0.9592
Val Acc (All Domains Combined): 0.9050
  sketch         : 0.8448
  photo          : 0.9820
  art_painting   : 0.9220
  cartoon        : 0.9360

Epoch 3/5
Train - Loss: 0.0513, Acc: 0.9810
Val Acc (All Domains Combined): 0.9195
  sketch         : 0.8995
  photo          : 0.9611
  art_painting   : 0.9024
  cartoon        : 0.9382

Epoch 4/5
Train - Loss: 0.0517, Acc: 0.9835
Val Acc (All Domains Combined): 0.9190
  sketch         : 0.8893
  photo          : 0.9820
  art_painting   : 0.8951
  cartoon        : 0.9446

Epoch 5/5
Train - Loss: 0.0416, Acc: 0.9862
Val Acc (All Domains Combined): 0.9190
  sketch         : 0.9148
  photo          : 0.9581
  art_painting   : 0.8585
  cartoon        : 0.9510

LODO RESULTS SUMMARY
sketch         : 0.5674
photo         

In [None]:
import matplotlib.pyplot as plt

# Data
results = {
    "photo": 0.9581,
    "art_painting": 0.7293,
    "cartoon": 0.7186,
    "sketch": 0.3461,
    "All-domains baseline": 0.9010
}

# Prepare labels & values; wrap long label onto two lines for neatness
labels = []
values = []
for k, v in results.items():
    if "All-domains" in k:
        labels.append("All-domains\nbaseline")   # wrap long label
    else:
        labels.append(k.replace("_", " "))       # nicer display for underscores
    values.append(v)

# Plot
fig, ax = plt.subplots(figsize=(9, 5))
x = range(len(labels))

# Draw bars; make baseline visually distinct
colors = ["#4C72B0", "#55A868", "#C44E52", "#8172B2", "#888888"]
bars = ax.bar(x, values, color=colors, edgecolor="black", linewidth=0.7)

# Axis and ticks
ax.set_ylim(0, 1.05)
ax.set_ylabel("Accuracy", fontsize=11)
ax.set_title("LODO Results by Domain", fontsize=13, weight="bold")
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=10, rotation=0, ha="center")

# Add horizontal grid lines for readability (below bars)
ax.yaxis.grid(True, linestyle="--", alpha=0.6)
ax.set_axisbelow(True)

# Annotate values above bars with consistent alignment
for bar, val in zip(bars, values):
    height = bar.get_height()
    ax.text(
        bar.get_x() + bar.get_width() / 2,
        height + 0.02,
        f"{val:.3f}",
        ha="center",
        va="bottom",
        fontsize=9,
        fontweight="semibold"
    )

# Tidy up spines
for spine in ("top", "right"):
    ax.spines[spine].set_visible(False)

plt.tight_layout()
# Save optionally:
# fig.savefig("lodo_results_bar.png", dpi=300)
plt.show()
