### Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import timm
from model import LatenViTtiny
import os
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from PIL import Image
from collections import OrderedDict

  from .autonotebook import tqdm as notebook_tqdm


### Configuration

In [2]:

class Config:
    MODEL_NAME   = 'tiny_vit_21m_224.dist_in22k_ft_in1k'
    NUM_CLASSES  = 7      
    NREPEAT      = 2
    stage = 2
    
    BATCH_SIZE   = 128
    NUM_EPOCHS   = 5
    LEARNING_RATE= 1e-4
    DEVICE       = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    DATA_ROOT    = "../../../pacs_data/pacs_data"
    DOMAINS      = ["art_painting", "cartoon", "photo", "sketch"]
    
    TRANSFORM = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])


### PACS Dataset Class

In [3]:
class PACSDataset(Dataset):
    def __init__(self, root_dir, domain, transform=None):
        self.root_dir    = os.path.join(root_dir, domain)
        self.transform   = transform
        self.classes     = sorted(os.listdir(self.root_dir))
        self.class_to_idx= {cls_name: i for i, cls_name in enumerate(self.classes)}
        self.images      = []
        self.labels      = []
        
        for cls_name in self.classes:
            cls_dir = os.path.join(self.root_dir, cls_name)
            for img_name in os.listdir(cls_dir):
                self.images.append(os.path.join(cls_dir, img_name))
                self.labels.append(self.class_to_idx[cls_name])
                
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        image    = Image.open(img_path).convert('RGB')
        label    = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
        return image, label


### Model Setup

In [4]:
def setup_model():
    base_model = timm.create_model(Config.MODEL_NAME, pretrained=True)
    model = LatenViTtiny(
        model     = base_model,
        nrepeat   = Config.NREPEAT,
        stage = Config.stage
    )
    return model.to(Config.DEVICE)

def setup_baseline_model():
    base_model = timm.create_model(Config.MODEL_NAME, pretrained=True)
    return base_model.to(Config.DEVICE)


### Training Function

In [5]:
def train_epoch(model, dataloader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct      = 0
    total        = 0
    
    for images, labels in dataloader:
        images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss    = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted  = outputs.max(1)
        total        += labels.size(0)
        correct      += predicted.eq(labels).sum().item()
        
    epoch_loss = running_loss / len(dataloader)
    epoch_acc  = 100.0 * correct / total
    return epoch_loss, epoch_acc


### Training Function

In [6]:
@torch.no_grad()
def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total   = 0
    
    for images, labels in dataloader:
        images, labels = images.to(Config.DEVICE), labels.to(Config.DEVICE)
        outputs        = model(images)
        
        _, predicted = outputs.max(1)
        total       += labels.size(0)
        correct     += predicted.eq(labels).sum().item()
        
    return 100.0 * correct / total


### Baseline -CotFormer

In [None]:

train_loaders = []
test_loaders  = []
for domain in Config.DOMAINS:
    ds_train = PACSDataset(Config.DATA_ROOT, domain, Config.TRANSFORM)
    ds_test  = PACSDataset(Config.DATA_ROOT, domain, Config.TRANSFORM)
    train_loaders.append(DataLoader(ds_train, batch_size=Config.BATCH_SIZE, shuffle=True))
    test_loaders .append(DataLoader(ds_test,  batch_size=Config.BATCH_SIZE, shuffle=False))


full_train_ds = ConcatDataset([dl.dataset for dl in train_loaders])
full_test_ds  = ConcatDataset([dl.dataset  for dl in test_loaders ])
full_train_loader = DataLoader(full_train_ds, batch_size=Config.BATCH_SIZE, shuffle=True)
full_test_loader  = DataLoader(full_test_ds,  batch_size=Config.BATCH_SIZE, shuffle=False)

model     = setup_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)

for epoch in range(1, Config.NUM_EPOCHS + 1):
    loss, acc = train_epoch(model, full_train_loader, criterion, optimizer)
    print(f"[Epoch {epoch}/{Config.NUM_EPOCHS}] Train Loss: {loss:.4f}, Train Acc: {acc:.2f}%")

test_acc = evaluate(model, full_test_loader)
print(f"Baseline (all domains) Test Accuracy: {test_acc:.2f}%")

print("\n[Per-Domain Evaluation]")
domain_accuracies = OrderedDict()
for domain, loader in zip(Config.DOMAINS, test_loaders):
    acc = evaluate(model, loader)
    domain_accuracies[domain] = acc
    print(f"  {domain:>12}: {acc:.2f}%")

[Epoch 1/5] Train Loss: 3.7127, Train Acc: 12.86%
[Epoch 2/5] Train Loss: 2.5703, Train Acc: 17.26%


In [None]:

transform = Config.TRANSFORM  

lodo_results = OrderedDict()

for test_domain in Config.DOMAINS:
    print(f"\n=== LODO: Held-Out Domain = {test_domain} ===")
    train_loaders = []
    for d in Config.DOMAINS:
        if d == test_domain:
            continue
        ds_train = PACSDataset(Config.DATA_ROOT, d, transform)
        train_loaders.append(DataLoader(ds_train, batch_size=Config.BATCH_SIZE, shuffle=True))

    train_ds = ConcatDataset([dl.dataset for dl in train_loaders])
    train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True)

    ds_test    = PACSDataset(Config.DATA_ROOT, test_domain, transform)
    test_loader = DataLoader(ds_test, batch_size=Config.BATCH_SIZE, shuffle=False)

    model     = setup_model()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)

    for epoch in range(1, Config.NUM_EPOCHS + 1):
        loss, acc = train_epoch(model, train_loader, criterion, optimizer)
        print(f"[Epoch {epoch}/{Config.NUM_EPOCHS}] Train Loss: {loss:.4f}, Train Acc: {acc:.2f}%")

    test_acc = evaluate(model, test_loader)
    lodo_results[test_domain] = test_acc
    print(f"--> Test Accuracy on {test_domain}: {test_acc:.2f}%")

print("\n=== LODO Summary ===")
for domain, acc in lodo_results.items():
    print(f"{domain:>14}: {acc:.2f}%")


### Baseline Vanilla

In [7]:

train_loaders = []
test_loaders  = []
for domain in Config.DOMAINS:
    ds_train = PACSDataset(Config.DATA_ROOT, domain, Config.TRANSFORM)
    ds_test  = PACSDataset(Config.DATA_ROOT, domain, Config.TRANSFORM)
    train_loaders.append(DataLoader(ds_train, batch_size=Config.BATCH_SIZE, shuffle=True))
    test_loaders .append(DataLoader(ds_test,  batch_size=Config.BATCH_SIZE, shuffle=False))


full_train_ds = ConcatDataset([dl.dataset for dl in train_loaders])
full_test_ds  = ConcatDataset([dl.dataset  for dl in test_loaders ])
full_train_loader = DataLoader(full_train_ds, batch_size=Config.BATCH_SIZE, shuffle=True)
full_test_loader  = DataLoader(full_test_ds,  batch_size=Config.BATCH_SIZE, shuffle=False)

model     = setup_baseline_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)

for epoch in range(1, Config.NUM_EPOCHS + 1):
    loss, acc = train_epoch(model, full_train_loader, criterion, optimizer)
    print(f"[Epoch {epoch}/{Config.NUM_EPOCHS}] Train Loss: {loss:.4f}, Train Acc: {acc:.2f}%")

test_acc = evaluate(model, full_test_loader)
print(f"Baseline (all domains) Test Accuracy: {test_acc:.2f}%")

print("\n[Per-Domain Evaluation]")
domain_accuracies = OrderedDict()
for domain, loader in zip(Config.DOMAINS, test_loaders):
    acc = evaluate(model, loader)
    domain_accuracies[domain] = acc
    print(f"  {domain:>12}: {acc:.2f}%")

[Epoch 1/5] Train Loss: 3.0058, Train Acc: 46.13%
[Epoch 2/5] Train Loss: 0.5827, Train Acc: 87.52%
[Epoch 3/5] Train Loss: 0.2492, Train Acc: 93.63%
[Epoch 4/5] Train Loss: 0.1634, Train Acc: 95.99%
[Epoch 5/5] Train Loss: 0.1144, Train Acc: 96.85%
Baseline (all domains) Test Accuracy: 99.46%

[Per-Domain Evaluation]
  art_painting: 99.90%
       cartoon: 99.49%
         photo: 99.94%
        sketch: 99.01%


In [8]:

transform = Config.TRANSFORM  

lodo_results = OrderedDict()

for test_domain in Config.DOMAINS:
    print(f"\n=== LODO: Held-Out Domain = {test_domain} ===")
    train_loaders = []
    for d in Config.DOMAINS:
        if d == test_domain:
            continue
        ds_train = PACSDataset(Config.DATA_ROOT, d, transform)
        train_loaders.append(DataLoader(ds_train, batch_size=Config.BATCH_SIZE, shuffle=True))

    train_ds = ConcatDataset([dl.dataset for dl in train_loaders])
    train_loader = DataLoader(train_ds, batch_size=Config.BATCH_SIZE, shuffle=True)

    ds_test    = PACSDataset(Config.DATA_ROOT, test_domain, transform)
    test_loader = DataLoader(ds_test, batch_size=Config.BATCH_SIZE, shuffle=False)

    model     = setup_baseline_model()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=Config.LEARNING_RATE)

    for epoch in range(1, Config.NUM_EPOCHS + 1):
        loss, acc = train_epoch(model, train_loader, criterion, optimizer)
        print(f"[Epoch {epoch}/{Config.NUM_EPOCHS}] Train Loss: {loss:.4f}, Train Acc: {acc:.2f}%")

    test_acc = evaluate(model, test_loader)
    lodo_results[test_domain] = test_acc
    print(f"--> Test Accuracy on {test_domain}: {test_acc:.2f}%")

print("\n=== LODO Summary ===")
for domain, acc in lodo_results.items():
    print(f"{domain:>14}: {acc:.2f}%")



=== LODO: Held-Out Domain = art_painting ===
[Epoch 1/5] Train Loss: 3.2942, Train Acc: 41.70%
[Epoch 2/5] Train Loss: 0.7274, Train Acc: 84.53%
[Epoch 3/5] Train Loss: 0.3054, Train Acc: 92.60%
[Epoch 4/5] Train Loss: 0.2077, Train Acc: 94.86%
[Epoch 5/5] Train Loss: 0.1318, Train Acc: 96.44%
--> Test Accuracy on art_painting: 85.50%

=== LODO: Held-Out Domain = cartoon ===
[Epoch 1/5] Train Loss: 3.4373, Train Acc: 40.77%
[Epoch 2/5] Train Loss: 0.7122, Train Acc: 85.04%
[Epoch 3/5] Train Loss: 0.2976, Train Acc: 92.73%
[Epoch 4/5] Train Loss: 0.1706, Train Acc: 95.49%
[Epoch 5/5] Train Loss: 0.1006, Train Acc: 97.08%
--> Test Accuracy on cartoon: 78.50%

=== LODO: Held-Out Domain = photo ===
[Epoch 1/5] Train Loss: 3.4844, Train Acc: 38.53%
[Epoch 2/5] Train Loss: 0.9777, Train Acc: 80.94%
[Epoch 3/5] Train Loss: 0.4085, Train Acc: 90.43%
[Epoch 4/5] Train Loss: 0.2471, Train Acc: 93.47%
[Epoch 5/5] Train Loss: 0.2597, Train Acc: 95.42%
--> Test Accuracy on photo: 97.66%

=== LODO: