In [1]:
import torch
from torchvision.models.detection.mask_rcnn import resnet50
import torch.optim as optim
import torch.nn as nn
import json
import torchvision

from utils import train, validation, plot_loss, plot_metrics, plot_class_metrics
from data_utils import get_dataloaders

import datetime
import os
import wandb
import numpy as np 

In [2]:
# open config file
with open('../config.json') as f:
    config = json.load(f) 

batch_size = config["batch_size"]
epochs = config["epochs"]
model_names = config["model"]
learning_rate = config["lr"]
train_loader, val_loader, _ = get_dataloaders(batch_size)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
config["n_classes"] = len(classes)
criterion = nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified


In [3]:
# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")
config["device"] = device

Using cuda device


In [4]:
model = resnet50().to(device)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

In [5]:
# wandb.init(
#     # set the wandb project where this run will be logged
#     project="homework_1",
#     # track hyperparameters and run metadata
#     config=config
# )

In [6]:
train_loss, valid_loss = [], []
valid_precision, valid_recall, valid_f1 = [], [], []
valid_precision_classes, valid_recall_classes, valid_f1_classes = [], [], []
# Start the training.
for epoch in range(epochs):
    print(f"[INFO]: Epoch {epoch+1} of {epochs}")
    train_epoch_loss = train(
        train_loader, 
        optimizer, 
        criterion,
        model,
        device
    )
    valid_epoch_loss, p, r, f1, p_classes, r_classes, f1_classes = validation( 
        val_loader, 
        criterion,
        model,
        device
    )
    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)

    valid_precision.append(p)
    valid_recall.append(r)
    valid_f1.append(f1)

    valid_precision_classes.append(p_classes)
    valid_recall_classes.append(r_classes)
    valid_f1_classes.append(f1_classes)

    # wandb.log({"train_loss": train_epoch_loss, "valid_loss": valid_epoch_loss})
    # wandb.log({"precision": p, "recall": r, "f1_score": f1})
    # for class_id in range(config["n_classes"]):
    #     wandb.log({f'val/{classes[class_id]}_precision': p_classes[class_id]})
    #     wandb.log({f'val/{classes[class_id]}_recall': r_classes[class_id]})
    #     wandb.log({f'val/{classes[class_id]}_f1': f1_classes[class_id]})

    print(f"Training loss: {train_epoch_loss:.3f}")
    print(f"Validation loss: {valid_epoch_loss:.3f}")
    print('-'*50)
        
print('TRAINING COMPLETE')

[INFO]: Epoch 1 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.80it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 33.57it/s]


Training loss: 2.212
Validation loss: 1.892
--------------------------------------------------
[INFO]: Epoch 2 of 20
Training


100%|██████████| 352/352 [00:12<00:00, 28.38it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 36.23it/s]


Training loss: 1.818
Validation loss: 1.732
--------------------------------------------------
[INFO]: Epoch 3 of 20
Training


100%|██████████| 352/352 [00:12<00:00, 27.44it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 30.96it/s]


Training loss: 1.654
Validation loss: 1.623
--------------------------------------------------
[INFO]: Epoch 4 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.16it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 30.75it/s]


Training loss: 1.527
Validation loss: 1.579
--------------------------------------------------
[INFO]: Epoch 5 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.86it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 34.43it/s]


Training loss: 1.423
Validation loss: 1.548
--------------------------------------------------
[INFO]: Epoch 6 of 20
Training


100%|██████████| 352/352 [00:12<00:00, 28.62it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 36.98it/s]


Training loss: 1.327
Validation loss: 1.452
--------------------------------------------------
[INFO]: Epoch 7 of 20
Training


100%|██████████| 352/352 [00:12<00:00, 28.57it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 33.08it/s]


Training loss: 1.248
Validation loss: 1.451
--------------------------------------------------
[INFO]: Epoch 8 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.11it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 30.59it/s]


Training loss: 1.155
Validation loss: 1.474
--------------------------------------------------
[INFO]: Epoch 9 of 20
Training


100%|██████████| 352/352 [00:14<00:00, 24.70it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 32.07it/s]


Training loss: 1.072
Validation loss: 1.475
--------------------------------------------------
[INFO]: Epoch 10 of 20
Training


100%|██████████| 352/352 [00:12<00:00, 27.24it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 33.51it/s]


Training loss: 0.994
Validation loss: 1.513
--------------------------------------------------
[INFO]: Epoch 11 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.29it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 32.69it/s]


Training loss: 0.913
Validation loss: 1.519
--------------------------------------------------
[INFO]: Epoch 12 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.60it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 33.06it/s]


Training loss: 0.829
Validation loss: 1.550
--------------------------------------------------
[INFO]: Epoch 13 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.92it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 27.70it/s]


Training loss: 0.750
Validation loss: 1.627
--------------------------------------------------
[INFO]: Epoch 14 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.61it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 28.91it/s]


Training loss: 0.669
Validation loss: 1.700
--------------------------------------------------
[INFO]: Epoch 15 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.10it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 34.06it/s]


Training loss: 0.593
Validation loss: 1.752
--------------------------------------------------
[INFO]: Epoch 16 of 20
Training


100%|██████████| 352/352 [00:12<00:00, 27.11it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 32.46it/s]


Training loss: 0.526
Validation loss: 1.816
--------------------------------------------------
[INFO]: Epoch 17 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.80it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 29.24it/s]


Training loss: 0.470
Validation loss: 1.916
--------------------------------------------------
[INFO]: Epoch 18 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.54it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 30.66it/s]


Training loss: 0.427
Validation loss: 2.000
--------------------------------------------------
[INFO]: Epoch 19 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.60it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 31.02it/s]


Training loss: 0.386
Validation loss: 2.000
--------------------------------------------------
[INFO]: Epoch 20 of 20
Training


100%|██████████| 352/352 [00:13<00:00, 26.89it/s]


Validation


100%|██████████| 40/40 [00:01<00:00, 31.89it/s]

Training loss: 0.347
Validation loss: 2.082
--------------------------------------------------
TRAINING COMPLETE





In [7]:

result_path = "./results/SL/"

model_path = os.path.join(result_path, "model")
torch.save(model.state_dict(), model_path)

with open(os.path.join(result_path, "config.json"), "w") as f:
    json.dump(config , f)

plot_loss(train_loss, valid_loss, "epochs", "loss", os.path.join(result_path, "validation_loss"))
plot_metrics(valid_precision, valid_recall, valid_f1, "epochs", "validation_metrics", os.path.join(result_path, "validation_metrics"))

for class_id in range(config["n_classes"]):
    plot_class_metrics(np.array(valid_precision_classes).T.tolist()[class_id],
                       np.array(valid_recall_classes).T.tolist()[class_id],
                       np.array(valid_f1_classes).T.tolist()[class_id],
                       classes[class_id],
                       "epochs",
                       f"{classes[class_id]}_metrcis",
                       os.path.join(result_path, f"{classes[class_id]}_metrcis"))

<Figure size 640x480 with 0 Axes>

In [8]:
# wandb.finish()

In [9]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 