In [1]:
import torch
import config
from utils import (
    get_model_object_detector,
    collate_fn,
    get_transform,
    myOwnDataset,
    save_model,
)
from pathlib import Path

In [2]:

my_dataset = myOwnDataset(
    root=config.train_data_dir, annotation=config.train_coco, transforms=get_transform()
)

loading annotations into memory...
Done (t=0.14s)
creating index...
index created!


In [3]:
data_loader = torch.utils.data.DataLoader(
    my_dataset,
    batch_size=config.train_batch_size,
    shuffle=config.train_shuffle_dl,
    num_workers=config.num_workers_dl,
    collate_fn=collate_fn,
)
data_loader

<torch.utils.data.dataloader.DataLoader at 0x2e388438b50>

In [4]:
# select device (whether GPU or CPU)
device = torch.device("cpu")

In [5]:
for imgs, annotations in data_loader:
    if len(annotations[0]['boxes']) == 0:
        continue

    imgs = list(img.to(device) for img in imgs)
    annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]

In [6]:
model = get_model_object_detector(config.num_classes)
model.to(device)

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to C:\Users\billy/.cache\torch\hub\checkpoints\fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:03<00:00, 50.9MB/s] 


FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [7]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params, lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay
)

len_dataloader = len(data_loader)

Path("result/").mkdir(parents=True, exist_ok=True)

In [8]:
for epoch in range(config.num_epochs):
    print(f"Epoch: {epoch}/{config.num_epochs}")
    model.train()
    i = 0
    for imgs, annotations in data_loader:
        if i > config.num_images:
            break
        if len(annotations[0]['boxes']) == 0:
            continue

        i += 1
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        loss_dict = model(imgs, annotations)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        print(f"Iteration: {i}/{len_dataloader}, Loss: {losses}")

# Save the trained model
save_model('small', config.num_epochs, model, optimizer)

Epoch: 0/5
Iteration: 1/320, Loss: 3.7390995025634766
Iteration: 2/320, Loss: 2.6616415977478027
Iteration: 3/320, Loss: 2.7483763694763184
Iteration: 4/320, Loss: 1.6395183801651
Iteration: 5/320, Loss: 2.2716922760009766
Iteration: 6/320, Loss: 1.8055933713912964
Iteration: 7/320, Loss: 1.596710443496704
Iteration: 8/320, Loss: 1.7667930126190186
Iteration: 9/320, Loss: 1.5801701545715332
Iteration: 10/320, Loss: 1.6214178800582886
Iteration: 11/320, Loss: 1.6195156574249268
Iteration: 12/320, Loss: 1.5077381134033203
Iteration: 13/320, Loss: 1.6205227375030518
Iteration: 14/320, Loss: 1.4873040914535522
Iteration: 15/320, Loss: 1.6825674772262573
Iteration: 16/320, Loss: 1.3572944402694702
Iteration: 17/320, Loss: 1.5121079683303833
Iteration: 18/320, Loss: 1.3428581953048706
Iteration: 19/320, Loss: 1.739819884300232
Iteration: 20/320, Loss: 1.8016616106033325
Iteration: 21/320, Loss: 1.6763185262680054
Iteration: 22/320, Loss: 1.4962950944900513
Iteration: 23/320, Loss: 1.76019394