In [1]:
import torch
import config
from utils import (
    get_model_object_detector,
    collate_fn,
    get_transform,
    myOwnDataset,
    save_model,
)
from pathlib import Path

In [2]:

my_dataset = myOwnDataset(
    root=config.train_data_dir, annotation=config.train_coco, transforms=get_transform()
)

loading annotations into memory...
Done (t=0.05s)
creating index...
index created!


In [3]:
data_loader = torch.utils.data.DataLoader(
    my_dataset,
    batch_size=config.train_batch_size,
    shuffle=config.train_shuffle_dl,
    num_workers=config.num_workers_dl,
    collate_fn=collate_fn,
)
data_loader

<torch.utils.data.dataloader.DataLoader at 0x7f9fd844a4a0>

In [4]:
# select device (whether GPU or CPU)
device = torch.device("cpu")

In [5]:
for imgs, annotations in data_loader:
    if len(annotations[0]['boxes']) == 0:
        continue

    imgs = list(img.to(device) for img in imgs)
    annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]

In [6]:
model = get_model_object_detector(config.num_classes)
model.to(device)



FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [7]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params, lr=config.lr, momentum=config.momentum, weight_decay=config.weight_decay
)

len_dataloader = len(data_loader)

Path("result/").mkdir(parents=True, exist_ok=True)

In [8]:
for epoch in range(config.num_epochs):
    print(f"Epoch: {epoch}/{config.num_epochs}")
    model.train()
    i = 0
    for imgs, annotations in data_loader:
        if i > config.num_images:
            break
        if len(annotations[0]['boxes']) == 0:
            continue

        i += 1
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        loss_dict = model(imgs, annotations)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        print(f"Iteration: {i}/{len_dataloader}, Loss: {losses}")

# Save the trained model
save_model('small', config.num_epochs, model, optimizer)

Epoch: 0/5
Iteration: 1/5406, Loss: 1.652025818824768
Iteration: 2/5406, Loss: 0.5345838069915771
Iteration: 3/5406, Loss: 0.44206321239471436
Iteration: 4/5406, Loss: 0.2499985247850418
Iteration: 5/5406, Loss: 0.32561054825782776
Iteration: 6/5406, Loss: 0.30881208181381226
Iteration: 7/5406, Loss: 0.2843503952026367
Iteration: 8/5406, Loss: 1.3471219539642334
Iteration: 9/5406, Loss: 0.30974382162094116
Iteration: 10/5406, Loss: 0.8179190754890442
Iteration: 11/5406, Loss: 0.40990716218948364
Iteration: 12/5406, Loss: 0.25864091515541077
Iteration: 13/5406, Loss: 0.17518705129623413
Iteration: 14/5406, Loss: 0.16952385008335114
Iteration: 15/5406, Loss: 0.2508592903614044
Iteration: 16/5406, Loss: 0.27716881036758423
Iteration: 17/5406, Loss: 0.4696740508079529
Iteration: 18/5406, Loss: 0.27247685194015503
Iteration: 19/5406, Loss: 0.7643455266952515
Iteration: 20/5406, Loss: 0.17919102311134338
Iteration: 21/5406, Loss: 0.17627660930156708
Iteration: 22/5406, Loss: 0.16716778278350