### Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import timm
from model import LatenViTSmall
import os
from torch.utils.data import DataLoader, Dataset
from PIL import Image


### Configuration

In [None]:

class Config:
    MODEL_NAME   = 'vit_small_patch16_224.augreg_in21k'
    NUM_CLASSES  = 7      
    NREPEAT      = 3
    START_BLOCK  = 8
    END_BLOCK    = 10
    NUM_LAYERS   = 12
    
    BATCH_SIZE   = 32
    NUM_EPOCHS   = 5
    LEARNING_RATE= 1e-4
    DEVICE       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    DATA_ROOT    = "../../../../pacs_data/pacs_data"
    DOMAINS      = ["art_painting", "cartoon", "photo", "sketch"]
    
    TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])


### PACS Dataset Class

In [None]:
class PACSDataset(Dataset):
    def __init__(self, root_dir, domain, transform=None):
        self.root_dir    = os.path.join(root_dir, domain)
        self.transform   = transform
        self.classes     = sorted(os.listdir(self.root_dir))
        self.class_to_idx= {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.images      = []
        self.labels      = []
        
        for cls_name in self.classes:
            cls_dir = os.path.join(self.root_dir, cls_name)
            for img_name in os.listdir(cls_dir):
                self.images.append(os.path.join(cls_dir, img_name))
                self.labels.append(self.class_to_idx[cls_name])
                
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        image    = Image.open(img_path).convert('RGB')
        label    = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        return image, label


### Model Setup

In [None]:
def setup_model():
    base_model = timm.create_model(Config.MODEL_NAME, pretrained=True)
    model = LatenViTSmall(
        model     = base_model,
        nrepeat   = Config.NREPEAT,
        start     = Config.START_BLOCK,
        end       = Config.END_BLOCK,
        num_layers= Config.NUM_LAYERS
    )
    return model.to(Config.DEVICE)


### Training Function

In [None]:
def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct      = 0
    total        = 0
    
    for images, labels in dataloader:
        images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss    = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted  = outputs.max(1)
        total        += labels.size(0)
        correct      += predicted.eq(labels).sum().item()
        
    epoch_loss = running_loss / len(dataloader)
    epoch_acc  = 100.0 * correct / total
    return epoch_loss, epoch_acc


### Training Function

In [None]:
@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total   = 0
    
    for images, labels in dataloader:
        images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
        outputs        = model(images)
        
        _, predicted = outputs.max(1)
        total       += labels.size(0)
        correct     += predicted.eq(labels).sum().item()
        
    return 100.0 * correct / total


### Main

In [None]:
model     = setup_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)

train_ds   = PACSDataset(Config.DATA_ROOT, "photo", Config.TRANSFORM)
train_loader = DataLoader(
    train_ds, batch_size=Config.BATCH_SIZE, shuffle=True
)

val_loaders = {}
for domain in Config.DOMAINS:
    if domain == "photo": continue
    ds = PACSDataset(Config.DATA_ROOT, domain, Config.TRANSFORM)
    val_loaders[domain] = DataLoader(
        ds, batch_size=Config.BATCH_SIZE, shuffle=False
    )

for epoch in range(1, Config.NUM_EPOCHS+1):
    loss, acc = train_epoch(model, train_loader, criterion, optimizer)
    print(f"[Epoch {epoch}/{Config.NUM_EPOCHS}] Train Loss: {loss:.4f}, Train Acc: {acc:.2f}%")
    
    for domain, loader in val_loaders.items():
        val_acc = evaluate(model, loader)
        print(f"  {domain:>14} Acc: {val_acc:.2f}%")
    print("-" * 60)