In [6]:
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

from datasets.esri_datascience_challenge_2019 import ESRI_challenge_2019

from PIL import ImageDraw

import matplotlib.pyplot as plt

from engine import train_one_epoch, evaluate
import transforms as T

from utils import collate_fn

In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Define the dataset:

In [8]:
dataset = ESRI_challenge_2019 (root = "D:\\Documents\\AiTLAS\\aitlas\\ESRI-challenge\\data\\ESRI",
                               subset = "train",
                               transforms = T.ToTensor())

dataset_test = ESRI_challenge_2019 (root = "D:\\Documents\\AiTLAS\\aitlas\\ESRI-challenge\\data\\ESRI",
                               subset = "test",
                               transforms = T.ToTensor()) 

## FastRCNN fine tuning on ESRI

In [9]:
# num_classes = ['null', 'pool', 'car']
def get_object_detection_model(num_classes = 3):

    # load a model pre-trained pre-trained on COCO
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    
    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 

    return model

### Define the data loaders:

In [10]:
# define training and test data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=0,
    collate_fn=collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=2, shuffle=False, num_workers=0,
    collate_fn=collate_fn)

Move the model to the GPU and attach an optimizer as well as a learning rate scheduler:

In [11]:
# to train on gpu if selected.
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

num_classes = 3

# get the model using our helper function
model = get_object_detection_model(num_classes)

# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

### Train 10 epochs by calling train_one_epoch:

In [13]:
# training for 10 epochs
num_epochs = 10

for epoch in range(num_epochs):
    # training for one epoch
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=100)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

Epoch: [0]  [   0/1614]  eta: 0:11:56  lr: 0.000010  loss: 1.5903 (1.5903)  loss_classifier: 1.2246 (1.2246)  loss_box_reg: 0.0627 (0.0627)  loss_objectness: 0.2864 (0.2864)  loss_rpn_box_reg: 0.0166 (0.0166)  time: 0.4440  data: 0.0270  max mem: 2974
Epoch: [0]  [ 100/1614]  eta: 0:10:01  lr: 0.000509  loss: 0.4711 (0.6442)  loss_classifier: 0.1743 (0.3113)  loss_box_reg: 0.1841 (0.1546)  loss_objectness: 0.0579 (0.1515)  loss_rpn_box_reg: 0.0091 (0.0269)  time: 0.4006  data: 0.0252  max mem: 2974
Epoch: [0]  [ 200/1614]  eta: 0:09:24  lr: 0.001009  loss: 0.4186 (0.5544)  loss_classifier: 0.1155 (0.2352)  loss_box_reg: 0.2038 (0.1919)  loss_objectness: 0.0242 (0.1005)  loss_rpn_box_reg: 0.0057 (0.0268)  time: 0.4005  data: 0.0248  max mem: 2974
Epoch: [0]  [ 300/1614]  eta: 0:08:45  lr: 0.001508  loss: 0.2531 (0.4837)  loss_classifier: 0.0663 (0.1892)  loss_box_reg: 0.1300 (0.1829)  loss_objectness: 0.0284 (0.0825)  loss_rpn_box_reg: 0.0085 (0.0292)  time: 0.4073  data: 0.0272  max me

KeyboardInterrupt: 

In [14]:
torch.cuda.empty_cache()