## Imports

In [46]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset, random_split, TensorDataset
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, ToPILImage, CenterCrop, RandomResizedCrop
from torchvision.datasets import ImageFolder
from torchvision.models import alexnet, resnet18, inception_v3

from torchvision.models.alexnet import AlexNet_Weights
from torchvision.models.inception import Inception_V3_Weights
from torchvision.models.resnet import ResNet18_Weights

data_path = "../data/"

## Data preparation

In [73]:
from download_rps import download_rps

train_data_path = data_path + "rps"
val_data_path = data_path + "rps-test-set"
download_rps(data_path)
class_names = ["paper", "rock", "scissors"]

rps folder already exists!
rps-test-set folder already exists!


In [48]:
# ImageNet statistics
normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transformer = Compose([Resize(256), CenterCrop(224), ToTensor(), normalizer])

#transformer = ResNet18_Weights.IMAGENET1K_V1.transforms

train_data = ImageFolder(root=train_data_path, transform=transformer)
val_data = ImageFolder(root=val_data_path, transform=transformer)

# Builds a loader of each set
train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(val_data, batch_size=16)

## Feature extraction

### Model configuration

In [49]:
def freeze_model(model):
    for parameter in model.parameters():
        parameter.requires_grad = False

# Set the seed
torch.manual_seed(42)

# Load the model
resnet = resnet18(weights=ResNet18_Weights.DEFAULT)

# Change the top layer to Identity
resnet.fc = nn.Identity()
# Freeze the model
freeze_model(resnet)

### Preprocess data

In [50]:
def preprocess_dataset(model, dataset):
    features = torch.Tensor()
    labels = torch.Tensor()
    for x, y in dataset:
        model.eval()
        features = torch.cat([features, model(x)])
        labels = torch.cat([labels, y])

    dataset = TensorDataset(features, labels)
    return dataset

# Preprocess the data
train_preproc = preprocess_dataset(resnet, train_loader)
val_preproc = preprocess_dataset(resnet, val_loader)

### Save features

In [51]:
train_preproc_path = data_path + "train_preproc.pth"
val_preproc_path = data_path + "val_preproc.pth"
torch.save(train_preproc.tensors, train_preproc_path)
torch.save(val_preproc.tensors, val_preproc_path)

### Load features

In [52]:
train_preproc_path = data_path + "train_preproc.pth"
val_preproc_path = data_path + "val_preproc.pth"
train_preproc_data = TensorDataset(*torch.load(train_preproc_path))
val_preproc_data = TensorDataset(*torch.load(val_preproc_path))
train_preproc_loader = DataLoader(train_preproc_data, batch_size=16, shuffle=True)
val_preproc_loader = DataLoader(val_preproc_data, batch_size=16)

## Top model

### Model configuration

In [53]:
torch.manual_seed(42)
top_model = nn.Sequential(nn.Linear(512, 3))
multi_loss_fn = nn.CrossEntropyLoss(reduction='mean')
optimizer_top = optim.Adam(top_model.parameters(), lr=3e-4)

### Model training and evaluation

In [78]:
import numpy as np

def evaluate(model, data_loader):
    model.eval()
    n_dims = 0
    with torch.no_grad():
        for x, y in data_loader:
            y_pred = model(x)
            _, n_dims = y_pred.shape
            break
    true_positives = np.zeros(n_dims)
    false_positives = np.zeros(n_dims)
    true_negatives = np.zeros(n_dims)
    false_negatives = np.zeros(n_dims)
    with torch.no_grad():
        for x, y in data_loader:
            y_pred = model(x)
            _, predicted = torch.max(y_pred, 1)

            for c in range(n_dims):
                true_positives[c] += (predicted[y == c] == c).sum().item()
                false_positives[c] += (predicted[y != c] == c).sum().item()
                false_negatives[c] += (predicted[y == c] != c).sum().item()
                true_negatives[c] += (predicted[y != c] != c).sum().item()

    for i in range(n_dims):
        precision = true_positives[i] / (true_positives[i] + false_positives[i])
        recall = true_positives[i] / (true_positives[i] + false_negatives[i])
        f1 = 2 * (precision * recall) / (precision + recall)
        print(f"\n{class_names[i].capitalize()}")
        print(f"Precission: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1: {f1:.4f}")

def train(model, train_loader, val_loader, loss_fn, optimizer, n_epochs):
    for epoch in range(n_epochs):
        model.train()
        for x, y in train_loader:
            optimizer.zero_grad()
            y_pred = model(x)
            loss = loss_fn(y_pred, y.long())
            loss.backward()
            optimizer.step()

        print(f"\nEPOCH {epoch + 1}")
        evaluate(model, val_loader)

train(top_model, train_preproc_loader, val_preproc_loader, multi_loss_fn, optimizer_top, 10)


EPOCH 1

Paper
Precission: 0.9670
Recall: 0.7097
F1: 0.8186

Rock
Precission: 0.7654
Recall: 1.0000
F1: 0.8671

Scissors
Precission: 0.8824
Recall: 0.8468
F1: 0.8642

EPOCH 2

Paper
Precission: 0.9620
Recall: 0.6129
F1: 0.7488

Rock
Precission: 0.7126
Recall: 1.0000
F1: 0.8322

Scissors
Precission: 0.8655
Recall: 0.8306
F1: 0.8477

EPOCH 3

Paper
Precission: 0.9639
Recall: 0.6452
F1: 0.7729

Rock
Precission: 0.7425
Recall: 1.0000
F1: 0.8522

Scissors
Precission: 0.8607
Recall: 0.8468
F1: 0.8537

EPOCH 4

Paper
Precission: 0.9615
Recall: 0.6048
F1: 0.7426

Rock
Precission: 0.7209
Recall: 1.0000
F1: 0.8378

Scissors
Precission: 0.8607
Recall: 0.8468
F1: 0.8537

EPOCH 5

Paper
Precission: 0.9639
Recall: 0.6452
F1: 0.7729

Rock
Precission: 0.7425
Recall: 1.0000
F1: 0.8522

Scissors
Precission: 0.8607
Recall: 0.8468
F1: 0.8537

EPOCH 6

Paper
Precission: 0.9630
Recall: 0.6290
F1: 0.7610

Rock
Precission: 0.7045
Recall: 1.0000
F1: 0.8267

Scissors
Precission: 0.8783
Recall: 0.8145
F1: 0.845

## Using the original dataset

### Reattach the top model

In [79]:
resnet.fc = top_model
evaluate(resnet, val_loader)


Paper
Precission: 0.9630
Recall: 0.6290
F1: 0.7610

Rock
Precission: 0.7337
Recall: 1.0000
F1: 0.8464

Scissors
Precission: 0.8607
Recall: 0.8468
F1: 0.8537
