Import packages

In [None]:
import cv2
import os, sys
import xml.etree.ElementTree as ET
import xmltodict, json
import numpy as np
import PIL.Image as Image
import PIL.ImageColor as ImageColor
import PIL.ImageDraw as ImageDraw
import PIL.ImageFont as ImageFont
from time import sleep
import psutil

import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

Display image+bounding box

In [None]:
def draw_bounding_box_on_image(image, ymin, xmin, ymax, xmax,color='red',
                               thickness=4, display_str_list=(), use_normalized_coordinates=True):
  """Adds a bounding box to an image.

  Bounding box coordinates can be specified in either absolute (pixel) or
  normalized coordinates by setting the use_normalized_coordinates argument.

  Each string in display_str_list is displayed on a separate line above the
  bounding box in black text on a rectangle filled with the input 'color'.
  If the top of the bounding box extends to the edge of the image, the strings
  are displayed below the bounding box.

  Args:
    image: a PIL.Image object.
    ymin: ymin of bounding box.
    xmin: xmin of bounding box.
    ymax: ymax of bounding box.
    xmax: xmax of bounding box.
    color: color to draw bounding box. Default is red.
    thickness: line thickness. Default value is 4.
    display_str_list: list of strings to display in box
                      (each to be shown on its own line).
    use_normalized_coordinates: If True (default), treat coordinates
      ymin, xmin, ymax, xmax as relative to the image.  Otherwise treat
      coordinates as absolute.
  """
  draw = ImageDraw.Draw(image)
  im_width, im_height = image.size
  if use_normalized_coordinates:
    (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
                                  ymin * im_height, ymax * im_height)
  else:
    (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
  draw.line([(left, top), (left, bottom), (right, bottom),
             (right, top), (left, top)], width=thickness, fill=color)
  try:
    font = ImageFont.truetype('arial.ttf', 24)
  except IOError:
    font = ImageFont.load_default()

  # If the total height of the display strings added to the top of the bounding
  # box exceeds the top of the image, stack the strings below the bounding box
  # instead of above.
  display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]
  
  # Each display_str has a top and bottom margin of 0.05x.
  total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)

  if top > total_display_str_height:
    text_bottom = top
  else:
    text_bottom = bottom + total_display_str_height

  # Reverse list and print from bottom to top.
  for display_str in display_str_list[::-1]:
    text_width, text_height = font.getsize(display_str)
    margin = np.ceil(0.05 * text_height)
    draw.rectangle([(left, text_bottom - text_height - 2 * margin), (left + text_width, text_bottom)], fill=color)
    draw.text((left + margin, text_bottom - text_height - margin),
              display_str, fill='black', font=font)
    text_bottom -= text_height - 2 * margin

List files in directory

In [None]:
annotated_path = '../data/augmented'

# list all files in directory
files = os.listdir(annotated_path)
n_files = len(files)
print("Nr of files: ", n_files)

Generate annotated images

In [None]:
# 5 colours for the bounding boxes, based on the classes 1 to 5
colours = {'one': 'red', 'two': 'green', 'three': 'blue', 'four': 'yellow', 'five': 'orange'}

# storage path for final result
save_path = '../data/annotated_combined/'

# only use .xml files
files_xml = [f for f in files if f.endswith('.xml')]

for file_id, filename in enumerate(files_xml): 
    # XML object --> dict for the current file
    # obj = xmltodict.parse(open(annotated_path + '/' + filename).read())

    # read the original xml
    xml_path = os.path.join(annotated_path, files_xml[file_id])
    xml_file_orig = ET.parse(xml_path)
    root = xml_file_orig.getroot()

    # open image
    image_pil = Image.open(annotated_path + '/' + filename.replace('xml', 'png'))

    # # if there is only one object, wrap in list
    # annotations = obj['annotation']['object']

    # if not isinstance(annotations, list):
    #     annotations = [annotations]

    # print all objects in file
    # loop over each bounding box
    for obj in root.iter('bndbox'):
        xmin = obj.find('xmin')
        xmax = obj.find('xmax')
        ymin = obj.find('ymin')
        ymax = obj.find('ymax')

        xmin = int(xmin.text)
        xmax = int(xmax.text)
        ymin = int(ymin.text)
        ymax = int(ymax.text)

        # class name
        class_name = 'one'

        # print("At filename: ", filename)
        # xmin = int(annotation['bndbox']['xmin'])
        # ymin = int(annotation['bndbox']['ymin'])
        # xmax = int(annotation['bndbox']['xmax'])
        # ymax = int(annotation['bndbox']['ymax'])
        # class_name = annotation['name']

        # print(type(xmin), type(ymin), type(xmax), type(ymax))

        # draw bounding boxes on image
        draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, colours[class_name],
                                    1, display_str_list=[class_name], use_normalized_coordinates=False)

        # save image
        image_path = os.path.join(save_path, 'gt_' + str(file_id) + '_.png')
        
        image_pil.save(image_path)



Load finetuned model

In [None]:

# load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)


Testing finetuned model

In [None]:

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# dataset = PennFudanDataset('PennFudanPed', get_transform(train=True))

# data_loader = torch.utils.data.DataLoader(
#  dataset, batch_size=2, shuffle=True, num_workers=4,
#  collate_fn=utils.collate_fn)

# For Training
# images,targets = next(iter(data_loader))

# images = list(image for image in images)
# targets = [{k: v for k, v in t.items()} for t in targets]

# output = model(images,targets)   # Returns losses and detections

# For inference
model.eval()

x = [torch.rand(3, 300, 400)]
print(x[0].shape)
predictions = model(x)           # Returns predictions
print(predictions)