In [1]:
import torch
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [2]:
def prep_fc(model):
    for param in model.parameters():
        param.requires_grad = False
    model.fc = torch.nn.Linear(model.fc.in_features, 10)
    return model

model_1_photo_only = torchvision.models.resnet50(weights='IMAGENET1K_V1')
model_1_photo_only = prep_fc(model_1_photo_only)
optim_m1 = torch.optim.Adam(model_1_photo_only.fc.parameters(), lr=0.001)

model_2_no_art = torchvision.models.resnet50(weights='IMAGENET1K_V1')
model_2_no_art = prep_fc(model_2_no_art)
optim_m2 = torch.optim.Adam(model_2_no_art.fc.parameters(), lr=0.001)

model_3_no_cartoon = torchvision.models.resnet50(weights='IMAGENET1K_V1')
model_3_no_cartoon = prep_fc(model_3_no_cartoon)
optim_m3 = torch.optim.Adam(model_3_no_cartoon.fc.parameters(), lr=0.001)

model_4_no_photo = torchvision.models.resnet50(weights='IMAGENET1K_V1')
model_4_no_photo = prep_fc(model_4_no_photo)
optim_m4 = torch.optim.Adam(model_4_no_photo.fc.parameters(), lr=0.001)

model_5_no_sketch = torchvision.models.resnet50(weights='IMAGENET1K_V1')
model_5_no_sketch = prep_fc(model_5_no_sketch)
optim_m5 = torch.optim.Adam(model_5_no_sketch.fc.parameters(), lr=0.001)

criterion = torch.nn.CrossEntropyLoss()

In [3]:
# mean and sd for normalisation
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(imagenet_mean, imagenet_std)
])

art_dir = "data/pacs_data/pacs_data/art_painting"
cartoon_dir = "data/pacs_data/pacs_data/cartoon"
photo_dir = "data/pacs_data/pacs_data/photo"
sketch_dir = "data/pacs_data/pacs_data/sketch"

photo_dataset = datasets.ImageFolder(root=photo_dir, transform=transform)
art_dataset = datasets.ImageFolder(root=art_dir, transform=transform)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=transform)
sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=transform)

def get_loader(dataset, shuffle):
    loader = DataLoader(
        dataset,
        batch_size=64,
        shuffle=shuffle,
        num_workers=4,
        pin_memory=True
    )
    return loader

train_1_loader = get_loader(photo_dataset, True)
test_1_loader = get_loader(ConcatDataset([art_dataset, cartoon_dataset, sketch_dataset]), False)

train_2_loader = get_loader(ConcatDataset([photo_dataset, cartoon_dataset, sketch_dataset]), True)
test_2_loader = get_loader(art_dataset, False)

train_3_loader = get_loader(ConcatDataset([photo_dataset, art_dataset, sketch_dataset]), True)
test_3_loader = get_loader(cartoon_dataset, False)

train_4_loader = get_loader(ConcatDataset([sketch_dataset, art_dataset, cartoon_dataset]), True)
test_4_loader = get_loader(photo_dataset, False)

train_5_loader = get_loader(ConcatDataset([photo_dataset, art_dataset, cartoon_dataset]), True)
test_5_loader = get_loader(sketch_dataset, False)

PACS_classes = photo_dataset.classes

In [4]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.to(device)
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    avg_loss = running_loss / total
    accuracy = 100. * correct / total
    return avg_loss, accuracy

In [5]:
def evaluate(model, test_loader, loss_fn, classes, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    class_correct = [0] * len(classes)
    class_total = [0] * len(classes)
    class_accuracies = {}

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            preds = model.forward(images)
            loss = loss_fn(preds, labels)
            preds = torch.argmax(preds, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            for i in range(len(labels)):
                label = labels[i].item()
                if preds[i].item() == label:
                    class_correct[label] += 1
                class_total[label] += 1
    accuracy = correct / total * 100
    
    for i, classname in enumerate(classes):
        acc = 100.0 * class_correct[i] / class_total[i]
        class_accuracies[classname] = acc

    return accuracy, class_accuracies

In [6]:
models = [model_1_photo_only, model_2_no_art, model_3_no_cartoon, model_4_no_photo, model_5_no_sketch]
train_loaders = [train_1_loader, train_2_loader, train_3_loader, train_4_loader, train_5_loader]
optimizers = [optim_m1, optim_m2, optim_m3, optim_m4, optim_m5]

for i, model in enumerate(models):
    print(f"Model {i+1}")
    train_acc = 0
    for _ in range(3):
        _, train_acc = train_epoch(model, train_loaders[i], criterion, optimizers[i], device)
        print(f"Train Accuracy: {train_acc}")
    

Model 1
Train Accuracy: 81.91616766467065
Train Accuracy: 98.44311377245509
Train Accuracy: 98.62275449101796
Model 2
Train Accuracy: 73.25947375047211
Train Accuracy: 84.19992446179026
Train Accuracy: 85.1567417852197
Model 3
Train Accuracy: 71.49208840068
Train Accuracy: 83.03910030077155
Train Accuracy: 84.922191709167
Model 4
Train Accuracy: 67.50390578055521
Train Accuracy: 79.06501622401154
Train Accuracy: 81.15611104434564
Model 5
Train Accuracy: 76.6413724843286
Train Accuracy: 88.83206862421643
Train Accuracy: 90.01979544704717


In [7]:
# Evaluate model 1
test_1_loaders = [test_2_loader, test_3_loader, test_5_loader]
test_sets = ["art", "cartoon", "sketch"]
for i, loader in enumerate(test_1_loaders):
    test_accuracy, test_class_accuracies = evaluate(model_1_photo_only, loader, criterion, PACS_classes, device)
    print(f"\nResNet (PACS photo trained) on PACS {test_sets[i]} set: {test_accuracy:.2f}%")
    print("Per-class accuracy:")
    for classname, acc in test_class_accuracies.items():
        print(f"    {classname:10s}: {acc:.2f}%")



ResNet (PACS photo trained) on PACS art set: 64.94%
Per-class accuracy:
    dog       : 79.68%
    elephant  : 64.71%
    giraffe   : 45.26%
    guitar    : 79.89%
    horse     : 74.63%
    house     : 90.85%
    person    : 37.64%

ResNet (PACS photo trained) on PACS cartoon set: 34.90%
Per-class accuracy:
    dog       : 41.39%
    elephant  : 11.16%
    giraffe   : 19.65%
    guitar    : 100.00%
    horse     : 30.86%
    house     : 70.83%
    person    : 24.44%

ResNet (PACS photo trained) on PACS sketch set: 30.36%
Per-class accuracy:
    dog       : 5.44%
    elephant  : 0.00%
    giraffe   : 66.00%
    guitar    : 89.64%
    horse     : 0.74%
    house     : 1.25%
    person    : 63.75%


In [8]:
# Evaluate models 2-5
test_loaders = [test_2_loader, test_3_loader, test_4_loader, test_5_loader]
test_sets = ["art", "cartoon", "photo", "sketch"]
for i, model in enumerate(models[1:]):
    test_accuracy, test_class_accuracies = evaluate(model, test_loaders[i], criterion, PACS_classes, device)
    print(f"\nResNet (non-{test_sets[i]}) on PACS {test_sets[i]} set: {test_accuracy:.2f}%")
    print("Per-class accuracy:")
    for classname, acc in test_class_accuracies.items():
        print(f"    {classname:10s}: {acc:.2f}%")



ResNet (non-art) on PACS art set: 59.62%
Per-class accuracy:
    dog       : 59.10%
    elephant  : 41.18%
    giraffe   : 67.02%
    guitar    : 57.61%
    horse     : 60.20%
    house     : 74.24%
    person    : 56.79%

ResNet (non-cartoon) on PACS cartoon set: 50.85%
Per-class accuracy:
    dog       : 67.35%
    elephant  : 19.04%
    giraffe   : 65.90%
    guitar    : 98.52%
    horse     : 28.70%
    house     : 85.42%
    person    : 35.31%

ResNet (non-photo) on PACS photo set: 89.46%
Per-class accuracy:
    dog       : 98.41%
    elephant  : 79.21%
    giraffe   : 84.07%
    guitar    : 94.62%
    horse     : 55.28%
    house     : 100.00%
    person    : 99.31%

ResNet (non-sketch) on PACS sketch set: 46.27%
Per-class accuracy:
    dog       : 22.54%
    elephant  : 40.00%
    giraffe   : 51.53%
    guitar    : 80.92%
    horse     : 36.27%
    house     : 30.00%
    person    : 92.50%


In [9]:
torch.save(model_1_photo_only.state_dict(), "models/resnet50-pacs-1-photo-only.pth")
torch.save(model_2_no_art.state_dict(), "models/resnet50-pacs-2-no-art.pth")
torch.save(model_3_no_cartoon.state_dict(), "models/resnet50-pacs-3-no-cartoon.pth")
torch.save(model_4_no_photo.state_dict(), "models/resnet50-pacs-4-no-photo.pth")
torch.save(model_5_no_sketch.state_dict(), "models/resnet50-pacs-4-no-sketch.pth")