In [1]:
import torch
from torchvision.models.detection.mask_rcnn import resnet50
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.utils.data import random_split
from tqdm import tqdm
from torcheval.metrics import MulticlassF1Score, MulticlassPrecision
from torcheval.metrics.classification import MulticlassRecall
from models import OnlineModel
import matplotlib.pyplot as plt

from torcheval.metrics.functional import multiclass_recall, multiclass_precision, multiclass_f1_score

In [2]:
plt.style.use("ggplot")

In [3]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [4]:
batch_size = 132
epochs = 12

In [5]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# CIFAR10 dataset split.
dataset_train = datasets.CIFAR10(
        root='data',
        train=True,
        download=True,
        transform=transform,
    )
dataset_train, dataset_val = random_split(dataset_train, [45000, 5000])
dataset_train, _ = random_split(dataset_train, [22500, 22500])
dataset_val, _ = random_split(dataset_val, [2500, 2500])
dataset_test = datasets.CIFAR10(
        root='data',
        train=False,
        download=True,
        transform=transform,
    )
    # Create data loaders.
train_loader = DataLoader(
        dataset_train, 
        batch_size=batch_size,
        shuffle=True
    )
val_loader = DataLoader(
        dataset_val, 
        batch_size=batch_size,
        shuffle=True
    )
test_loader = DataLoader(
        dataset_test, 
        batch_size=batch_size,
        shuffle=True
    )

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# print(multiclass_f1_score(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 1, 1, 1]), num_classes=2, average="micro").item())
# print(multiclass_precision(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 1, 1, 1]), num_classes=2, average="micro").item())
# print(multiclass_recall(torch.tensor([1, 1, 1, 1]), torch.tensor([0, 1, 1, 1]), num_classes=2, average="micro").item())

### Baseline

In [7]:
model = resnet50()
model.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [8]:
# f1_metric = MulticlassF1Score(num_classes=10)
# p_metric = MulticlassPrecision(num_classes=10)
# r_metric = MulticlassRecall(num_classes=10)

train_loss, valid_loss = [], []
f1_train, f1_valid = [], []
p_train, p_valid = [], []
r_train, r_valid = [], []

# Start the training.
for epoch in range(epochs):
    model.train()
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")

    print('Training')
    train_epoch_loss = 0.0
    counter = 0
    train_f1 = 0.0
    train_p = 0.0
    train_r = 0.0
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
        counter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(inputs)

        loss = criterion(outputs, labels)
        # f1_metric.update(outputs, labels)
        # p_metric.update(outputs, labels)
        # r_metric.update(outputs, labels)
        # train_f1 += f1_metric.compute().item()
        # train_p += p_metric.compute().item()
        # train_r += r_metric.compute().item()
        train_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
        train_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
        train_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
        loss.backward()
        optimizer.step()
        # print statistics
        train_epoch_loss += loss.item()

    train_epoch_loss = train_epoch_loss / counter
    train_f1 = train_f1 / counter
    train_p = train_p / counter
    train_r = train_r / counter

    model.eval()
    print("Validation")
    valid_epoch_loss = 0
    counter = 0
    val_f1 = 0.0
    val_p = 0.0
    val_r = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(val_loader), total=len(val_loader)):
            counter += 1
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            # forward + backward + optimize
            outputs = model(inputs)

            # f1_metric.update(outputs, labels)
            # p_metric.update(outputs, labels)
            # r_metric.update(outputs, labels)
            val_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
            val_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
            val_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
            loss = criterion(outputs, labels)
            # print statistics
            valid_epoch_loss += loss.item()

    valid_epoch_loss = valid_epoch_loss / counter
    val_f1 = val_f1 / counter
    val_p = val_p / counter
    val_r = val_r / counter

    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)

    f1_train.append(train_f1)
    f1_valid.append(val_f1)

    p_train.append(train_p)
    p_valid.append(val_p)

    r_train.append(train_r)
    r_valid.append(val_r)

    print(f"Training loss: {train_epoch_loss:.3f}, f1: {train_f1}, p: {train_p}, r: {train_r}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, f1: {val_f1}, p: {val_p}, r: {val_r}")
    print('-'*50)
        
print('TRAINING COMPLETE')

[INFO]: Epoch 1 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 24.44it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 44.78it/s]


Training loss: 2.277, f1: 0.15967786473314666, p: 0.18232630589726376, r: 0.18164236281524626
Validation loss: 2.064, f1: 0.2377324425860455, p: 0.2609345183560723, r: 0.26057056768944387
--------------------------------------------------
[INFO]: Epoch 2 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 26.36it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 50.60it/s]


Training loss: 1.971, f1: 0.2658039810713272, p: 0.2920386440049835, r: 0.2873434432243046
Validation loss: 1.963, f1: 0.299129298643062, p: 0.3550274074077606, r: 0.3296929684124495
--------------------------------------------------
[INFO]: Epoch 3 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 26.39it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 48.37it/s]


Training loss: 1.828, f1: 0.3217411463895039, p: 0.3502986650717886, r: 0.34279891099149024
Validation loss: 1.799, f1: 0.3307813751070123, p: 0.36953987573322494, r: 0.36009037180950765
--------------------------------------------------
[INFO]: Epoch 4 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 26.22it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 49.68it/s]


Training loss: 1.711, f1: 0.37190921020786666, p: 0.40620695115530003, r: 0.3923878190461655
Validation loss: 1.806, f1: 0.3598628091184716, p: 0.3906688298049726, r: 0.387567433871721
--------------------------------------------------
[INFO]: Epoch 5 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 25.75it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 46.06it/s]


Training loss: 1.610, f1: 0.40808809221836556, p: 0.43983984150384603, r: 0.4288514496987326
Validation loss: 1.804, f1: 0.3588969566320118, p: 0.44709938921426473, r: 0.37335120690496343
--------------------------------------------------
[INFO]: Epoch 6 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 26.29it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 46.69it/s]


Training loss: 1.512, f1: 0.4452161916166718, p: 0.47245932880200836, r: 0.4631484319940645
Validation loss: 1.663, f1: 0.39680272027065877, p: 0.42419762517276566, r: 0.4184788073364057
--------------------------------------------------
[INFO]: Epoch 7 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 26.19it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 49.88it/s]


Training loss: 1.411, f1: 0.48025719969593295, p: 0.5040956885493987, r: 0.49761131545256454
Validation loss: 1.702, f1: 0.4005860639245887, p: 0.43500926463227524, r: 0.410144893746627
--------------------------------------------------
[INFO]: Epoch 8 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 26.15it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 49.53it/s]


Training loss: 1.315, f1: 0.5167572859086489, p: 0.540629065524765, r: 0.5337334114905686
Validation loss: 1.655, f1: 0.4231670404735364, p: 0.44568766731964915, r: 0.43830866248984085
--------------------------------------------------
[INFO]: Epoch 9 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 26.20it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 50.13it/s]


Training loss: 1.210, f1: 0.558632936568288, p: 0.5792612123210528, r: 0.5734564862404651
Validation loss: 1.671, f1: 0.42926897971253647, p: 0.44193595020394577, r: 0.44244615968905
--------------------------------------------------
[INFO]: Epoch 10 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 26.37it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 50.32it/s]


Training loss: 1.122, f1: 0.5871852545710335, p: 0.6058851295744466, r: 0.6032781602694974
Validation loss: 1.686, f1: 0.44081987205304596, p: 0.4596732823472274, r: 0.4501904139393254
--------------------------------------------------
[INFO]: Epoch 11 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 26.05it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 50.14it/s]


Training loss: 1.022, f1: 0.6272597680663505, p: 0.6461585106556875, r: 0.6410943328985694
Validation loss: 1.757, f1: 0.4266056904667302, p: 0.4450730179485522, r: 0.4364403564678995
--------------------------------------------------
[INFO]: Epoch 12 of 12
Training


100%|██████████| 171/171 [00:06<00:00, 26.34it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 45.45it/s]


Training loss: 0.910, f1: 0.6701274217917905, p: 0.6838088614201685, r: 0.6840545343376739
Validation loss: 1.849, f1: 0.41023112598218414, p: 0.42431852848906265, r: 0.41876817847553055
--------------------------------------------------
TRAINING COMPLETE


In [9]:
torch.save(model.state_dict(), f"./weights/base")

### Pretrained feature extractor

In [27]:
model = OnlineModel().encoder
model.load_state_dict(torch.load("../BYOL/pretrained_feature_extractors/feature_extractor_20"))
model.fc = nn.Linear(in_features=2048, out_features=10, bias=True)
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [28]:
# # Freeze all layers
# for param in model.parameters():
#     param.requires_grad = False

# # Unfreeze last layer
# for param in model.fc.parameters():
#     param.requires_grad = True

In [29]:
train_loss_pretrained, valid_loss_pretrained = [], []
f1_train_pretrained, f1_valid_pretrained = [], []
p_train_pretrained, p_valid_pretrained = [], []
r_train_pretrained, r_valid_pretrained = [], []

# Start the training.
for epoch in range(epochs):
    model.train()
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")

    print('Training')
    train_epoch_loss = 0.0
    counter = 0
    train_f1 = 0.0
    train_p = 0.0
    train_r = 0.0
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
        counter += 1
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # f1_metric.update(outputs, labels)
        # p_metric.update(outputs, labels)
        # r_metric.update(outputs, labels)
        train_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
        train_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
        train_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
        loss.backward()
        optimizer.step()
        # print statistics
        train_epoch_loss += loss.item()

    train_epoch_loss = train_epoch_loss / counter
    train_f1 = train_f1 / counter
    train_p = train_p / counter
    train_r = train_r / counter

    model.eval()
    print("Validation")
    valid_epoch_loss = 0
    counter = 0
    val_f1 = 0.0
    val_p = 0.0
    val_r = 0.0
    with torch.no_grad():
        for i, data in tqdm(enumerate(val_loader), total=len(val_loader)):
            counter += 1
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            # forward + backward + optimize
            outputs = model(inputs)
            # f1_metric.update(outputs, labels)
            # p_metric.update(outputs, labels)
            # r_metric.update(outputs, labels)
            val_f1 += multiclass_f1_score(outputs, labels, num_classes=10, average="macro").item()
            val_p += multiclass_precision(outputs, labels, num_classes=10, average="macro").item()
            val_r += multiclass_recall(outputs, labels, num_classes=10, average="macro").item()
            loss = criterion(outputs, labels)
            # print statistics
            valid_epoch_loss += loss.item()

    valid_epoch_loss = valid_epoch_loss / counter
    val_f1 = val_f1 / counter
    val_p = val_p / counter
    val_r = val_r / counter

    train_loss_pretrained.append(train_epoch_loss)
    valid_loss_pretrained.append(valid_epoch_loss)

    f1_train_pretrained.append(train_f1)
    f1_valid_pretrained.append(val_f1)

    p_train_pretrained.append(train_p)
    p_valid_pretrained.append(val_p)

    r_train_pretrained.append(train_r)
    r_valid_pretrained.append(val_r)

    print(f"Training loss: {train_epoch_loss:.3f}, f1: {train_f1}, p: {train_p}, r: {train_r}")
    print(f"Validation loss: {valid_epoch_loss:.3f}, f1: {val_f1}, p: {val_p}, r: {val_r}")
    print('-'*50)
        
print('TRAINING COMPLETE')


[INFO]: Epoch 1 of 12
Training


100%|██████████| 171/171 [00:23<00:00,  7.38it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 19.57it/s]


Training loss: 2.344, f1: 0.052479198126242174, p: 0.06121627670271616, r: 0.1214033098161569
Validation loss: 2.344, f1: 0.05501940375880191, p: 0.058860479118792636, r: 0.12648913695624
--------------------------------------------------
[INFO]: Epoch 2 of 12
Training


100%|██████████| 171/171 [00:22<00:00,  7.63it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 19.41it/s]


Training loss: 2.325, f1: 0.06206332596988357, p: 0.07082748656536926, r: 0.13234263196674703
Validation loss: 2.299, f1: 0.09166883265501574, p: 0.10100280415070684, r: 0.15183730384236888
--------------------------------------------------
[INFO]: Epoch 3 of 12
Training


100%|██████████| 171/171 [00:22<00:00,  7.48it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 19.37it/s]


Training loss: 2.314, f1: 0.07213638998480917, p: 0.07509509750985001, r: 0.14077807379047774
Validation loss: 2.300, f1: 0.11028131528904564, p: 0.1339574309163972, r: 0.16038350211946586
--------------------------------------------------
[INFO]: Epoch 4 of 12
Training


100%|██████████| 171/171 [00:22<00:00,  7.72it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 19.36it/s]


Training loss: 2.300, f1: 0.07589666527962824, p: 0.08532858187365427, r: 0.14335826832300042
Validation loss: 2.275, f1: 0.0704966771759485, p: 0.06718317860443342, r: 0.17032013441386976
--------------------------------------------------
[INFO]: Epoch 5 of 12
Training


100%|██████████| 171/171 [00:22<00:00,  7.60it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 19.25it/s]


Training loss: 2.280, f1: 0.08571522749350434, p: 0.09677466481096214, r: 0.14968316474853202
Validation loss: 2.260, f1: 0.10548561654592815, p: 0.11674446083213154, r: 0.16440760775616295
--------------------------------------------------
[INFO]: Epoch 6 of 12
Training


100%|██████████| 171/171 [00:22<00:00,  7.62it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 19.71it/s]


Training loss: 2.277, f1: 0.08715400154217642, p: 0.09709458608637776, r: 0.1507808120801435
Validation loss: 2.275, f1: 0.07598483738930602, p: 0.08491428500335467, r: 0.13554035950648158
--------------------------------------------------
[INFO]: Epoch 7 of 12
Training


100%|██████████| 171/171 [00:22<00:00,  7.63it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 19.93it/s]


Training loss: 2.251, f1: 0.10029977509937091, p: 0.11701521971290223, r: 0.16245700615017036
Validation loss: 2.294, f1: 0.09490355927693217, p: 0.11063489768850177, r: 0.15820480569412834
--------------------------------------------------
[INFO]: Epoch 8 of 12
Training


100%|██████████| 171/171 [00:22<00:00,  7.68it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 19.63it/s]


Training loss: 2.267, f1: 0.09851713894664893, p: 0.10954755741819652, r: 0.16437925128211753
Validation loss: 2.224, f1: 0.0968161754702267, p: 0.07788651240499396, r: 0.17265939869378744
--------------------------------------------------
[INFO]: Epoch 9 of 12
Training


100%|██████████| 171/171 [00:22<00:00,  7.53it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 19.84it/s]


Training loss: 2.246, f1: 0.10399737562003888, p: 0.11594073546298764, r: 0.16574484231876352
Validation loss: 2.239, f1: 0.1298699186820733, p: 0.13958551734685898, r: 0.18682367158563515
--------------------------------------------------
[INFO]: Epoch 10 of 12
Training


100%|██████████| 171/171 [00:22<00:00,  7.52it/s]


Validation


100%|██████████| 19/19 [00:01<00:00, 17.00it/s]


Training loss: 2.253, f1: 0.10448776791144533, p: 0.11430787899645797, r: 0.16733678434675897
Validation loss: 2.244, f1: 0.14677839451714567, p: 0.16801856301332774, r: 0.1718111116635172
--------------------------------------------------
[INFO]: Epoch 11 of 12
Training


100%|██████████| 171/171 [00:22<00:00,  7.53it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 19.74it/s]


Training loss: 2.225, f1: 0.11320835317087452, p: 0.12661587537337116, r: 0.1743916268977854
Validation loss: 2.203, f1: 0.11329112241142675, p: 0.12726050536883504, r: 0.16821127502541794
--------------------------------------------------
[INFO]: Epoch 12 of 12
Training


100%|██████████| 171/171 [00:22<00:00,  7.60it/s]


Validation


100%|██████████| 19/19 [00:00<00:00, 19.99it/s]


Training loss: 2.216, f1: 0.11192384985281013, p: 0.12184209879814532, r: 0.1738760229962611
Validation loss: 2.314, f1: 0.09945314946143251, p: 0.08922977294576795, r: 0.1774150740943457
--------------------------------------------------
TRAINING COMPLETE


In [30]:
torch.save(model.state_dict(), f"./weights/pretrained")

### Plots

In [31]:
plt.plot(f1_train_pretrained, label='f1_train_pretrained')
plt.plot(f1_valid_pretrained, label='f1_valid_pretrained')
plt.plot(f1_train, label='f1_train')
plt.plot(f1_valid, label='f1_valid')

plt.legend()
plt.savefig(f"./plots/f1", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>

In [32]:
plt.plot(p_train_pretrained, label='p_train_pretrained')
plt.plot(p_valid_pretrained, label='p_valid_pretrained')
plt.plot(p_train, label='p_train')
plt.plot(p_valid, label='p_valid')

plt.legend()
plt.savefig(f"./plots/precision", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>

In [33]:
plt.plot(r_train_pretrained, label='r_train_pretrained')
plt.plot(r_valid_pretrained, label='r_valid_pretrained')
plt.plot(r_train, label='r_train')
plt.plot(r_valid, label='r_valid')

plt.legend()
plt.savefig(f"./plots/recall", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>

In [34]:
plt.plot(train_loss_pretrained, label='train_loss_pretrained')
plt.plot(valid_loss_pretrained, label='train_loss_pretrained')
plt.plot(train_loss, label='train_loss')
plt.plot(valid_loss, label='train_loss')

plt.legend()
plt.savefig(f"./plots/loss", bbox_inches='tight')
plt.clf()

<Figure size 640x480 with 0 Axes>