# Imports and Configs

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torchvision import datasets, transforms
from transformers import ViTImageProcessor, ViTForImageClassification  # Key: Processor for proper ViT prep!
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import random

# Constants: Centralize for easy tweaking
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

SEED = 42
BATCH_SIZE = 24
NUM_EPOCHS = 5
NUM_CLASSES = 7  # PACS has 7 object classes
DATA_ROOT = "../../../pacs_data/pacs_data"  # Adjust to your path
DOMAINS = ["art_painting", "cartoon", "photo", "sketch"]

# ViT model variants (pre-trained on ImageNet-21k for better generalization)
MODELS = {
    "base": "google/vit-base-patch16-224-in21k",
    "small": "WinKawaks/vit-small-patch16-224",
    "tiny": "WinKawaks/vit-tiny-patch16-224"
}

# Reproducibility: Seed everything
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Dataset

In [None]:
class PACSDataset:
    def __init__(self, data_root: str, domains: list, processor: ViTImageProcessor):
        """
        Initializes the PACS multi-domain dataset.
        - data_root: Base path to domain folders.
        - domains: List of domain names (e.g., ['photo', 'sketch']).
        - processor: ViTImageProcessor for preprocessing (resize, normalize to ImageNet stats).
        """
        self.data_root = data_root
        self.domains = domains
        self.processor = processor  # Store for use in dataloader

    def get_dataloader(self, domain: str, train: bool = True, batch_size: int = BATCH_SIZE) -> DataLoader:
        """
        Creates a DataLoader for a specific domain, with train/val split.
        - Splits 80/20 stratified by class labels.
        - Applies processor in a custom collate_fn for batch processing.
        """
        # Load full dataset for this domain using ImageFolder (assumes subfolders = classes)
        full_dataset = datasets.ImageFolder(
            root=os.path.join(self.data_root, domain),
            transform=None  # We'll process in collate_fn to batch it efficiently
        )
        
        # Stratified split: Ensure class balance in train/val
        indices = list(range(len(full_dataset)))
        targets = [full_dataset.targets[i] for i in indices]
        train_idx, val_idx = train_test_split(
            indices, test_size=0.2, stratify=targets, random_state=SEED
        )
        selected_idx = train_idx if train else val_idx
        
        # Subset the dataset
        subset = Subset(full_dataset, selected_idx)
        
        # Custom collate_fn: Load PIL -> process -> tensor for ViT
        def collate_fn(batch):
            images, labels = zip(*batch)  # Unpack (PIL image, label) tuples
            # Process batch of PIL images
            inputs = self.processor(images, return_tensors="pt")
            labels = torch.tensor(labels)
            return {"pixel_values": inputs["pixel_values"], "labels": labels}
        
        # Create loader: Shuffle for train, parallel workers for speed
        loader = DataLoader(
            subset, batch_size=batch_size, shuffle=train,
            num_workers=4, collate_fn=collate_fn, pin_memory=True  # pin_memory for faster GPU transfer
        )
        return loader