In [10]:
import torch
import torchvision
from torchvision import datasets
from torch.utils.data import DataLoader, ConcatDataset

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [11]:
def prep_fc(model):
    for param in model.parameters():
        param.requires_grad = False
    model.fc = torch.nn.Linear(model.fc.in_features, 7)
    return model

model_1_photo_only = torchvision.models.resnet50(weights='IMAGENET1K_V1')
model_1_photo_only = prep_fc(model_1_photo_only)
optim_m1 = torch.optim.Adam(model_1_photo_only.fc.parameters(), lr=0.001)

model_2_no_art = torchvision.models.resnet50(weights='IMAGENET1K_V1')
model_2_no_art = prep_fc(model_2_no_art)
optim_m2 = torch.optim.Adam(model_2_no_art.fc.parameters(), lr=0.001)

model_3_no_cartoon = torchvision.models.resnet50(weights='IMAGENET1K_V1')
model_3_no_cartoon = prep_fc(model_3_no_cartoon)
optim_m3 = torch.optim.Adam(model_3_no_cartoon.fc.parameters(), lr=0.001)

model_4_no_photo = torchvision.models.resnet50(weights='IMAGENET1K_V1')
model_4_no_photo = prep_fc(model_4_no_photo)
optim_m4 = torch.optim.Adam(model_4_no_photo.fc.parameters(), lr=0.001)

model_5_no_sketch = torchvision.models.resnet50(weights='IMAGENET1K_V1')
model_5_no_sketch = prep_fc(model_5_no_sketch)
optim_m5 = torch.optim.Adam(model_5_no_sketch.fc.parameters(), lr=0.001)

criterion = torch.nn.CrossEntropyLoss()

In [12]:
# mean and sd for normalisation
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(imagenet_mean, imagenet_std)
])

art_dir = "data/pacs_data/pacs_data/art_painting"
cartoon_dir = "data/pacs_data/pacs_data/cartoon"
photo_dir = "data/pacs_data/pacs_data/photo"
sketch_dir = "data/pacs_data/pacs_data/sketch"

photo_dataset = datasets.ImageFolder(root=photo_dir, transform=transform)
art_dataset = datasets.ImageFolder(root=art_dir, transform=transform)
cartoon_dataset = datasets.ImageFolder(root=cartoon_dir, transform=transform)
sketch_dataset = datasets.ImageFolder(root=sketch_dir, transform=transform)

def get_loader(dataset, shuffle):
    loader = DataLoader(
        dataset,
        batch_size=64,
        shuffle=shuffle,
        num_workers=4,
        pin_memory=True
    )
    return loader

train_1_loader = get_loader(photo_dataset, True)
test_1_loader = get_loader(ConcatDataset([art_dataset, cartoon_dataset, sketch_dataset]), False)

train_2_loader = get_loader(ConcatDataset([photo_dataset, cartoon_dataset, sketch_dataset]), True)
test_2_loader = get_loader(art_dataset, False)

train_3_loader = get_loader(ConcatDataset([photo_dataset, art_dataset, sketch_dataset]), True)
test_3_loader = get_loader(cartoon_dataset, False)

train_4_loader = get_loader(ConcatDataset([sketch_dataset, art_dataset, cartoon_dataset]), True)
test_4_loader = get_loader(photo_dataset, False)

train_5_loader = get_loader(ConcatDataset([photo_dataset, art_dataset, cartoon_dataset]), True)
test_5_loader = get_loader(sketch_dataset, False)

PACS_classes = photo_dataset.classes

In [13]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.to(device)
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    avg_loss = running_loss / total
    accuracy = 100. * correct / total
    return avg_loss, accuracy

In [14]:
def evaluate(model, test_loader, loss_fn, classes, device):
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    class_correct = [0] * len(classes)
    class_total = [0] * len(classes)
    class_accuracies = {}

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            preds = model.forward(images)
            loss = loss_fn(preds, labels)
            preds = torch.argmax(preds, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

            for i in range(len(labels)):
                label = labels[i].item()
                if preds[i].item() == label:
                    class_correct[label] += 1
                class_total[label] += 1
    accuracy = correct / total * 100
    
    for i, classname in enumerate(classes):
        acc = 100.0 * class_correct[i] / class_total[i]
        class_accuracies[classname] = acc

    return accuracy, class_accuracies

In [15]:
models = [model_1_photo_only, model_2_no_art, model_3_no_cartoon, model_4_no_photo, model_5_no_sketch]
train_loaders = [train_1_loader, train_2_loader, train_3_loader, train_4_loader, train_5_loader]
optimizers = [optim_m1, optim_m2, optim_m3, optim_m4, optim_m5]

for i, model in enumerate(models):
    print(f"Model {i+1}")
    train_acc = 0
    for _ in range(3):
        _, train_acc = train_epoch(model, train_loaders[i], criterion, optimizers[i], device)
        print(f"Train Accuracy: {train_acc}")
    

Model 1
Train Accuracy: 84.91017964071857
Train Accuracy: 98.62275449101796
Train Accuracy: 98.8622754491018
Model 2
Train Accuracy: 73.0454488228629
Train Accuracy: 84.52725670401611
Train Accuracy: 85.09379327709934
Model 3
Train Accuracy: 71.57055054269648
Train Accuracy: 82.31986399895383
Train Accuracy: 85.07911599319995
Model 4
Train Accuracy: 69.3065737291191
Train Accuracy: 80.1946881384449
Train Accuracy: 80.55522172815768
Model 5
Train Accuracy: 77.89508413064995
Train Accuracy: 88.74958759485318
Train Accuracy: 90.67964368195315


In [16]:
# Evaluate model 1
test_1_loaders = [test_2_loader, test_3_loader, test_5_loader]
test_sets = ["art", "cartoon", "sketch"]
for i, loader in enumerate(test_1_loaders):
    test_accuracy, test_class_accuracies = evaluate(model_1_photo_only, loader, criterion, PACS_classes, device)
    print(f"\nResNet (PACS photo trained) on PACS {test_sets[i]} set: {test_accuracy:.2f}%")
    print("Per-class accuracy:")
    for classname, acc in test_class_accuracies.items():
        print(f"    {classname:10s}: {acc:.2f}%")



ResNet (PACS photo trained) on PACS art set: 60.89%
Per-class accuracy:
    dog       : 87.34%
    elephant  : 52.55%
    giraffe   : 43.86%
    guitar    : 84.24%
    horse     : 61.19%
    house     : 76.61%
    person    : 34.08%

ResNet (PACS photo trained) on PACS cartoon set: 27.30%
Per-class accuracy:
    dog       : 46.53%
    elephant  : 3.72%
    giraffe   : 22.83%
    guitar    : 100.00%
    horse     : 17.28%
    house     : 46.53%
    person    : 9.38%

ResNet (PACS photo trained) on PACS sketch set: 33.80%
Per-class accuracy:
    dog       : 20.98%
    elephant  : 0.00%
    giraffe   : 74.77%
    guitar    : 93.59%
    horse     : 0.00%
    house     : 0.00%
    person    : 21.25%


In [17]:
# Evaluate models 2-5
test_loaders = [test_2_loader, test_3_loader, test_4_loader, test_5_loader]
test_sets = ["art", "cartoon", "photo", "sketch"]
for i, model in enumerate(models[1:]):
    test_accuracy, test_class_accuracies = evaluate(model, test_loaders[i], criterion, PACS_classes, device)
    print(f"\nResNet (non-{test_sets[i]}) on PACS {test_sets[i]} set: {test_accuracy:.2f}%")
    print("Per-class accuracy:")
    for classname, acc in test_class_accuracies.items():
        print(f"    {classname:10s}: {acc:.2f}%")



ResNet (non-art) on PACS art set: 63.77%
Per-class accuracy:
    dog       : 50.13%
    elephant  : 55.29%
    giraffe   : 74.39%
    guitar    : 58.15%
    horse     : 53.73%
    house     : 70.85%
    person    : 75.50%

ResNet (non-cartoon) on PACS cartoon set: 55.08%
Per-class accuracy:
    dog       : 58.87%
    elephant  : 23.85%
    giraffe   : 86.71%
    guitar    : 97.04%
    horse     : 33.95%
    house     : 86.46%
    person    : 40.25%

ResNet (non-photo) on PACS photo set: 90.90%
Per-class accuracy:
    dog       : 89.95%
    elephant  : 91.58%
    giraffe   : 80.77%
    guitar    : 95.70%
    horse     : 71.86%
    house     : 100.00%
    person    : 96.06%

ResNet (non-sketch) on PACS sketch set: 52.28%
Per-class accuracy:
    dog       : 20.73%
    elephant  : 52.43%
    giraffe   : 84.20%
    guitar    : 87.50%
    horse     : 26.59%
    house     : 60.00%
    person    : 46.88%


In [18]:
torch.save(model_1_photo_only.state_dict(), "models/resnet50-pacs-1-photo-only.pth")
torch.save(model_2_no_art.state_dict(), "models/resnet50-pacs-2-no-art.pth")
torch.save(model_3_no_cartoon.state_dict(), "models/resnet50-pacs-3-no-cartoon.pth")
torch.save(model_4_no_photo.state_dict(), "models/resnet50-pacs-4-no-photo.pth")
torch.save(model_5_no_sketch.state_dict(), "models/resnet50-pacs-4-no-sketch.pth")