In [1]:
import torch
from torchvision.models.detection.mask_rcnn import resnet50
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.utils.data import random_split
from tqdm import tqdm
from torcheval.metrics import MulticlassF1Score, MulticlassPrecision
from torcheval.metrics.classification import MulticlassRecall
from models import OnlineModel
import matplotlib.pyplot as plt

from torcheval.metrics.functional import multiclass_recall, multiclass_precision, multiclass_f1_score

In [2]:
plt.style.use("ggplot")

In [3]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [4]:
batch_size = 132

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# CIFAR10 dataset split.
dataset_test = datasets.CIFAR10(
        root='data',
        train=False,
        download=True,
        transform=transform,
    )
# Create data loaders.
test_loader = DataLoader(
        dataset_test, 
        batch_size=batch_size,
        shuffle=True
    )

Files already downloaded and verified


### Baseline

In [6]:
model = resnet50()
model.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
model.load_state_dict(torch.load("./weights/base"))
model = model.to(device)

criterion = nn.CrossEntropyLoss()

In [7]:

# Start the training.
for epoch in range(1):
    model.eval()
    print("Test")
    valid_epoch_loss = 0
    counter = 0
    val_f1 = 0.0
    val_p = 0.0
    val_r = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(test_loader), total=len(test_loader)):
            counter += 1
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            # forward + backward + optimize
            outputs = model(inputs)

            # f1_metric.update(outputs, labels)
            # p_metric.update(outputs, labels)
            # r_metric.update(outputs, labels)
            val_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
            val_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
            val_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
            loss = criterion(outputs, labels)
            # print statistics
            valid_epoch_loss += loss.item()

    valid_epoch_loss = valid_epoch_loss / counter
    val_f1 = val_f1 / counter
    val_p = val_p / counter
    val_r = val_r / counter

    print(f"Test loss: {valid_epoch_loss:.3f}, f1: {val_f1}, p: {val_p}, r: {val_r}")
    print('-'*50)
        
print('TESTING COMPLETE')

Test


100%|██████████| 76/76 [00:02<00:00, 32.10it/s]

Test loss: 1.860, f1: 0.4070377836101933, p: 0.41810584185939087, r: 0.41986486315727234
--------------------------------------------------
TESTING COMPLETE





### Pretrained feature extractor

In [8]:
model = OnlineModel().encoder
model.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
model.load_state_dict(torch.load("./weights/pretrained"))
model = model.to(device)

criterion = nn.CrossEntropyLoss()

In [9]:
# # Freeze all layers
# for param in model.parameters():
#     param.requires_grad = False

# # Unfreeze last layer
# for param in model.fc.parameters():
#     param.requires_grad = True

In [10]:
# Start the training.
for epoch in range(1):
    model.eval()
    print("Test")
    valid_epoch_loss = 0
    counter = 0
    val_f1 = 0.0
    val_p = 0.0
    val_r = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(test_loader), total=len(test_loader)):
            counter += 1
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            # forward + backward + optimize
            outputs = model(inputs)

            # f1_metric.update(outputs, labels)
            # p_metric.update(outputs, labels)
            # r_metric.update(outputs, labels)
            val_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
            val_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
            val_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
            loss = criterion(outputs, labels)
            # print statistics
            valid_epoch_loss += loss.item()

    valid_epoch_loss = valid_epoch_loss / counter
    val_f1 = val_f1 / counter
    val_p = val_p / counter
    val_r = val_r / counter

    print(f"Test loss: {valid_epoch_loss:.3f}, f1: {val_f1}, p: {val_p}, r: {val_r}")
    print('-'*50)
        
print('TESTING COMPLETE')

Test


100%|██████████| 76/76 [00:03<00:00, 19.25it/s]

Test loss: 2.255, f1: 0.10257887600087806, p: 0.10109993928161107, r: 0.17410038931197241
--------------------------------------------------
TESTING COMPLETE



