In [101]:
import torch
from torchvision.models.detection.mask_rcnn import resnet50
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.utils.data import random_split
from tqdm import tqdm
from torcheval.metrics import MulticlassF1Score, MulticlassPrecision
from torcheval.metrics.classification import MulticlassRecall
from models import OnlineModel
import matplotlib.pyplot as plt

from torcheval.metrics.functional import multiclass_recall, multiclass_precision, multiclass_f1_score

In [117]:
plt.style.use("ggplot")

In [102]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [103]:
batch_size = 132
epochs = 12

In [104]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# CIFAR10 dataset split.
dataset_train = datasets.CIFAR10(
        root='data',
        train=True,
        download=True,
        transform=transform,
    )
dataset_train, dataset_val = random_split(dataset_train, [45000, 5000])
dataset_test = datasets.CIFAR10(
        root='data',
        train=False,
        download=True,
        transform=transform,
    )
    # Create data loaders.
train_loader = DataLoader(
        dataset_train, 
        batch_size=batch_size,
        shuffle=True
    )
val_loader = DataLoader(
        dataset_val, 
        batch_size=batch_size,
        shuffle=True
    )
test_loader = DataLoader(
        dataset_test, 
        batch_size=batch_size,
        shuffle=True
    )

Files already downloaded and verified
Files already downloaded and verified


In [105]:
# print(multiclass_f1_score(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 1, 1, 1]), num_classes=2, average="micro").item())
# print(multiclass_precision(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 1, 1, 1]), num_classes=2, average="micro").item())
# print(multiclass_recall(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 1, 1, 1]), num_classes=2, average="micro").item())

### Baseline

In [106]:
model = resnet50()
model.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [107]:
# f1_metric = MulticlassF1Score(num_classes=10)
# p_metric = MulticlassPrecision(num_classes=10)
# r_metric = MulticlassRecall(num_classes=10)

train_loss, valid_loss = [], []
f1_train, f1_valid = [], []
p_train, p_valid = [], []
r_train, r_valid = [], []

# Start the training.
for epoch in range(epochs):
    model.train()
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")

    print('Training')
    train_epoch_loss = 0.0
    counter = 0
    train_f1 = 0.0
    train_p = 0.0
    train_r = 0.0
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
        counter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        # f1_metric.update(outputs, labels)
        # p_metric.update(outputs, labels)
        # r_metric.update(outputs, labels)
        # train_f1 += f1_metric.compute().item()
        # train_p += p_metric.compute().item()
        # train_r += r_metric.compute().item()
        train_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
        train_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
        train_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
        loss.backward()
        optimizer.step()
        # print statistics
        train_epoch_loss += loss.item()

    train_epoch_loss = train_epoch_loss / counter
    train_f1 = train_f1 / counter
    train_p = train_p / counter
    train_r = train_r / counter

    model.eval()
    print("Validation")
    valid_epoch_loss = 0
    counter = 0
    val_f1 = 0.0
    val_p = 0.0
    val_r = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(val_loader), total=len(val_loader)):
            counter += 1
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            # forward + backward + optimize
            outputs = model(inputs)

            # f1_metric.update(outputs, labels)
            # p_metric.update(outputs, labels)
            # r_metric.update(outputs, labels)
            val_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
            val_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
            val_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
            loss = criterion(outputs, labels)
            # print statistics
            valid_epoch_loss += loss.item()

    valid_epoch_loss = valid_epoch_loss / counter
    val_f1 = val_f1 / counter
    val_p = val_p / counter
    val_r = val_r / counter

    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)

    f1_train.append(train_f1)
    f1_valid.append(val_f1)

    p_train.append(train_p)
    p_valid.append(val_p)

    r_train.append(train_r)
    r_valid.append(val_r)

    print(f"Training loss: {train_epoch_loss:.3f}, f1: {train_f1}, p: {train_p}, r: {train_r}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, f1: {val_f1}, p: {val_p}, r: {val_r}")
    print('-'*50)
        
print('TRAINING COMPLETE')

[INFO]: Epoch 1 of 12
Training


100%|██████████| 341/341 [00:13<00:00, 24.76it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 41.65it/s]


Training loss: 2.160, f1: 0.2029059614297931, p: 0.22636827637504273, r: 0.22446793235860263
Validation loss: 1.939, f1: 0.29392848516765396, p: 0.30897745999850723, r: 0.3097578236146977
--------------------------------------------------
[INFO]: Epoch 2 of 12
Training


100%|██████████| 341/341 [00:13<00:00, 25.53it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 50.31it/s]


Training loss: 1.854, f1: 0.3189296936534367, p: 0.3480951831050632, r: 0.340619915216899
Validation loss: 1.808, f1: 0.33772325986310053, p: 0.35967447961631577, r: 0.3596550390908593
--------------------------------------------------
[INFO]: Epoch 3 of 12
Training


100%|██████████| 341/341 [00:13<00:00, 25.98it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 50.72it/s]


Training loss: 1.690, f1: 0.375406042659038, p: 0.40277233151746283, r: 0.3964501596615811
Validation loss: 1.722, f1: 0.372089811061558, p: 0.4018392131516808, r: 0.39422643498370524
--------------------------------------------------
[INFO]: Epoch 4 of 12
Training


100%|██████████| 341/341 [00:13<00:00, 25.74it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 48.26it/s]


Training loss: 1.556, f1: 0.42409037765869295, p: 0.44838714914238, r: 0.4435473195443755
Validation loss: 1.624, f1: 0.40172286331653595, p: 0.4221214330510089, r: 0.42241808144669785
--------------------------------------------------
[INFO]: Epoch 5 of 12
Training


100%|██████████| 341/341 [00:13<00:00, 25.89it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 47.93it/s]


Training loss: 1.454, f1: 0.4642480275323314, p: 0.48884829660314966, r: 0.48264462061641505
Validation loss: 1.547, f1: 0.41993111920984166, p: 0.454171979113629, r: 0.448409787918392
--------------------------------------------------
[INFO]: Epoch 6 of 12
Training


100%|██████████| 341/341 [00:13<00:00, 25.92it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 49.14it/s]


Training loss: 1.348, f1: 0.5043389805076409, p: 0.5260700152591526, r: 0.5225621636661966
Validation loss: 1.525, f1: 0.44468695239016887, p: 0.46824101554720027, r: 0.45958794182852697
--------------------------------------------------
[INFO]: Epoch 7 of 12
Training


100%|██████████| 341/341 [00:13<00:00, 26.10it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 48.18it/s]


Training loss: 1.265, f1: 0.5364448082936474, p: 0.5574696200334431, r: 0.5533415780039477
Validation loss: 1.499, f1: 0.45670212726843984, p: 0.468539906175513, r: 0.47138010043846934
--------------------------------------------------
[INFO]: Epoch 8 of 12
Training


100%|██████████| 341/341 [00:13<00:00, 26.23it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 48.06it/s]


Training loss: 1.167, f1: 0.5722026310183785, p: 0.5888978008650615, r: 0.5880960119434815
Validation loss: 1.506, f1: 0.47703952146203893, p: 0.4876447500366914, r: 0.4899917699788746
--------------------------------------------------
[INFO]: Epoch 9 of 12
Training


100%|██████████| 341/341 [00:13<00:00, 25.86it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 48.99it/s]


Training loss: 1.086, f1: 0.60226726488284, p: 0.6194304608879201, r: 0.6186515049151311
Validation loss: 1.585, f1: 0.464857070069564, p: 0.4945778556560215, r: 0.4763180880170119
--------------------------------------------------
[INFO]: Epoch 10 of 12
Training


100%|██████████| 341/341 [00:12<00:00, 26.47it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 50.33it/s]


Training loss: 0.990, f1: 0.64002692681953, p: 0.6543057034791739, r: 0.6551058484661963
Validation loss: 1.551, f1: 0.4724411486010802, p: 0.49175210297107697, r: 0.48280450936995056
--------------------------------------------------
[INFO]: Epoch 11 of 12
Training


100%|██████████| 341/341 [00:12<00:00, 26.92it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 50.19it/s]


Training loss: 0.915, f1: 0.666286860568083, p: 0.6801112397325354, r: 0.6811666859210062
Validation loss: 1.591, f1: 0.48884361433355433, p: 0.4999043086641713, r: 0.500759596103116
--------------------------------------------------
[INFO]: Epoch 12 of 12
Training


100%|██████████| 341/341 [00:12<00:00, 26.67it/s]


Validation


100%|██████████| 38/38 [00:00<00:00, 51.47it/s]

Training loss: 0.838, f1: 0.6933437963385036, p: 0.7062799093310784, r: 0.70750704766019
Validation loss: 1.618, f1: 0.4763316661119461, p: 0.49766919644255386, r: 0.4832678719570762
--------------------------------------------------
TRAINING COMPLETE





In [108]:
torch.save(model.state_dict(), f"./weights/base")

### Pretrained feature extractor

In [126]:
model = OnlineModel().encoder
model.load_state_dict(torch.load("../BYOL/pretrained_feature_extractors/feature_extractor_20"))
model.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [127]:
# # Freeze all layers
# for param in model.parameters():
#     param.requires_grad = False

# # Unfreeze last layer
# for param in model.fc.parameters():
#     param.requires_grad = True

In [128]:
train_loss_pretrained, valid_loss_pretrained = [], []
f1_train_pretrained, f1_valid_pretrained = [], []
p_train_pretrained, p_valid_pretrained = [], []
r_train_pretrained, r_valid_pretrained = [], []

# Start the training.
for epoch in range(epochs):
    model.train()
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")

    print('Training')
    train_epoch_loss = 0.0
    counter = 0
    train_f1 = 0.0
    train_p = 0.0
    train_r = 0.0
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
        counter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # f1_metric.update(outputs, labels)
        # p_metric.update(outputs, labels)
        # r_metric.update(outputs, labels)
        train_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
        train_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
        train_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
        loss.backward()
        optimizer.step()
        # print statistics
        train_epoch_loss += loss.item()

    train_epoch_loss = train_epoch_loss / counter
    train_f1 = train_f1 / counter
    train_p = train_p / counter
    train_r = train_r / counter

    model.eval()
    print("Validation")
    valid_epoch_loss = 0
    counter = 0
    val_f1 = 0.0
    val_p = 0.0
    val_r = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(val_loader), total=len(val_loader)):
            counter += 1
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            # forward + backward + optimize
            outputs = model(inputs)
            # f1_metric.update(outputs, labels)
            # p_metric.update(outputs, labels)
            # r_metric.update(outputs, labels)
            val_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
            val_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
            val_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
            loss = criterion(outputs, labels)
            # print statistics
            valid_epoch_loss += loss.item()

    valid_epoch_loss = valid_epoch_loss / counter
    val_f1 = val_f1 / counter
    val_p = val_p / counter
    val_r = val_r / counter

    train_loss_pretrained.append(train_epoch_loss)
    valid_loss_pretrained.append(valid_epoch_loss)

    f1_train_pretrained.append(train_f1)
    f1_valid_pretrained.append(val_f1)

    p_train_pretrained.append(train_p)
    p_valid_pretrained.append(val_p)

    r_train_pretrained.append(train_r)
    r_valid_pretrained.append(val_r)

    print(f"Training loss: {train_epoch_loss:.3f}, f1: {train_f1}, p: {train_p}, r: {train_r}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, f1: {val_f1}, p: {val_p}, r: {val_r}")
    print('-'*50)
        
print('TRAINING COMPLETE')


[INFO]: Epoch 1 of 12
Training


100%|██████████| 341/341 [00:45<00:00,  7.52it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 19.90it/s]


Training loss: 2.334, f1: 0.055799128378460135, p: 0.06313767095978572, r: 0.12487802776860352
Validation loss: 2.380, f1: 0.06332263698507297, p: 0.057490208921463865, r: 0.1346162732102369
--------------------------------------------------
[INFO]: Epoch 2 of 12
Training


100%|██████████| 341/341 [00:44<00:00,  7.69it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 19.87it/s]


Training loss: 2.300, f1: 0.0761635543893247, p: 0.08386690702147323, r: 0.14603866507842744
Validation loss: 2.271, f1: 0.06752822729513834, p: 0.08216146284126137, r: 0.15840311799394458
--------------------------------------------------
[INFO]: Epoch 3 of 12
Training


100%|██████████| 341/341 [00:45<00:00,  7.57it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 19.62it/s]


Training loss: 2.279, f1: 0.08955873857058388, p: 0.09652604766823679, r: 0.15659440407130726
Validation loss: 2.300, f1: 0.09284839534053677, p: 0.08163645677268505, r: 0.16530683107281985
--------------------------------------------------
[INFO]: Epoch 4 of 12
Training


100%|██████████| 341/341 [00:44<00:00,  7.69it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 19.87it/s]


Training loss: 2.251, f1: 0.09877555280967541, p: 0.10928929058010103, r: 0.1610668741406933
Validation loss: 2.377, f1: 0.09876088366696709, p: 0.0859151927656249, r: 0.1741296370562754
--------------------------------------------------
[INFO]: Epoch 5 of 12
Training


100%|██████████| 341/341 [00:44<00:00,  7.68it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 20.86it/s]


Training loss: 2.240, f1: 0.1076941308387913, p: 0.11490007932366164, r: 0.1700982015547165
Validation loss: 2.323, f1: 0.10828044736071636, p: 0.13184672829351926, r: 0.1677242715499903
--------------------------------------------------
[INFO]: Epoch 6 of 12
Training


100%|██████████| 341/341 [00:44<00:00,  7.70it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 20.45it/s]


Training loss: 2.236, f1: 0.10929248354171028, p: 0.12038027163584036, r: 0.172061949968338
Validation loss: 2.244, f1: 0.12401773270807769, p: 0.11132986843585968, r: 0.18175539276317546
--------------------------------------------------
[INFO]: Epoch 7 of 12
Training


100%|██████████| 341/341 [00:44<00:00,  7.68it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 20.27it/s]


Training loss: 2.231, f1: 0.11265094884701314, p: 0.12330422542792611, r: 0.1752409615780601
Validation loss: 2.216, f1: 0.11401189667613883, p: 0.10282595828175545, r: 0.1806074404402783
--------------------------------------------------
[INFO]: Epoch 8 of 12
Training


100%|██████████| 341/341 [00:43<00:00,  7.76it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 20.61it/s]


Training loss: 2.211, f1: 0.11753976642211511, p: 0.12863790889508214, r: 0.18165116516813154
Validation loss: 2.181, f1: 0.10980267018864029, p: 0.1249152835654585, r: 0.16980401348126561
--------------------------------------------------
[INFO]: Epoch 9 of 12
Training


100%|██████████| 341/341 [00:45<00:00,  7.53it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 19.35it/s]


Training loss: 2.192, f1: 0.125351970092205, p: 0.13759354331538823, r: 0.1853097205724884
Validation loss: 2.184, f1: 0.13567283239803815, p: 0.11849062242790272, r: 0.18477229146580948
--------------------------------------------------
[INFO]: Epoch 10 of 12
Training


100%|██████████| 341/341 [00:44<00:00,  7.70it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 20.79it/s]


Training loss: 2.187, f1: 0.12290377280185999, p: 0.133584844445553, r: 0.18521007849586324
Validation loss: 2.222, f1: 0.12207499244495441, p: 0.13700973762101248, r: 0.18571655216969943
--------------------------------------------------
[INFO]: Epoch 11 of 12
Training


100%|██████████| 341/341 [00:44<00:00,  7.73it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 20.13it/s]


Training loss: 2.178, f1: 0.12939596449262586, p: 0.14287980223112792, r: 0.18852837700560646
Validation loss: 2.191, f1: 0.12217146001364056, p: 0.12641450498057039, r: 0.18444752065758957
--------------------------------------------------
[INFO]: Epoch 12 of 12
Training


100%|██████████| 341/341 [00:44<00:00,  7.64it/s]


Validation


100%|██████████| 38/38 [00:01<00:00, 19.88it/s]

Training loss: 2.182, f1: 0.1288635454023164, p: 0.13980741245000244, r: 0.19021832956072174
Validation loss: 2.188, f1: 0.13237648692570234, p: 0.12137670138556707, r: 0.19988089721453817
--------------------------------------------------
TRAINING COMPLETE





In [129]:
torch.save(model.state_dict(), f"./weights/pretrained")

### Plots

In [130]:
plt.plot(f1_train_pretrained, label='f1_train_pretrained')
plt.plot(f1_valid_pretrained, label='f1_valid_pretrained')
plt.plot(f1_train, label='f1_train')
plt.plot(f1_valid, label='f1_valid')

plt.legend()
plt.savefig(f"./plots/f1", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>

In [131]:
plt.plot(p_train_pretrained, label='p_train_pretrained')
plt.plot(p_valid_pretrained, label='p_valid_pretrained')
plt.plot(p_train, label='p_train')
plt.plot(p_valid, label='p_valid')

plt.legend()
plt.savefig(f"./plots/precision", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>

In [132]:
plt.plot(r_train_pretrained, label='r_train_pretrained')
plt.plot(r_valid_pretrained, label='r_valid_pretrained')
plt.plot(r_train, label='r_train')
plt.plot(r_valid, label='r_valid')

plt.legend()
plt.savefig(f"./plots/recall", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>

In [133]:
plt.plot(train_loss_pretrained, label='train_loss_pretrained')
plt.plot(valid_loss_pretrained, label='train_loss_pretrained')
plt.plot(train_loss, label='train_loss')
plt.plot(valid_loss, label='train_loss')

plt.legend()
plt.savefig(f"./plots/loss", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>