In [9]:
import torch
from torchvision.models.detection.mask_rcnn import resnet50
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.utils.data import random_split
from tqdm import tqdm
from torcheval.metrics import MulticlassF1Score, MulticlassPrecision
from torcheval.metrics.classification import MulticlassRecall
from models import OnlineModel
import matplotlib.pyplot as plt

from torcheval.metrics.functional import multiclass_recall, multiclass_precision, multiclass_f1_score

In [10]:
plt.style.use("ggplot")

In [11]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [12]:
batch_size = 256
epochs = 12

In [13]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# CIFAR10 dataset split.
dataset_train = datasets.CIFAR10(
        root='data',
        train=True,
        download=True,
        transform=transform,
    )
dataset_train, dataset_val = random_split(dataset_train, [45000, 5000])
dataset_train, _ = random_split(dataset_train, [4500, 40500])
dataset_val, _ = random_split(dataset_val, [500, 4500])
dataset_test = datasets.CIFAR10(
        root='data',
        train=False,
        download=True,
        transform=transform,
    )
    # Create data loaders.
train_loader = DataLoader(
        dataset_train, 
        batch_size=batch_size,
        shuffle=True
    )
val_loader = DataLoader(
        dataset_val, 
        batch_size=batch_size,
        shuffle=True
    )
test_loader = DataLoader(
        dataset_test, 
        batch_size=batch_size,
        shuffle=True
    )

Files already downloaded and verified
Files already downloaded and verified


In [14]:
# print(multiclass_f1_score(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 1, 1, 1]), num_classes=2, average="micro").item())
# print(multiclass_precision(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 1, 1, 1]), num_classes=2, average="micro").item())
# print(multiclass_recall(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 1, 1, 1]), num_classes=2, average="micro").item())

### Baseline

In [15]:
model = resnet50()
model.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [16]:
# f1_metric = MulticlassF1Score(num_classes=10)
# p_metric = MulticlassPrecision(num_classes=10)
# r_metric = MulticlassRecall(num_classes=10)

train_loss, valid_loss = [], []
f1_train, f1_valid = [], []
p_train, p_valid = [], []
r_train, r_valid = [], []

# Start the training.
for epoch in range(epochs):
    model.train()
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")

    print('Training')
    train_epoch_loss = 0.0
    counter = 0
    train_f1 = 0.0
    train_p = 0.0
    train_r = 0.0
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
        counter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        # f1_metric.update(outputs, labels)
        # p_metric.update(outputs, labels)
        # r_metric.update(outputs, labels)
        # train_f1 += f1_metric.compute().item()
        # train_p += p_metric.compute().item()
        # train_r += r_metric.compute().item()
        train_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
        train_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
        train_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
        loss.backward()
        optimizer.step()
        # print statistics
        train_epoch_loss += loss.item()

    train_epoch_loss = train_epoch_loss / counter
    train_f1 = train_f1 / counter
    train_p = train_p / counter
    train_r = train_r / counter

    model.eval()
    print("Validation")
    valid_epoch_loss = 0
    counter = 0
    val_f1 = 0.0
    val_p = 0.0
    val_r = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(val_loader), total=len(val_loader)):
            counter += 1
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            # forward + backward + optimize
            outputs = model(inputs)

            # f1_metric.update(outputs, labels)
            # p_metric.update(outputs, labels)
            # r_metric.update(outputs, labels)
            val_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
            val_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
            val_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
            loss = criterion(outputs, labels)
            # print statistics
            valid_epoch_loss += loss.item()

    valid_epoch_loss = valid_epoch_loss / counter
    val_f1 = val_f1 / counter
    val_p = val_p / counter
    val_r = val_r / counter

    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)

    f1_train.append(train_f1)
    f1_valid.append(val_f1)

    p_train.append(train_p)
    p_valid.append(val_p)

    r_train.append(train_r)
    r_valid.append(val_r)

    print(f"Training loss: {train_epoch_loss:.3f}, f1: {train_f1}, p: {train_p}, r: {train_r}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, f1: {val_f1}, p: {val_p}, r: {val_r}")
    print('-'*50)
        
print('TRAINING COMPLETE')

[INFO]: Epoch 1 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 13.15it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 17.92it/s]


Training loss: 2.490, f1: 0.0953126907762554, p: 0.10864415475063854, r: 0.10919060930609703
Validation loss: 2.351, f1: 0.07910642400383949, p: 0.12404929473996162, r: 0.12152556329965591
--------------------------------------------------
[INFO]: Epoch 2 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 17.12it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 28.35it/s]


Training loss: 2.410, f1: 0.12094948068261147, p: 0.1308628941575686, r: 0.13322563469409943
Validation loss: 2.566, f1: 0.0999944917857647, p: 0.11440722271800041, r: 0.10855479165911674
--------------------------------------------------
[INFO]: Epoch 3 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 16.80it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 24.30it/s]


Training loss: 2.336, f1: 0.13070274227195317, p: 0.1452573649585247, r: 0.14140580718715987
Validation loss: 2.587, f1: 0.12642822787165642, p: 0.13679759204387665, r: 0.1373981609940529
--------------------------------------------------
[INFO]: Epoch 4 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 16.91it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 27.34it/s]


Training loss: 2.285, f1: 0.1492091491818428, p: 0.16416835660735765, r: 0.16430271913607916
Validation loss: 2.395, f1: 0.14403236284852028, p: 0.1511385515332222, r: 0.16208580136299133
--------------------------------------------------
[INFO]: Epoch 5 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 16.76it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 27.92it/s]


Training loss: 2.205, f1: 0.18487298819753858, p: 0.1999764997098181, r: 0.19692341403828728
Validation loss: 2.333, f1: 0.16199962049722672, p: 0.17115431278944016, r: 0.16771823912858963
--------------------------------------------------
[INFO]: Epoch 6 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 17.33it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 26.68it/s]


Training loss: 2.149, f1: 0.2103632026248508, p: 0.23148226075702244, r: 0.22391275068124136
Validation loss: 2.217, f1: 0.1705378219485283, p: 0.19318005442619324, r: 0.18301738798618317
--------------------------------------------------
[INFO]: Epoch 7 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 17.17it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 27.29it/s]


Training loss: 2.095, f1: 0.23392324232392842, p: 0.2590810689661238, r: 0.2461659461259842
Validation loss: 2.219, f1: 0.19656091928482056, p: 0.21177832037210464, r: 0.22755196690559387
--------------------------------------------------
[INFO]: Epoch 8 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 17.10it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 28.05it/s]


Training loss: 2.028, f1: 0.2485276891125573, p: 0.27964261919260025, r: 0.2660578638315201
Validation loss: 2.222, f1: 0.18936192989349365, p: 0.20658008754253387, r: 0.20716143399477005
--------------------------------------------------
[INFO]: Epoch 9 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 17.42it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 26.76it/s]


Training loss: 1.957, f1: 0.2891757935285568, p: 0.3093995617495643, r: 0.300629672076967
Validation loss: 2.187, f1: 0.1873161941766739, p: 0.20093470811843872, r: 0.20845971256494522
--------------------------------------------------
[INFO]: Epoch 10 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 17.27it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 22.82it/s]


Training loss: 1.889, f1: 0.3140010353591707, p: 0.336570594045851, r: 0.3283900287416246
Validation loss: 2.172, f1: 0.21691153198480606, p: 0.2256055846810341, r: 0.23205406218767166
--------------------------------------------------
[INFO]: Epoch 11 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 17.16it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 28.31it/s]


Training loss: 1.824, f1: 0.3361483845445845, p: 0.3733382026354472, r: 0.35214420325226253
Validation loss: 2.143, f1: 0.22893569618463516, p: 0.23847205191850662, r: 0.23860371112823486
--------------------------------------------------
[INFO]: Epoch 12 of 12
Training


100%|██████████| 18/18 [00:01<00:00, 17.22it/s]


Validation


100%|██████████| 2/2 [00:00<00:00, 26.46it/s]

Training loss: 1.764, f1: 0.3517172535260518, p: 0.37306736409664154, r: 0.36469634705119663
Validation loss: 2.142, f1: 0.23503126949071884, p: 0.25170814990997314, r: 0.2553534358739853
--------------------------------------------------
TRAINING COMPLETE





In [17]:
torch.save(model.state_dict(), f"./weights/base")

### Pretrained feature extractor

In [18]:
model = OnlineModel().encoder
model.load_state_dict(torch.load("../BYOL/pretrained_feature_extractors/feature_extractor_20"))
model.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [19]:
# # Freeze all layers
# for param in model.parameters():
#     param.requires_grad = False

# # Unfreeze last layer
# for param in model.fc.parameters():
#     param.requires_grad = True

In [20]:
train_loss_pretrained, valid_loss_pretrained = [], []
f1_train_pretrained, f1_valid_pretrained = [], []
p_train_pretrained, p_valid_pretrained = [], []
r_train_pretrained, r_valid_pretrained = [], []

# Start the training.
for epoch in range(epochs):
    model.train()
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")

    print('Training')
    train_epoch_loss = 0.0
    counter = 0
    train_f1 = 0.0
    train_p = 0.0
    train_r = 0.0
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
        counter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # f1_metric.update(outputs, labels)
        # p_metric.update(outputs, labels)
        # r_metric.update(outputs, labels)
        train_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
        train_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
        train_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
        loss.backward()
        optimizer.step()
        # print statistics
        train_epoch_loss += loss.item()

    train_epoch_loss = train_epoch_loss / counter
    train_f1 = train_f1 / counter
    train_p = train_p / counter
    train_r = train_r / counter

    model.eval()
    print("Validation")
    valid_epoch_loss = 0
    counter = 0
    val_f1 = 0.0
    val_p = 0.0
    val_r = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(val_loader), total=len(val_loader)):
            counter += 1
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            # forward + backward + optimize
            outputs = model(inputs)
            # f1_metric.update(outputs, labels)
            # p_metric.update(outputs, labels)
            # r_metric.update(outputs, labels)
            val_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
            val_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
            val_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
            loss = criterion(outputs, labels)
            # print statistics
            valid_epoch_loss += loss.item()

    valid_epoch_loss = valid_epoch_loss / counter
    val_f1 = val_f1 / counter
    val_p = val_p / counter
    val_r = val_r / counter

    train_loss_pretrained.append(train_epoch_loss)
    valid_loss_pretrained.append(valid_epoch_loss)

    f1_train_pretrained.append(train_f1)
    f1_valid_pretrained.append(val_f1)

    p_train_pretrained.append(train_p)
    p_valid_pretrained.append(val_p)

    r_train_pretrained.append(train_r)
    r_valid_pretrained.append(val_r)

    print(f"Training loss: {train_epoch_loss:.3f}, f1: {train_f1}, p: {train_p}, r: {train_r}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, f1: {val_f1}, p: {val_p}, r: {val_r}")
    print('-'*50)
        
print('TRAINING COMPLETE')


[INFO]: Epoch 1 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.74it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.35it/s]


Training loss: 2.357, f1: 0.04024852791594134, p: 0.04881262000546687, r: 0.10775816233621703
Validation loss: 2.399, f1: 0.040535902604460716, p: 0.04894971288740635, r: 0.10632399842143059
--------------------------------------------------
[INFO]: Epoch 2 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.84it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.85it/s]


Training loss: 2.347, f1: 0.0463479385814733, p: 0.05908137347756161, r: 0.11810547113418579
Validation loss: 2.353, f1: 0.059190401807427406, p: 0.06246899627149105, r: 0.1228143647313118
--------------------------------------------------
[INFO]: Epoch 3 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.85it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.69it/s]


Training loss: 2.303, f1: 0.05979946214291784, p: 0.08910970480388238, r: 0.12106376265486081
Validation loss: 2.328, f1: 0.06216698698699474, p: 0.05312344804406166, r: 0.12902161106467247
--------------------------------------------------
[INFO]: Epoch 4 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.83it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.80it/s]


Training loss: 2.310, f1: 0.0550220383124219, p: 0.07064732660849889, r: 0.12648924440145493
Validation loss: 2.346, f1: 0.05432921648025513, p: 0.0884223710745573, r: 0.11786853522062302
--------------------------------------------------
[INFO]: Epoch 5 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.83it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.78it/s]


Training loss: 2.301, f1: 0.06044065631512138, p: 0.09038944195749031, r: 0.12818620436721379
Validation loss: 2.308, f1: 0.05009664595127106, p: 0.04569598566740751, r: 0.13499633222818375
--------------------------------------------------
[INFO]: Epoch 6 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.84it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.91it/s]


Training loss: 2.294, f1: 0.07221663350032435, p: 0.09643986769434479, r: 0.13879791647195816
Validation loss: 2.310, f1: 0.06164790131151676, p: 0.04557429999113083, r: 0.13034038618206978
--------------------------------------------------
[INFO]: Epoch 7 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.84it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.88it/s]


Training loss: 2.319, f1: 0.06086772581976321, p: 0.07021175681923826, r: 0.12774895835253927
Validation loss: 2.321, f1: 0.01680778292939067, p: 0.009282146580517292, r: 0.10000000149011612
--------------------------------------------------
[INFO]: Epoch 8 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.84it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.83it/s]


Training loss: 2.312, f1: 0.07033487181696627, p: 0.08284056525573963, r: 0.13837124821212557
Validation loss: 2.337, f1: 0.06424254179000854, p: 0.05818105861544609, r: 0.12228836864233017
--------------------------------------------------
[INFO]: Epoch 9 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.84it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.91it/s]


Training loss: 2.311, f1: 0.07585275338755713, p: 0.09132749038851923, r: 0.1376114942961269
Validation loss: 2.331, f1: 0.05191346816718578, p: 0.04798607528209686, r: 0.12351730093359947
--------------------------------------------------
[INFO]: Epoch 10 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.85it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.59it/s]


Training loss: 2.292, f1: 0.07315021339390013, p: 0.10450689122080803, r: 0.14365431583589977
Validation loss: 2.332, f1: 0.06460334733128548, p: 0.061460599303245544, r: 0.13246093317866325
--------------------------------------------------
[INFO]: Epoch 11 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.84it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.72it/s]


Training loss: 2.309, f1: 0.0628015057494243, p: 0.07648795605119732, r: 0.13573211638463867
Validation loss: 2.357, f1: 0.05098233371973038, p: 0.06255516409873962, r: 0.11834509670734406
--------------------------------------------------
[INFO]: Epoch 12 of 12
Training


100%|██████████| 18/18 [00:04<00:00,  3.85it/s]


Validation


100%|██████████| 2/2 [00:00<00:00,  9.56it/s]

Training loss: 2.288, f1: 0.06886826517681281, p: 0.09209533677332932, r: 0.13297979409495989
Validation loss: 2.309, f1: 0.06492675840854645, p: 0.0791527982801199, r: 0.12506486102938652
--------------------------------------------------
TRAINING COMPLETE





In [21]:
torch.save(model.state_dict(), f"./weights/pretrained")

### Plots

In [22]:
plt.plot(f1_train_pretrained, label='f1_train_pretrained')
plt.plot(f1_valid_pretrained, label='f1_valid_pretrained')
plt.plot(f1_train, label='f1_train')
plt.plot(f1_valid, label='f1_valid')

plt.legend()
plt.savefig(f"./plots/f1", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>

In [23]:
plt.plot(p_train_pretrained, label='p_train_pretrained')
plt.plot(p_valid_pretrained, label='p_valid_pretrained')
plt.plot(p_train, label='p_train')
plt.plot(p_valid, label='p_valid')

plt.legend()
plt.savefig(f"./plots/precision", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>

In [24]:
plt.plot(r_train_pretrained, label='r_train_pretrained')
plt.plot(r_valid_pretrained, label='r_valid_pretrained')
plt.plot(r_train, label='r_train')
plt.plot(r_valid, label='r_valid')

plt.legend()
plt.savefig(f"./plots/recall", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>

In [25]:
plt.plot(train_loss_pretrained, label='train_loss_pretrained')
plt.plot(valid_loss_pretrained, label='train_loss_pretrained')
plt.plot(train_loss, label='train_loss')
plt.plot(valid_loss, label='train_loss')

plt.legend()
plt.savefig(f"./plots/loss", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>