### Retina Net

Looking into another pre-trained model, this time we look at a network which tries to balance both accuracy and speed. This model tries to be a middle ground between the Faster RCNN model with ResNet50 and the Faster RCNN model with MobileNet V3 backbones. 

In [1]:
import torchvision.transforms as transforms
from torchvision.models.detection import RetinaNet_ResNet50_FPN_Weights
from torchvision.models.detection import retinanet_resnet50_fpn
from PIL import Image
from os import listdir
from os import path
import numpy as np
from torch import device
from torch import cuda
import cv2

In [2]:
join = path.join

# image path and annotations path.
img_path, ann_path = 'PRW/frames', 'PRW/annotations'

# get the image names.
img_names = sorted(list(listdir(img_path)))
img_names = [join(img_path, name) for name in img_names]

# get the annoation names.
# ann_names = sorted(list(listdir(ann_path)))
# ann_names = [join(ann_path, name) for name in ann_names]

In [3]:
CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

# define the torchvision image transforms
transform = transforms.Compose([
    transforms.ToTensor(),
])

def predict(image, model, device, detection_threshold):

    # transform the image to tensor.
    image = transform(image).to(device)
    image = image.unsqueeze(0) # add a batch dimension
    outputs = model(image)     # get the predictions on the image

    # get all the predicited class names
    pred_labels = outputs[0]['labels'].cpu().numpy()

    # get score for all the predicted objects.
    pred_scores = outputs[0]['scores'].detach().cpu().numpy()

    # get all the predicted bounding boxes.
    pred_bboxes = outputs[0]['boxes'].detach().cpu().numpy()

    # get boxes above the threshold score.
    boxes = pred_bboxes[pred_scores >= detection_threshold].astype(np.int32)
    labels = pred_labels[pred_scores >= detection_threshold]


    return boxes, labels

In [4]:
def draw_boxes(boxes, labels, image):

    # create a color for the bounding box.
    COLOR = [255, 0, 0] 
  
    # read the image with OpenCV
    image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)

    # draw only the boxes which are persons.
    for i, box in enumerate(boxes):
        if labels[i] == 1:
            cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), COLOR, 2)

    return image

In [5]:
# create the model.
weights = RetinaNet_ResNet50_FPN_Weights.DEFAULT
model = retinanet_resnet50_fpn(weights=weights)

Downloading: "https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth" to /Users/chiahaohsutai/.cache/torch/hub/checkpoints/retinanet_resnet50_fpn_coco-eeacb38b.pth
100%|██████████| 130M/130M [00:32<00:00, 4.22MB/s] 


In [6]:
# set the device.
device = device('cuda' if cuda.is_available() else 'cpu')

In [9]:
# make a prediction.
image = Image.open(img_names[0])
model.eval().to(device)
boxes, labels = predict(image, model, device, 0.6)

In [10]:
img = draw_boxes(boxes, labels, image)

# display the image.
cv2.imshow('Image', img)
cv2.waitKey(5000)
cv2.destroyWindow('Image')
cv2.waitKey(1)

-1

Results are not as accurate as the Faster RCNN with ResNet50 backbone.