In [1]:
import torch
from torchvision.models.detection.mask_rcnn import resnet50
import torch.optim as optim
import torch.nn as nn
import json
import torchvision

from utils import train, validation, plot_loss, plot_metrics, plot_class_metrics
from data_utils import get_dataloaders

import datetime
import os
import wandb
import numpy as np 

In [2]:
class MLP(nn.Module):
    def __init__(self, input_dim=2048) -> None:
        super(MLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(),
            nn.Linear(4096, 256)
        )
    
    def forward(self, x):
        return self.net(x)

In [3]:
class OnlineModel(nn.Module):
    def __init__(self) -> None:
        super(OnlineModel, self).__init__()
        self.encoder = torchvision.models.resnet50()
        self.encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.encoder.fc = nn.Identity()

        self.represent = MLP()

        self.predictor = MLP(input_dim=256)
    
    def forward(self, x):
        x = self.encoder(x)
        x = self.represent(x)
        x = self.predictor(x)
        return x


In [4]:
# open config file
with open('../config.json') as f:
    config = json.load(f) 

batch_size = config["batch_size"]
epochs = config["epochs"]
model_names = config["model"]
learning_rate = config["lr"]
train_loader, val_loader, _ = get_dataloaders(batch_size)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
config["n_classes"] = len(classes)
criterion = nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")
config["device"] = device

Using cuda device


In [6]:
model = OnlineModel()
model.load_state_dict(torch.load("./results/SSL/pretrained/online"))
model.encoder.fc = nn.Linear(in_features=2048, out_features=1000, bias=True)
model = model.encoder
model = model.to(device)

In [7]:
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

In [8]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [9]:
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="homework_1",
#     # track hyperparameters and run metadata
#     config=config
# )

In [10]:
train_loss, valid_loss = [], []
valid_precision, valid_recall, valid_f1 = [], [], []
valid_precision_classes, valid_recall_classes, valid_f1_classes = [], [], []
# Start the training.
for epoch in range(epochs):
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")
    train_epoch_loss = train(
        train_loader, 
        optimizer, 
        criterion,
        model,
        device
    )
    valid_epoch_loss, p, r, f1, p_classes, r_classes, f1_classes = validation( 
        val_loader, 
        criterion,
        model,
        device
    )
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)

    valid_precision.append(p)
    valid_recall.append(r)
    valid_f1.append(f1)

    valid_precision_classes.append(p_classes)
    valid_recall_classes.append(r_classes)
    valid_f1_classes.append(f1_classes)

    # wandb.log({"train_loss": train_epoch_loss, "valid_loss": valid_epoch_loss})
    # wandb.log({"precision": p, "recall": r, "f1_score": f1})
    # for class_id in range(config["n_classes"]):
    #     wandb.log({f'val/{classes[class_id]}_precision': p_classes[class_id]})
    #     wandb.log({f'val/{classes[class_id]}_recall': r_classes[class_id]})
    #     wandb.log({f'val/{classes[class_id]}_f1': f1_classes[class_id]})

    print(f"Training loss: {train_epoch_loss:.3f}")
    print(f"Validation loss: {valid_epoch_loss:.3f}")
    print('-'*50)
        
print('TRAINING COMPLETE')

[INFO]: Epoch 1 of 20
Training


100%|██████████| 36/36 [00:03<00:00, 10.42it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 21.38it/s]


Training loss: 3.339
Validation loss: 2.251
--------------------------------------------------
[INFO]: Epoch 2 of 20
Training


100%|██████████| 36/36 [00:02<00:00, 16.60it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 29.31it/s]


Training loss: 2.277
Validation loss: 2.259
--------------------------------------------------
[INFO]: Epoch 3 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 18.85it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 28.71it/s]


Training loss: 2.263
Validation loss: 2.198
--------------------------------------------------
[INFO]: Epoch 4 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 19.87it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 30.65it/s]


Training loss: 2.212
Validation loss: 2.272
--------------------------------------------------
[INFO]: Epoch 5 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 19.54it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 29.17it/s]


Training loss: 2.225
Validation loss: 2.174
--------------------------------------------------
[INFO]: Epoch 6 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 20.79it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 29.14it/s]


Training loss: 2.185
Validation loss: 2.159
--------------------------------------------------
[INFO]: Epoch 7 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 20.09it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 28.49it/s]


Training loss: 2.138
Validation loss: 2.221
--------------------------------------------------
[INFO]: Epoch 8 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 19.93it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 27.72it/s]


Training loss: 2.225
Validation loss: 2.122
--------------------------------------------------
[INFO]: Epoch 9 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 20.78it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 31.84it/s]


Training loss: 2.218
Validation loss: 2.232
--------------------------------------------------
[INFO]: Epoch 10 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 20.46it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 28.77it/s]


Training loss: 2.171
Validation loss: 2.112
--------------------------------------------------
[INFO]: Epoch 11 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 20.22it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 29.91it/s]


Training loss: 2.110
Validation loss: 2.098
--------------------------------------------------
[INFO]: Epoch 12 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 19.90it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 26.73it/s]


Training loss: 2.131
Validation loss: 2.054
--------------------------------------------------
[INFO]: Epoch 13 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 20.00it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 28.41it/s]


Training loss: 2.106
Validation loss: 2.115
--------------------------------------------------
[INFO]: Epoch 14 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 19.99it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 27.45it/s]


Training loss: 2.081
Validation loss: 2.129
--------------------------------------------------
[INFO]: Epoch 15 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 20.08it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 28.71it/s]


Training loss: 2.086
Validation loss: 2.032
--------------------------------------------------
[INFO]: Epoch 16 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 20.71it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 27.50it/s]


Training loss: 2.028
Validation loss: 2.032
--------------------------------------------------
[INFO]: Epoch 17 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 19.96it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 29.85it/s]


Training loss: 2.000
Validation loss: 2.028
--------------------------------------------------
[INFO]: Epoch 18 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 19.95it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 29.16it/s]


Training loss: 2.057
Validation loss: 2.022
--------------------------------------------------
[INFO]: Epoch 19 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 19.90it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 29.26it/s]


Training loss: 2.028
Validation loss: 2.006
--------------------------------------------------
[INFO]: Epoch 20 of 20
Training


100%|██████████| 36/36 [00:01<00:00, 19.73it/s]


Validation


100%|██████████| 4/4 [00:00<00:00, 29.50it/s]


Training loss: 1.990
Validation loss: 1.998
--------------------------------------------------
TRAINING COMPLETE


In [11]:

result_path = "./results/SL_pretrained/"
# os.makedirs(result_path)
model_path = os.path.join(result_path, "model")
torch.save(model.state_dict(), model_path)

with open(os.path.join(result_path, "config.json"), "w") as f:
    json.dump(config , f)

plot_loss(train_loss, valid_loss, "epochs", "loss", os.path.join(result_path, "validation_loss"))
plot_metrics(valid_precision, valid_recall, valid_f1, "epochs", "validation_metrics", os.path.join(result_path, "validation_metrics"))

for class_id in range(config["n_classes"]):
    plot_class_metrics(np.array(valid_precision_classes).T.tolist()[class_id],
                       np.array(valid_recall_classes).T.tolist()[class_id],
                       np.array(valid_f1_classes).T.tolist()[class_id],
                       classes[class_id],
                       "epochs",
                       f"{classes[class_id]}_metrcis",
                       os.path.join(result_path, f"{classes[class_id]}_metrcis"))

<Figure size 640x480 with 0 Axes>

In [12]:
# wandb.finish()