In [1]:
import torch
from groceries_dataset import GroceriesDataset, collate_fn
from utils.engine import train_one_epoch, evaluate
from torch.utils.data import DataLoader
import config_maskrcnn as config
import torchvision

In [2]:
train_groceries_dataset = GroceriesDataset(root_dir=config.train_root_dir, ann_file=config.train_ann_file,
                                           transform='train')
test_groceries_dataset = GroceriesDataset(root_dir=config.test_root_dir, ann_file=config.test_ann_file,
                                          transform='test')

loading annotations into memory...
Done (t=0.07s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [3]:
train_dataloader = DataLoader(
    train_groceries_dataset,
    batch_size=config.batch_size,
    shuffle=config.train_shuffle,
    num_workers=config.number_of_workers,
    collate_fn=collate_fn,
)

In [4]:
test_dataloader = DataLoader(
    test_groceries_dataset,
    batch_size=config.batch_size,
    shuffle=config.test_shuffle,
    num_workers=config.number_of_workers,
    collate_fn=collate_fn,
)

In [5]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device

device(type='cuda')

In [6]:
model = torchvision.models.detection.maskrcnn_resnet50_fpn_v2(weights='DEFAULT', progress=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, 10)

in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 512
model.roi_heads.mask_predictor = torchvision.models.detection.mask_rcnn.MaskRCNNPredictor(in_features_mask,
                                                                                          hidden_layer, 10)
model.to(device)

MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
         

In [7]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, weight_decay=0.0005)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [8]:
num_epochs = 12
for e in range(num_epochs):
    train_one_epoch(model, optimizer, train_dataloader, device, e, print_freq=10)
    lr_scheduler.step()
    evaluate(model, test_dataloader, device=device)

model_scripted = torch.jit.script(model)
model_scripted.save('groceries_maskrcnn.pt')

  return F.conv2d(input, weight, bias, self.stride,


Epoch: [0]  [  0/725]  eta: 0:13:36  lr: 0.000024  loss: 4.3522 (4.3522)  loss_classifier: 2.3070 (2.3070)  loss_box_reg: 0.6972 (0.6972)  loss_mask: 1.0432 (1.0432)  loss_objectness: 0.2591 (0.2591)  loss_rpn_box_reg: 0.0457 (0.0457)  time: 1.1264  data: 0.2983  max mem: 4632
Epoch: [0]  [ 10/725]  eta: 0:06:46  lr: 0.000162  loss: 3.8881 (3.8727)  loss_classifier: 2.2423 (2.2200)  loss_box_reg: 0.4993 (0.5348)  loss_mask: 1.0315 (1.0205)  loss_objectness: 0.0656 (0.0826)  loss_rpn_box_reg: 0.0088 (0.0148)  time: 0.5692  data: 0.0354  max mem: 5099
Epoch: [0]  [ 20/725]  eta: 0:06:14  lr: 0.000300  loss: 3.1429 (3.2677)  loss_classifier: 1.9183 (1.8460)  loss_box_reg: 0.4039 (0.4643)  loss_mask: 0.8428 (0.8733)  loss_objectness: 0.0547 (0.0703)  loss_rpn_box_reg: 0.0088 (0.0137)  time: 0.5021  data: 0.0097  max mem: 5099
Epoch: [0]  [ 30/725]  eta: 0:06:02  lr: 0.000438  loss: 1.8866 (2.7042)  loss_classifier: 0.8486 (1.4389)  loss_box_reg: 0.3571 (0.4417)  loss_mask: 0.5384 (0.7456) 

In [9]:
# model_scripted = torch.jit.script(model)
# model_scripted.save('groceries_maskrcnn.pt')

In [10]:
import cv2 as cv

In [11]:
# test_image = test_groceries_dataset[0][0]
# test_img = test_image.permute(1, 2, 0).numpy()
# cv.imshow('Original image', test_img)
# cv.waitKey(0)
# cv.destroyAllWindows()

In [12]:
# model.eval()
# with torch.no_grad():
#     prediction = model([test_image.to(device)])

In [13]:
# prediction[0]

In [14]:
# for i in range(len(prediction[0]['masks'])):
#     if prediction[0]['scores'][i] < 0.5:
#         continue
#     mask = prediction[0]['masks'][i, 0].mul(255).byte().cpu().numpy()
#     _, mask = cv.threshold(mask, 127, 255, cv.THRESH_BINARY)
#     cv.imwrite(f'mask_{i}.png', mask)
#     # cv.imshow('Mask', mask)
#     # cv.waitKey(0)
#     # cv.destroyAllWindows()