## ERM with ResNet50 - No-Adaption Baseline

#### Source Domains: PACS Art/Painting, Cartoon, Photo
#### Target Domain: PACS Sketch

In [1]:
import torch
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [2]:
model = torchvision.models.resnet50(weights='IMAGENET1K_V1')

for param in model.parameters():
    param.requires_grad = False
model.fc = torch.nn.Linear(model.fc.in_features, 7)

optim = torch.optim.Adam(model.fc.parameters(), lr=5e-5)

criterion = torch.nn.CrossEntropyLoss()

In [None]:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(imagenet_mean, imagenet_std)
])

art_dir = "../data/pacs_data/pacs_data/art_painting"
cartoon_dir = "../data/pacs_data/pacs_data/cartoon"
photo_dir = "../data/pacs_data/pacs_data/photo"
sketch_dir = "../data/pacs_data/pacs_data/sketch"

art_dataset = datasets.ImageFolder(root=art_dir, transform=transform)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=transform)
photo_dataset = datasets.ImageFolder(root=photo_dir, transform=transform)
sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=transform)

art_loader = DataLoader(
    art_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

cartoon_loader = DataLoader(
    cartoon_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

photo_loader = DataLoader(
    photo_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

sketch_loader = DataLoader( # This is also the test domain loader
    sketch_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

source_dataset = ConcatDataset([art_dataset, cartoon_dataset, photo_dataset])

source_loader = DataLoader(
    source_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

pacs_classes = sketch_dataset.classes

In [4]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.to(device)
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    avg_loss = running_loss / total
    accuracy = 100. * correct / total
    return avg_loss, accuracy

In [5]:
def evaluate(model, test_loader, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            preds = model.forward(images)
            preds = torch.argmax(preds, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    accuracy = correct / total * 100
    return accuracy

In [6]:
train_acc = 0
num_epochs = 50
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, source_loader, criterion, optim, device)
    print(f"Epoch {epoch+1}.")
    print(f"Loss:  {train_loss:.4f} | Train Accuracy:  {train_acc:.2f}%")

Epoch 1.
Loss:  1.6311 | Train Accuracy:  47.05%
Epoch 2.
Loss:  1.1899 | Train Accuracy:  76.43%
Epoch 3.
Loss:  0.9413 | Train Accuracy:  82.45%
Epoch 4.
Loss:  0.7865 | Train Accuracy:  84.53%
Epoch 5.
Loss:  0.6930 | Train Accuracy:  85.73%
Epoch 6.
Loss:  0.6313 | Train Accuracy:  86.03%
Epoch 7.
Loss:  0.5707 | Train Accuracy:  87.33%
Epoch 8.
Loss:  0.5284 | Train Accuracy:  87.41%
Epoch 9.
Loss:  0.4976 | Train Accuracy:  88.06%
Epoch 10.
Loss:  0.4753 | Train Accuracy:  88.34%
Epoch 11.
Loss:  0.4553 | Train Accuracy:  88.12%
Epoch 12.
Loss:  0.4336 | Train Accuracy:  88.88%
Epoch 13.
Loss:  0.4127 | Train Accuracy:  89.26%
Epoch 14.
Loss:  0.4068 | Train Accuracy:  88.98%
Epoch 15.
Loss:  0.3941 | Train Accuracy:  89.15%
Epoch 16.
Loss:  0.3860 | Train Accuracy:  88.98%
Epoch 17.
Loss:  0.3727 | Train Accuracy:  89.49%
Epoch 18.
Loss:  0.3610 | Train Accuracy:  89.97%
Epoch 19.
Loss:  0.3522 | Train Accuracy:  90.22%
Epoch 20.
Loss:  0.3483 | Train Accuracy:  90.22%
Epoch 21.

### Evaluation on Source Domains

In [7]:
art_accuracy = evaluate(model, art_loader, device)
print(f"Art Accuracy: {art_accuracy:.2f}%")

cartoon_accuracy = evaluate(model, cartoon_loader, device)
print(f"Cartoon Accuracy: {cartoon_accuracy:.2f}%")

photo_accuracy = evaluate(model, photo_loader, device)
print(f"Photo Accuracy: {photo_accuracy:.2f}%")

source_accuracy = evaluate(model, source_loader, device)
print(f"\nAll Source Domains Accuracy: {source_accuracy:.2f}%")

Art Accuracy: 91.99%
Cartoon Accuracy: 90.83%
Photo Accuracy: 98.68%

All Source Domains Accuracy: 93.39%


### Evaluation on Test Domain

In [8]:
sketch_accuracy = evaluate(model, sketch_loader, device)
print(f"Sketch Accuracy: {sketch_accuracy:.2f}%")

Sketch Accuracy: 55.76%
