# PACS Domain Generalization with Vision Transformer (ViT)

## Setup and Imports

In [1]:
import os
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, ConcatDataset
from sklearn.model_selection import train_test_split
# Updated import to include the base ViTModel
from transformers import ViTForImageClassification, ViTFeatureExtractor, ViTModel
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import random

ModuleNotFoundError: No module named 'transformers'

## Configuration and Seeds

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")
SEED = 42
BATCH_SIZE = 24
NUM_EPOCHS = 5
NUM_CLASSES = 7
DATA_ROOT = "../../../pacs_data/pacs_data"
DOMAINS = ["art_painting", "cartoon", "photo", "sketch"]
MODELS = {
    "base": "google/vit-base-patch16-224-in21k",
    "small": "WinKawaks/vit-small-patch16-224",
    "tiny": "WinKawaks/vit-tiny-patch16-224"
    }

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

## Dataset Wrapper Class

In [None]:
class PACSDataset:
    def __init__(self, data_root, domains, transform):
        self.data_root = data_root
        self.domains = domains
        self.transform = transform

    def get_dataloader(self, domain, train=True):
        dataset = datasets.ImageFolder(os.path.join(self.data_root, domain), transform=self.transform)
        
        indices = list(range(len(dataset)))
        train_idx, val_idx = train_test_split(indices, test_size=0.2, stratify=[dataset.targets[i] for i in indices], random_state=SEED)
        selected_idx = train_idx if train else val_idx
        
        subset = Subset(dataset, selected_idx)
        loader = DataLoader(subset, batch_size=BATCH_SIZE, shuffle=train)
        return loader

## Vision Transformer Wrapper Class

In [None]:
class ViTModel(nn.Module):
    def __init__(self, num_classes, model_size="base"):
        super(ViTModel, self).__init__()
        # Change 1: Load the base ViTModel, not the one for image classification
        self.model = ViTModel.from_pretrained(
            MODELS[model_size]
        )
        
        # Change 2: Define our custom "thinking" head
        hidden_size = self.model.config.hidden_size # This is 768 for the base model
        self.classifier_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_size, num_classes)
        )

    def forward(self, x):
        # Change 3: Update the forward pass logic
        # Pass input through the base model
        outputs = self.model(x)
        # Get the feature vector for the [CLS] token
        cls_token_features = outputs.last_hidden_state[:, 0, :]
        # Pass the features through our custom head
        logits = self.classifier_head(cls_token_features)
        return logits

## Trainer Class

In [None]:
class Trainer:
    def __init__(self, model, optimizer, criterion):
        self.model = model.to(DEVICE)
        self.optimizer = optimizer
        self.criterion = criterion

    def train(self, dataloader):
        self.model.train()
        total_loss = 0
        for inputs, labels in tqdm(dataloader):
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            outputs = self.model(inputs)
            loss = self.criterion(outputs, labels)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
        return total_loss / len(dataloader)

    def evaluate(self, dataloader):
        self.model.eval()
        
        # Change 1: Manually set the dropout layers in our custom head to train mode
        # This keeps them active during evaluation for Monte Carlo Dropout
        for module in self.model.classifier_head.modules():
            if isinstance(module, nn.Dropout):
                module.train()

        total_correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in dataloader:
                inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
                
                # Change 2: Perform N=20 forward passes to get an ensemble of predictions
                ensemble_preds = []
                for _ in range(20): # N=20 passes
                    outputs = self.model(inputs)
                    preds = torch.argmax(outputs, dim=1)
                    ensemble_preds.append(preds.unsqueeze(0))
                
                # Change 3: Calculate the majority vote for the final prediction
                stacked_preds = torch.cat(ensemble_preds, dim=0)
                final_preds, _ = torch.mode(stacked_preds, dim=0)

                total_correct += (final_preds == labels).sum().item()
                total += labels.size(0)
        return total_correct / total

# Stochastic Ensemble ViT

## Leave-One-Domain-Out (LODO) Training

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

results_base = {}

for test_domain in DOMAINS:
    print(f"\nTesting on domain: {test_domain}")
    train_domains = [d for d in DOMAINS if d != test_domain]

    # Load datasets
    dataset = PACSDataset(DATA_ROOT, DOMAINS, transform)
    train_loaders = [dataset.get_dataloader(d, train=True) for d in train_domains]
    val_loaders = [dataset.get_dataloader(d, train=False) for d in train_domains]
    test_loader = dataset.get_dataloader(test_domain, train=False)

    # Concatenate datasets
    train_ds = ConcatDataset([dl.dataset for dl in train_loaders])
    val_ds = ConcatDataset([dl.dataset for dl in val_loaders])

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    # Initialize model, optimizer, and criterion
    model_base = ViTModel(NUM_CLASSES, model_size="base")
    optimizer = optim.Adam(model_base.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    trainer = Trainer(model_base, optimizer, criterion)

    # Train
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch + 1}/{NUM_EPOCHS}")
        train_loss_base = trainer.train(train_loader)
        val_acc_base = trainer.evaluate(val_loader)
        print(f"Train Loss: {train_loss_base:.4f} | Val Acc: {val_acc_base:.4f}")

    # Test
    test_acc_base = trainer.evaluate(test_loader)
    results_base[test_domain] = test_acc_base
    print(f"Test Accuracy on {test_domain}: {test_acc_base:.4f}")