In [2]:
import os,sys
sys.path.append(os.path.abspath('./../'))

import torch
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torch.utils.data import DataLoader
import coco.transforms as T
from coco.engine import train_one_epoch
from coco.utils import * 
from dataset import PennFudanDataset

import matplotlib.pyplot as plt
import cv2
%matplotlib inline 

In [None]:
data_path = './data/PennFudanPed'
save_path = './parameters'
num_epoch = 20

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
def get_transform(train):
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [None]:
trainset = PennFudanDataset(data_path, get_transform(train=True))
testset = PennFudanDataset(data_path, get_transform(train=False))

indices = [i for i in range(len(trainset))]
dataset = torch.utils.data.Subset(trainset, indices[:-1])
dataset_test = torch.utils.data.Subset(testset, indices[-1:])

trainLoader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
testLoader = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn)

In [None]:
# mobilenet_weight = torch.load(os.path.join(save_path, 'mobilenet.pth'))
backbone = torchvision.models.mobilenet_v2().features
# backbone.load_state_dict(mobilenet_weight, strict=False)

backbone.out_channels = 1280

anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
                                   aspect_ratios=((0.5, 1.0, 2.0),))
# print(anchor_generator.cell_anchors)
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0'],
                                                output_size=7,
                                                sampling_ratio=2)
model = FasterRCNN(backbone,
                   num_classes=2,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)
print(model.rpn)
model = model.to(device)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0005)

In [None]:
for epoch in range(num_epoch):
    train_one_epoch(model, optimizer, trainLoader, device, epoch, print_freq=10)
torch.save(model.state_dict(), os.path.join(save_path, 'detector.pth'))

In [None]:
imgs, targets= next(iter(testLoader))
img = imgs[0]
sample = img.permute(1,2,0).cpu().numpy()
target = targets[0]
boxes = target['boxes'].cpu().numpy().astype(int)
print(boxes)

In [None]:
model.eval()
device = torch.device('cpu')
model = model.to(device)
outputs = model(img.unsqueeze(0))
outputs = [{k: v.to(device) for k, v in t.items()} for t in outputs]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(16, 8))

mean_score = torch.mean(outputs[0]['scores'])

for box, score in zip(outputs[0]['boxes'].int(), outputs[0]['scores']):
    print(box, score)
    if score > 0.5:
        cv2.rectangle(sample,(box[0].item(), box[1].item()),(box[2].item(), box[3].item()),(225, 0, 0), 3)
        
for box in zip(targets[0]['boxes'].int()):
    box = box[0]
    cv2.rectangle(sample,(box[0].item(), box[1].item()),(box[2].item(), box[3].item()),(0, 0, 255), 3)
    
ax.set_axis_off()
ax.imshow(sample)
