## Object Detection
Fastai DL Course Lesson 8 and 9 - [video](https://www.youtube.com/watch?v=0frKXR-2PBY&feature=youtu.be), [notebook](https://github.com/fastai/fastai/blob/master/courses/dl2/pascal.ipynb)

In [0]:
!curl -s https://course.fast.ai/setup/colab | bash

### Download and prepare data for learning

In [0]:
!mkdir pascal

# Download the Pascal 2007 data set containing images and annotations
!wget http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar
!tar xf VOCtrainval_06-Nov-2007.tar
!mv VOCdevkit pascal

# Download the annotations in JSON format (as the above dataset has them in XML)
# The annotations contain multiple object types along with respective bounding boxes for each image
!wget https://storage.googleapis.com/coco-dataset/external/PASCAL_VOC.zip
!unzip PASCAL_VOC.zip

!mv PASCAL_VOC/*.json pascal
!rmdir PASCAL_VOC

In [0]:
from pathlib import Path
import collections
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json

from fastai.vision import *

#-----------------------------------------------------
# List the contents of the two important directories of the
# dataset - one containing the info JSON file and the other
# containing the JPEG images themselves.
#-----------------------------------------------------
def list_data_dirs ():
  # Info JSON
  data_dir = Path.cwd()/'pascal'
  print (list(data_dir.iterdir()))
  info_json = data_dir / 'pascal_train2007.json'

  # JPEG Image files
  voc = data_dir/'VOCdevkit'/'VOC2007'
  print (list(voc.iterdir()))
  jpeg_dir = voc/'JPEGImages'
  print (list(jpeg_dir.iterdir())[:3])

  # Return the path objects for both
  return data_dir, info_json, jpeg_dir

#-----------------------------------------------------
# Bounding boxes in the dataset are specified as [top-left-col, top-left-row, height, width]
# We convert them to [top-left-row, top-left-col, bottom-right-row, bottom-right-col] so
# that it is consistent with numpy
#-----------------------------------------------------
def hw_bb(bb): return np.array([bb[1], bb[0], bb[3]+bb[1]-1, bb[2]+bb[0]-1])

#-----------------------------------------------------
# Create an inverse of the above function to reverse the conversion
#-----------------------------------------------------
def bb_hw(a): return np.array([a[1],a[0],a[3]-a[1]+1,a[2]-a[0]+1])

#-----------------------------------------------------
# List the contents of the two important directories of the
# dataset - one containing the info JSON file and the other
# containing the JPEG images themselves.
#-----------------------------------------------------
def process_json (info_json):
  
  # Load the JSON file, which contains 4 keys:
  #    'images': maps to a list of image_ids to image_filenames
  #    'annotations': maps to a list of bounding boxes for each image, and object category IDs
  #    'categories': maps object category IDs to category names eg. car, horse etc.
  with info_json.open() as fp:
    info_dict = json.load(fp)
    print (info_dict.keys())

  IMAGES,ANNOTATIONS,CATEGORIES = ['images', 'annotations', 'categories']
  print (info_dict[IMAGES][:2])
  print (info_dict[ANNOTATIONS][:2]) 
  print (info_dict[CATEGORIES][:2])
  
  # Create dictionary mapping image id to image file names
  image_dict = {image['id']:image['file_name'] for image in info_dict[IMAGES] }
  
  # Create dictionary mapping category id to category name
  category_dict = {category['id']:category['name'] for category in info_dict[CATEGORIES] }

  # Create dictionary  mapping image id to a list of bounding boxes and category names
  annot_dict = collections.defaultdict(lambda:[])
  for annot in info_dict[ANNOTATIONS]:
    if not annot['ignore']:
      # Tuple of (category name, [bounding box coordinates])
      new_tuple = (category_dict[annot['category_id']], hw_bb(annot['bbox']))
    
      # Append the new tuple to the existing list of tuples for this image id
      # Because we used defaultdict() above, it will initialise a default value of
      # an empty list if the key for the image id doesn't exist
      annot_dict[annot['image_id']].append (new_tuple)
    
  # Return the three dictionaries we created
  return (image_dict, annot_dict, category_dict)

data_dir, info_json, jpeg_dir = list_data_dirs ()
image_dict, annot_dict, category_dict = process_json (info_json)

print(image_dict[17])
print (category_dict[2])
print(list(annot_dict.items())[3])
open_image (jpeg_dir/image_dict[17])

In [0]:
#-----------------------------------------------------
# Some Matplotlib utility functions
#-----------------------------------------------------

#-----------------------------------------------------
# Show an image
#-----------------------------------------------------
def show_img(im, figsize=None, ax=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    return ax

#-----------------------------------------------------
# Draw bounding box
# A trick to making text visible regardless of background is to use white text with black outline, or visa versa
#-----------------------------------------------------
def draw_outline(o, lw):
    o.set_path_effects([patheffects.Stroke(
        linewidth=lw, foreground='black'), patheffects.Normal()])

#-----------------------------------------------------
# Draw a rectangle
# Note that * in argument lists is the splat operator. In this case it's a little shortcut compared to writing out b[-2],b[-1].
#-----------------------------------------------------
def draw_rect(ax, b):
    patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor='white', lw=2))
    draw_outline(patch, 4)

#-----------------------------------------------------
# Draw some text
#-----------------------------------------------------
def draw_text(ax, xy, txt, sz=14):
    text = ax.text(*xy, txt,
        verticalalignment='top', color='white', fontsize=sz, weight='bold')
    draw_outline(text, 1)

#-----------------------------------------------------
# Given the image data, show it and draw the bounding boxes and
# annotate the category name
#-----------------------------------------------------
def draw_image(im_data, annot):
  ax = show_img(im_data, figsize=(12,6))
  for category_name, bbox in annot:
    bb = bb_hw(bbox)
    draw_rect(ax, bb)
    draw_text(ax, bb[:2], category_name)
  
#-----------------------------------------------------
# Given the image_id, draw the image along with the bounding boxes and category
# names of each object in the image
#-----------------------------------------------------
def draw_idx(image_id, image_dict, annot_dict, jpeg_dir):
  image_filename = image_dict[image_id]
  
  # Fastai open_image() returns the image data with the channel as the first
  # dimension, so transpose it to third dimension
  im = open_image (jpeg_dir/image_filename)
  im_data = np.transpose (im.data, (1, 2, 0))
  
  # Draw the image with the annotations
  draw_image(im_data, annot_dict[image_id])
    
draw_idx(12, image_dict, annot_dict, jpeg_dir)

In [0]:
#-----------------------------------------------------
# Create a dictionary mapping each image_id to its largest bounding box
#-----------------------------------------------------


#-----------------------------------------------------
# Helper function to find the largest bounding box for an image
# Sort all the bounding boxes in descending order based on the area, and then 
# return the first one
#-----------------------------------------------------
def largest_bbox (annot):
  # The lambda function computes the area of the bounding box based on the top-left and
  # bottom-right corners
  sorted_annot = sorted (annot, key=lambda x: (x[1][2] - x[1][0]) * (x[1][3] - x[1][1]), reverse=True)
  return sorted_annot[0]

# We map each image to a list with only a single boounding box ie. {image_id: [largest_bbox(annot)]} rather
# than without a list ie. {image_id: largest_bbox(annot)} so that we can continue to use draw_idx() 
# which expects a list.
largest_annot_dict = {image_id: [largest_bbox(annot)] for image_id, annot in annot_dict.items()}

draw_idx(23, image_dict, largest_annot_dict, jpeg_dir)
annot_dict[23], largest_annot_dict[23]

In [0]:
#-----------------------------------------------------
# Convert the bounding box from a numpy array to a string of coordinates
# separated by spaces
#
# Supporting function for creating the CSV file
#-----------------------------------------------------
def bbox2str (bbox):
  bbox_str = " ".join (str(float(i)) for i in bbox)
  return (bbox_str)

#-----------------------------------------------------
# Write the data out to a CSV file using Pandas, so that we can create a Fastai Dataset
# easily by import from the CSV
#-----------------------------------------------------
def save_csv (csv_file):
  # We create several variations of columns for the bounding box data, after converting to floats
  # 'box' - a single string with all the coordinates separated by spaces
  # 'tl' and 'br' - coordinates of the top-left point and bottom-right point as an array [row, col]
  # 'tlr' and 'tlc '- coordinate of the top-left point as row and column. Similarly for the bottom-right
  # We actually use only 'tl' and 'br'
  df = pd.DataFrame (
      {'file': [image_dict[image_id] for image_id in largest_annot_dict.keys()], 
      'category': [annot[0][0] for annot in largest_annot_dict.values()],
      'box': [bbox2str(annot[0][1]) for annot in largest_annot_dict.values()],
      'tl': [[float(annot[0][1][0]), float(annot[0][1][1])] for annot in largest_annot_dict.values()],
      'br': [[float(annot[0][1][2]), float(annot[0][1][3])] for annot in largest_annot_dict.values()],
      'tlr': [float(annot[0][1][0]) for annot in largest_annot_dict.values()],
      'tlc': [float(annot[0][1][1]) for annot in largest_annot_dict.values()],
      'brr': [float(annot[0][1][2]) for annot in largest_annot_dict.values()],
      'brc': [float(annot[0][1][3]) for annot in largest_annot_dict.values()]},
      columns=['file', 'category', 'box', 'tl', 'br', 'tlr', 'tlc', 'brr', 'brc'])
  df.to_csv(csv_file, index=False)
  return (df)
  
csv_file = data_dir/'data.csv'
df = save_csv (csv_file)
df.head()

### Single Object Classification of Largest Object (without bounding box)

In [0]:
#-----------------------------------------------------
# The first step is to do single object classification of the largest object for each
# image, without considering the bounding boxes
#
# Build the data bunch as usual from the CSV file. Since it resizes all images to size 224x224,
# we use the SQUISH method (shrink both sides without preserving aspect ratio) rather than a 
# CROP method (preserve aspect ratio, make the smaller side=224 and then crop from centre)
#-----------------------------------------------------
tfms = get_transforms()
data_oc = ImageDataBunch.from_csv(data_dir, 'VOCdevkit/VOC2007/JPEGImages', csv_labels=csv_file, ds_tfms=tfms, 
                               size=224, resize_method=ResizeMethod.SQUISH, bs=64).normalize(imagenet_stats)
data_oc.show_batch(rows=3, figsize=(7,6))

In [0]:
#-----------------------------------------------------
# Create a CNN learner object based on resnet transfer learning, and use the Learning Rate Finder
#-----------------------------------------------------
learn_oc = cnn_learner(data_oc, models.resnet34, metrics=[accuracy])
learn_oc.lr_find()
learn_oc.recorder.plot()

In [0]:
#-----------------------------------------------------
# Train using the Fit-One-Cycle
#-----------------------------------------------------
lr = 2e-2
learn_oc.fit_one_cycle(cyc_len=2, max_lr=slice(lr))

In [0]:
#-----------------------------------------------------
# Unfreeze the earlier layers and use the LR Finder again
#-----------------------------------------------------
learn_oc.unfreeze()
learn_oc.lr_find()
learn_oc.recorder.plot()

In [0]:
#-----------------------------------------------------
# Train again using the Fit-One-Cycle
#-----------------------------------------------------
learn_oc.fit_one_cycle(cyc_len=2, max_lr=slice(1e-6,1e-4))
learn_oc.freeze()

In [0]:
#-----------------------------------------------------
# Look at the results ie. incorrect predictions, using top losses
#-----------------------------------------------------
interp_oc = ClassificationInterpretation.from_learner(learn_oc)
interp_oc.plot_top_losses(9, figsize=(15,11), heatmap=False)

In [0]:
#-----------------------------------------------------
# Look at the results ie. confusion matrix
#-----------------------------------------------------
interp_oc.plot_confusion_matrix(figsize=(8,6))

In [0]:
#-----------------------------------------------------
# Look at the results ie. classes incorrectly predicted
#-----------------------------------------------------
interp_oc.most_confused(min_val=2)

### Regression of Largest bounding box (without classification)

In [0]:
#-----------------------------------------------------
# We want to learn how to predict just the bounding box. The bounding box
# consists of 4 floating point numbers corresponding to the coordinates of the top-left
# and bottom-right corners. So we can model this as a multi-value regression problem.
#
# We create the data bunch such that:
#  X: images
#  y: two image points for the two bounding box corners, using the PointsItemList
#-----------------------------------------------------

#-----------------------------------------------------
# Supporting function to compute the 'y' label for the bounding box coordinates, given
# the image file name. It looks these up from the CSV dataframe
#
# Returns a tensor of two points [[tlr, tlc][brr, brc]], as required by PointsItemList
#-----------------------------------------------------
def bbox_label_func(filepath):
    image_name = Path(filepath).name
    image_row = df[df['file']==image_name].iloc[0]
    tl, br = image_row['tl'], image_row['br']
    return tensor([tl, br])

#-----------------------------------------------------
# The PointsItemList is suitable since our 'y' labels are two bounding box points on an 
# image. Those labels are computed using a function, by looking up the bounding-box 
# coordinates from our CSV dataframe
#
# (I think) The PointsItemList has built-in intelligence to transform the points when
# the corresponding image is transformed during data augmentation.
#-----------------------------------------------------
data_bb = (PointsItemList.from_df(df, path=data_dir, folder='VOCdevkit/VOC2007/JPEGImages')
      .split_by_rand_pct()
      .label_from_func(bbox_label_func)
      .transform(get_transforms(), resize_method=ResizeMethod.SQUISH, size=224)
      .databunch(bs=64)) 
data_bb.normalize(imagenet_stats)

# Since this is a PointsItemList, it shows the images along small red dots for the
# two bounding box points
data_bb.show_batch(rows=3, figsize=(7,6))

# No need for tfm_y = True, TfmType.COORD?

In [0]:
data_bb.train_ds.x.show_xys(data_bb.train_ds.x[0:8], data_bb.train_ds.y[0:8])

In [0]:
#-----------------------------------------------------
# (I think) The PointsItemList has built-in intelligence to create a learner with the
# appropriate final layers suitable for multi-value regression, and the appropriate
# loss function.
#
# There is no need to use custom head to add the custom layers and to manually
# add a custom loss function
#-----------------------------------------------------
learn_bb = cnn_learner(data_bb, models.resnet34)
learn_bb.lr_find()
learn_bb.recorder.plot()

In [0]:
lr = 2e-2
learn_bb.fit_one_cycle(cyc_len=2, max_lr=slice(lr))

In [0]:
# https://docs.fast.ai/basic_train.html#See-results
learn_bb.data.train_ds[0]
learn_bb.predict()
learn_bb.get_preds()
learn_bb.validate()
learn_bb.pred_batch()
learn.data.one_item(item)

In [0]:
#-----------------------------------------------------
# NB: Fastai doesn't yet have a class for Regression Interpretation similar to
# ClassificationInterpretation
#-----------------------------------------------------
learn_bb.show_results()

In [0]:
# Figure out how to look at model dataloader input data
# Figure out how to look at model result outputs
# Figure out whether transforms are doing the bounding box as well or not

### Object Detection with Retinanet from scratch. Based on [this](https://github.com/fastai/course-v3/blob/master/nbs/dl2/pascal.ipynb) notebook

**================= Prepare Data ===============**

In [0]:
# These lines didn't work. In Jupyter, any variables in the last line of a cell are displayed as
# the output of that cell. These two lines were supposed to turn on a feature in Jupyter where variables
# on any line, not just the last one, get displayed. 
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

from fastai.vision import *
import json

# If you ever get the bizarre "CUDA device-side assert" error, then uncomment these lines to get a meaningful stack trace so you can debug
#import os
#os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
#print("CUDA ", os.environ['CUDA_LAUNCH_BLOCKING']) 

# Load the Pascal VOC 2007 dataset (which is smaller than the 2012 dataset)
path = untar_data(URLs.PASCAL_2007)
annots = json.load(open(path/'train.json'))

# We ignore the parsed Json. Instead we use fastai built-in get_annotations to prepare the train 
# and validation data. For each image file, there are one or more bounding boxes and class labels
train_images, train_lbl_bbox = get_annotations(path/'train.json')
val_images, val_lbl_bbox = get_annotations(path/'valid.json')
#tst_images, tst_lbl_bbox = get_annotations(path/'test.json')
train_images[1], train_lbl_bbox[1]

In [0]:
# Open a sample image
img = open_image(path/'train'/train_images[1])
# ImageBBox is Fastai's class to represent bounding boxes on an image. We create an ImageBBox object from the list of bounding boxes. This will allow us 
# to apply data augmentation to our bounding box. We need to give it the height and the width of the 
# original picture, the list of bounding boxes, the list of category ids and the classes list (to map an id to a class)
bbox = ImageBBox.create(*img.size, train_lbl_bbox[1][0], [0, 1], classes=['person', 'horse'])
img.show(figsize=(6,4), y=bbox)

In [0]:
# If we apply a transform to our image and the ImageBBox object, they stay aligned
img = img.rotate(-10)
bbox = bbox.rotate(-10)
img.show(figsize=(6,4), y=bbox)

In [0]:
#------------------------------------------------------
# Use the Datablock API to load a Databunch
#------------------------------------------------------
images, lbl_bbox = train_images+val_images,train_lbl_bbox+val_lbl_bbox
img2bbox = dict(zip(images, lbl_bbox))
get_y_func = lambda o:img2bbox[o.name]

def get_data(bs, size):
    src = ObjectItemList.from_folder(path/'train')
    src = src.split_by_files(val_images)
    src = src.label_from_func(get_y_func)
    src = src.transform(get_transforms(), size=size, tfm_y=True)
    # our images may have multiple bounding boxes, so the collate function pads them 
    # to the largest number of bounding boxes
    return src.databunch(path=path, bs=bs, collate_fn=bb_pad_collate)
  
data = get_data(64,128)
data.show_batch(rows=3)

**========================= Plotting Functions for Visualisation only ======================**

In [0]:
#-----------------------------------------------------
# Utility functions to visualise images, with grids and bounding boxes
#-----------------------------------------------------

import pdb
import IPython.core.debugger as db

import matplotlib.cm as cmx
import matplotlib.colors as mcolors
from cycler import cycler

def get_cmap(N):
    color_norm  = mcolors.Normalize(vmin=0, vmax=N-1)
    return cmx.ScalarMappable(norm=color_norm, cmap='Set3').to_rgba

num_color = 12
cmap = get_cmap(num_color)
color_list = [cmap(float(x)) for x in range(num_color)]

def draw_outline(o, lw):
    o.set_path_effects([patheffects.Stroke(
        linewidth=lw, foreground='black'), patheffects.Normal()])

def draw_rect(ax, b, color='white'):
    patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))
    draw_outline(patch, 4)

def draw_text(ax, xy, txt, sz=14, color='white'):
    text = ax.text(*xy, txt,
        verticalalignment='top', color=color, fontsize=sz, weight='bold')
    draw_outline(text, 1)

def show_boxes(boxes):
    "Show the `boxes` (size by 4)"
    _, ax = plt.subplots(1,1, figsize=(5,5))
    ax.set_xlim(-1,1)
    ax.set_ylim(1,-1)
    for i, bbox in enumerate(boxes):
        bb = bbox.numpy()
        rect = [bb[1]-bb[3]/2, bb[0]-bb[2]/2, bb[3], bb[2]]
        draw_rect(ax, rect, color=color_list[i%num_color])
        draw_text(ax, [bb[1]-bb[3]/2,bb[0]-bb[2]/2], str(i), color=color_list[i%num_color])

def show_anchors(ancs, size):
    _,ax = plt.subplots(1,1, figsize=(5,5))
    ax.set_xticks(np.linspace(-1,1, size[1]+1))
    ax.set_yticks(np.linspace(-1,1, size[0]+1))
    ax.grid()
    ax.scatter(ancs[:,1], ancs[:,0]) #y is first
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    ax.set_xlim(-1,1)
    ax.set_ylim(1,-1) #-1 is top, 1 is bottom
    for i, (x, y) in enumerate(zip(ancs[:, 1], ancs[:, 0])): ax.annotate(i, xy = (x,y))

**===================== Supporting functions needed for Loss Function =====================**

In [0]:
#-----------------------------------------------------
# Define a grid of cells on the image, given size as a tuple (number of rows, number of columns)
# Compute the parameters of the grid eg. height and width of the grid, and height and width of each cell.
# 
# Our convention is that y is first (like in numpy or PyTorch), and that all coordinates are 
# scaled from -1 to 1 (-1 being top/right, 1 being bottom/left).
# So, to avoid confusion, for our variable names which represent coordinates, rather than use 'x' and y', we use 'r'and 'c' which 
# stands for 'row' (ie. vertical coordinates) and 'column' (ie. horizontal coordinates)
#-----------------------------------------------------
def grid_info(size):
  # Number of rows, number of columns in the grid
  size_r, size_c = size
  # Coordinates of Top Row, Bottom Row, Left Column and Right Column
  top_r, bottom_r, left_c, right_c = -1, 1, 1, -1
  # Height and Width of the Grid
  grid_h, grid_w = bottom_r - top_r, left_c - right_c
  # Height and Width of Each cell
  cell_h, cell_w = grid_h/size_r, grid_w/size_c

  return (size_r, size_c, top_r, bottom_r, left_c, right_c, cell_h, cell_w)

#-----------------------------------------------------
# Return an array with the coordinates of all the grid cell centres. It has shape (total number of centres, 2) 
# where the total number of centres = number of grid rows * number of grid columns
#
# Note that we return coordinates of each centre using the coordinate system from -1 to +1 as above
# [[ctr1_r, ctr1_c], [ctr2_r, ctr2_c]....]
#-----------------------------------------------------
def create_grid_centres (size):
  size_r, size_c, top_r, bottom_r, left_c, right_c, cell_h, cell_w = grid_info(size)

  # Sequence of values of the centre point in vertical direction, from top to bottom
  val_r = np.linspace (top_r + cell_h/2, bottom_r - cell_h/2, size_r)
  # Sequence of values of the centre point in horizontal direction, from right to left ie. in the 'reverse' direction
  # We do this so that the show_anchors() puts a index number on each centre in the 'correct' direction
  val_c = np.linspace (right_c + cell_w/2, left_c - cell_w/2, size_c)
  
  # Tile the horizontal values as [a b c] -> [a b c a b c ...]
  # This represents the 'horizontal half' of the coordinate of each centre
  pts_c = np.tile(val_c, size_r)
  # Repeat the vertical values as [u v w] -> [u u u v v v w w w]
  # This represents the 'vertical half' of the coordinate of each centre
  pts_r = np.repeat(val_r, size_c)

  # Now finally we get actual coordinates for the centres by concatenating the two halves of 
  # the sequence of values as [column of vertical values, column of horizontal values] to give us
  # [[u, a], [u, b], [u, c], [v, a], ....]
  ctrs = np.stack((pts_r, pts_c), axis=1)
  return (ctrs)

gsize=(2, 5)
bar = create_grid_centres (gsize)
print (bar)
show_anchors (bar, gsize)

In [0]:
#-----------------------------------------------------
# We use different grid sizes (ie. gsizes). For each grid size, we define several anchor boxes
# per grid cell. The anchor boxes are defined only by their height and width, and are assumed to
# be centred on the centre of the grid cell. The height and width of the anchor boxes is calculated
# using a formula based on a list of 'ratios' and 'scales'
#
# Typically we have 5 grid sizes ranging from (2,2) to (32, 32). Typically, for each grid size, for each grid cell, we have
# 9 anchors each. eg. a (2, 2) grid has 4 (ie. 2 * 2) grid cells. That means a total of 36 (ie. 4 * 9) anchors for this grid.
# Similarly we have anchors for all the other grid sizes.
#-----------------------------------------------------
ratios = [1/2,1,2]
scales = [1,2**(-1/3), 2**(-2/3)] 
#Paper used [1,2**(1/3), 2**(2/3)] but a bigger size (600) too, so the largest feature map gave anchors that cover less of the image.
gsizes = [(2**i,2**i) for i in range(5)]
gsizes.reverse() #Predictions come in the order of the smallest feature map to the biggest
print(ratios, scales, gsizes)

#-----------------------------------------------------
# Return an array of all anchors for all grid sizes, and all grid cells for each grid size
# The shape is (total number of anchors, 4) ie. [[anchor_1], [anchor_2], .....] where 
# the 4 coordinates are [centre_r, centre_c, height, width]
#
# We are given a list of grid sizes, a list of ratios and a list of scales
# If there are 3 ratios and 3 scales, we will have 9 (ie. 3 * 3) anchors per grid cell
#-----------------------------------------------------
def create_anchor_boxes (sizes, ratios, scales):
  # We have to make combinations of every ratio with every scale
  # This statement returns [[ratio_1, scale_1], [ratio_1, scale_2], ....[ratio_m, scale_n]]
  combos = np.array(np.meshgrid(ratios, scales)).T.reshape(-1,2)
  
  # First column is ratio values and second column is scale values
  ratio_v, scale_v = combos[:, 0], combos[:, 1]
  
  # Now use the formula to compute the aspect height and width proportions of the anchor boxes
  # eg. if there are 9 anchor boxes per grid cell, there are 9 aspect ratios. These aspects are
  # the same across different grid sizes. As grid sizes change, and therefore grid cell sizes,
  # the actual anchor height and width also change, even thought the aspect ratio doesn't change.
  aspect_h = scale_v * np.sqrt (ratio_v)
  aspect_w = scale_v * np.sqrt (1/ratio_v)
  
  # Loop through every grid size and populate the anchors array
  anchors = np.empty(shape=(0, 4))
  for h, w in sizes:
    # Calculate actual anchor height and width based on grid cell height and width
    # and the aspect height and width
    _, _, _, _, _, _, cell_h, cell_w = grid_info((h, w))
    anchor_h = 4 * aspect_h * cell_h
    anchor_w = 4 * aspect_w * cell_w
    
    # Concatenate the height and width columns. 
    # eg. if there are 9 anchors, this has shape (9, 2)
    anchor_hw = np.stack([anchor_h, anchor_w], axis=1)
    
    # Get the grid cell centres for this grid size
    ctrs = create_grid_centres ((h,w))
    
    # Cartesian product of every grid cell centre with every (height, width) pair
    # eg. if there are 64 grid cells (for grid size (8, 8)) and 9 anchors, it has
    # shape (576, 4). The 4 coordinates are (centre_r, centre_c, height, width) 
    ancs = cartesian(ctrs, anchor_hw)
    
    # Concatenate the anchors for this grid size to the cumulative array of anchors
    anchors = np.vstack((anchors, ancs))
  
  # Convert to Pytorch tensor
  anchors_t = torch.from_numpy(anchors.astype('float32'))
  return (anchors_t)

#-----------------------------------------------------
# Cartesian Product of two 2D arrays
# Given the first array of shape (m, 2) and the second array of shape (n, 2) it
# returns an array of shape (m * n, 4)
#-----------------------------------------------------
def cartesian (f2d, s2d):
  # Repeat [[a1, a2], [b1, b2],  [c1, c2]] -> [[a1, a2], [a1, a2], [b1, b2], [b1, b2] [c1, c2], [c1, c2]]. 
  # Repeated 2 times since second array length (ie. 'n') is 2
  # Resulting shape is (m * n, 2)
  f = np.repeat(f2d, s2d.shape[0], axis=0)

  # Tile [[u1, u2], [v1, v2]] -> [[u1, u2], [v1, v2], [u1, u2], [v1, v2], [u1, u2], [v1, v2]]. 
  # Tile 3 times since first array length (ie. 'm') is 3
  # Resulting shape is (m * n, 2)
  s = np.tile(s2d, (f2d.shape[0], 1))
  
  # Concatenate the two arrays column-wise to get [[a1, a2, u1, u2], [a1, a2, v1, v2], ...]
  # Resulting shape is (m * n, 4)
  c = np.hstack((f, s))

  return (c)

#foo = create_anchors(gsizes, ratios, scales)
#anchors[70:80], foo[70:80], anchors.shape, foo.shape

anchors = create_anchor_boxes (gsizes, ratios, scales)
show_boxes(anchors[900:909])

In [0]:
#-----------------------------------------------------
# Convert the output activations predicted by the model's regressor to bounding box coordinates
#
# For each anchor, there are 4 floats predicted: p_y,p_x,p_h,p_w 
# If the corresponding anchor has a center (anc_y, anc_x) and height-width as (anc_h, anc_w), then 
# the predicted bounding box coordinates are computed as:
#     center = [p_y * anc_h + anc_y, p_x * anc_w + anc_x]
#     height = anc_h * exp(p_h)
#     width  = anc_w * exp(p_w)
#-----------------------------------------------------
def prediction_to_bbox (activations, anchors):
  # This part is optional, but we dampen the activations by scaling them as it helps regularization
  scale_val = np.array([0.1, 0.1, 0.2, 0.2], dtype=np.float32)
  activations_mul = torch.as_tensor(np.multiply (activations, scale_val), device=activations.device)
  
  # First two columns are (p_y, p_x)
  activations_xy = activations_mul[:, :2]
  # Last two columns are (p_h, p_w)
  activations_hw = activations_mul[:, 2:]
  
  # First two columns are (anc_y, anc_x)
  anchors_xy = anchors[:, :2]
  # Last two columns are (anc_h, anc_w)
  anchors_hw = anchors[:, 2:]
  
  # Use the above formula to compute the bounding box centre and height and width
  bboxes_ctrs = anchors_hw * activations_xy + anchors_xy
  bboxes_hw = anchors_hw * torch.exp(activations_hw)
  
  # We have just computed two columns for the bounding box centre coordinate and two columns
  # for the bounding box height and width. Concatenate them to give us four columns
  bboxes = np.hstack((bboxes_ctrs, bboxes_hw))
  bboxes_t = torch.as_tensor(bboxes.astype('float32'), device=activations.device)
  return (bboxes_t)

#-----------------------------------------------------
# Inverse of above function. Convert from predicted bounding boxes to output activations
#-----------------------------------------------------
def bbox_to_prediction (bboxes, anchors):
  # First two columns are bounding box centres and last two columns are height and width
  bboxes_ctrs = bboxes[:, :2]
  bboxes_hw = bboxes[:, 2:]
  
  # First two columns are (anc_y, anc_x) and last two columns are (anc_h, anc_w)
  anchors_xy = anchors[:, :2]
  anchors_hw = anchors[:, 2:]
  
  # Use the above formula to compute the activations
  activations_xy = (bboxes_ctrs - anchors_xy) / anchors_hw
  activations_hw = np.log(bboxes_hw / anchors_hw)
  
  # Concatenate the two arrays column-wise
  activations_mul = np.hstack((activations_xy, activations_hw))
  
  # Reverse the scaling that we applied in the above function
  scale_val = np.array([0.1, 0.1, 0.2, 0.2], dtype=np.float32)
  activations = np.divide (activations_mul, scale_val)
  
  # Convert to Pytorch tensor on the same device
  activations_t = torch.as_tensor(activations.astype('float32'), device=bboxes.device)
  
  return (activations_t)

size=(3,4)
anchors_np = create_grid_centres(size)
anchors = torch.from_numpy(anchors_np).float()
anchors = torch.cat([anchors, torch.tensor([2/size[0],2/size[1]]).expand_as(anchors)], 1)
activations = torch.randn(size[0]*size[1], 4) * 0.1

bboxes = prediction_to_bbox (activations, anchors)

activations_reverse = bbox_to_prediction (bboxes, anchors)
print(activations, activations_reverse)

show_boxes(bboxes)

In [0]:
#-----------------------------------------------------
# Convert coordinates for an array of boxes from (ctr, height, width) to 
# (top left, bottom right)
#-----------------------------------------------------
def ctr_hw_to_tl_br(boxes):
  # First two columns are centre coordinates
  centre = boxes[:, :2]
  # Last two columns are height and width
  hw = boxes[:, 2:]
  
  # Top Left is (ctr_y - height / 2, ctr_x - width / 2)
  top_left = centre - hw/2
  # Bottom right is (ctr_y + height / 2, ctr_x + width / 2)
  bottom_right = centre + hw/2
  
  convert_boxes = np.hstack((top_left, bottom_right))
  convert_boxes_t = torch.as_tensor(convert_boxes.astype('float32'), device=boxes.device)
  return (convert_boxes_t)

#-----------------------------------------------------
# Inverse of the above function. Converts coordinates for an array of boxes 
# from (top left, bottom right) to (ctr, height, width)
#-----------------------------------------------------
def tl_br_to_ctr_hw(boxes):
  # First two columns are top-left, last two columns are bottom-right
  tl, br = boxes[:, :2], boxes[:, 2:]
  
  # Height and width are (bottom-right-y - top-left-y, bottom-right-x - top-left-x)
  hw = br - tl
  # Centre is (top-left-y + height / 2, top-left-x + width / 2)
  centre = tl + hw/2
  
  convert_boxes = np.hstack((centre, hw))
  convert_boxes_t = torch.as_tensor(convert_boxes.astype('float32'), device=boxes.device)
  return (convert_boxes_t)

#-----------------------------------------------------
# Calculate the area of intersection between an array of anchor boxes
# and an array of ground truth target boxes. Compute intersection
# for every combination of anchor and target
#-----------------------------------------------------
def intersection(anchors_ctrhw, targets_ctrhw):
  # First convert the coordinates from (ctr, hw) to (top_left, bottom_right)
  anchors_tlbr, targets_tlbr = ctr_hw_to_tl_br(anchors_ctrhw), ctr_hw_to_tl_br(targets_ctrhw)
  
  # First two columns are Top Left. Last two columns are Bottom Right 
  anchors_tl, targets_tl = anchors_tlbr[:, :2], targets_tlbr[:, :2]
  anchors_br, targets_br = anchors_tlbr[:, 2:], targets_tlbr[:, 2:]
  
  # We need to find the top-left and bottom-right coordinates of the intersection box
  # between every anchor and every target.
  #
  # To do this, we use broadcasting to replicate each element in anchor as well 
  # as each element in target, so that there is a pair for every combination. It
  # is similar to doing a cartesian product.
  #
  # anchor_tl is shaped (m, 2) and target_tl is shaped (n, 2). By introducing a None extra
  # dimension, they get reshaped to (m, 1, 2) and (1, n, 2) respectively
  #
  # When used in np.maximum, they both get broadcast to (m, n, 2) each. This gives
  # us m x n pairs which represents every combination of anchor and target
  # The 3rd dimension corresponds to the (x, y) coordinate
  #
  # Note that np.maximum() does an element-wise maximum between the anchor and
  # target arrays and returns a (m, n, 2) shape
  intersection_tl = np.maximum (anchors_tl[:, None, :], targets_tl[None, :, :])
  intersection_br = np.minimum (anchors_br[:, None, :], targets_br[None, :, :])
  
  # Height/Width is Bottom-Right - Top-Left. Set Height/Width to 0 if that number is
  # negative which means there is no overlap
  # Returns a (m, n, 2) shape with the height/width in the third dimension
  intersection_hw = np.maximum(intersection_br - intersection_tl, 0)
  
  # Multiply the height and width to get the area of the intersection
  # Returns (m, n) shape
  intersection_area = intersection_hw[..., 0] * intersection_hw[..., 1]
  intersection_area = intersection_area.to(device=targets_ctrhw.device)
  return (intersection_area)

#-----------------------------------------------------
# Calculate the IoU between an array of anchor boxes and an array of 
# ground truth target boxes using 
# IoU = area of intersection / area of union
#-----------------------------------------------------
def iou(anchors_ctrhw, targets_ctrhw):
  intersection_area = intersection(anchors_ctrhw, targets_ctrhw)
  anchors_area = anchors_ctrhw[:, 2] * anchors_ctrhw[:, 3]
  targets_area = targets_ctrhw[:, 2] * targets_ctrhw[:, 3]
  
  # We need to add up the area for every combination of anchor and target. We
  # use broadcasting to replicate each element in anchor as well 
  # as each element in target, so that there is a pair for every combination.
  #
  # anchors_area shape is (m,) and targets_area shape is (n,). By introducing a None extra
  # dimension, they get reshaped to (m, 1) and (1, n) respectively. During broadcasting
  # they get reshaped to (m, n) each. intersection_area is already (m, n)
  union_area = anchors_area[:, None] + targets_area[None, :] - intersection_area
  return (intersection_area / union_area)

#-----------------------------------------------------
# Match each anchor to targets with the following rules:
#   for each anchor we take the maximum overlap possible with any of the targets.
#   if that maximum overlap is less than 'bkg_thr', we match the anchor box to background, the classifier's target will be that class
#   if the maximum overlap is greater than 'match_thr', we match the anchor box to that ground truth object. The classifier's target will be the category of that target
#   if the maximum overlap is between 0.4 and 0.5, we ignore that anchor in our loss computation
# 
# We return the matches, with one value per anchor which is the index of the matching
# ground truth target, -1 if it matches 'background', and -2 is 'ignore'
#-----------------------------------------------------
def match_anchors(anchors_ctrhw, targets_ctrhw, match_thr=0.5, bkg_thr=0.4):
  # Initialise the 1d array of matches to a value of 'ignore'
  matches = torch.from_numpy(np.full([anchors_ctrhw.shape[0]], -2))
  # If target is empty, return
  if targets_ctrhw.numel() == 0: return matches

  # Get the IoU values
  iou_val = iou(anchors_ctrhw, targets_ctrhw)
  # Get maximum overlap per anchor (with any target)
  max_iou_val = np.maximum.reduce(iou_val, axis=1)
  # Get the index value of the target which has the maximum overlap
  max_iou_idx = np.argmax(iou_val, axis=1)
  
  # Match is 'background' if the max overlap is less than the background threshold
  matches[max_iou_val < bkg_thr] = -1
  
  # Match is the index value of the target if the max overlap is greater than the match threshold
  matches[max_iou_val > match_thr] = max_iou_idx[max_iou_val > match_thr]

  return (matches)

boxes_ctrhw = tensor([[5., 5., 2., 2.], [12., 12., 3., 3.]])
boxes_tlbr = ctr_hw_to_tl_br(boxes_ctrhw)
boxes_reverse = tl_br_to_ctr_hw(boxes_tlbr)
print(boxes_tlbr, boxes_reverse)

show_boxes(anchors)
targets = torch.tensor([[0.,0.,2.,2.], [-0.5,-0.5,1.,1.], [1/3,0.5,0.5,0.5]])
show_boxes(targets)
iou(anchors, targets)
match_anchors(anchors, targets)

In [0]:
#-----------------------------------------------------
# One-hot encode the targets with the convention that the class of index 0 is 
# the background, which is the absence of any other classes. That is coded by a row of zeros
#-----------------------------------------------------
def onehot_encode(values, n_classes):
  #db.set_trace()
  # Initialise array of encoded values to 0s
  n_values = values.shape[0]
  encoded = torch.as_tensor(np.zeros((n_values, n_classes), dtype=np.float32), device=values.device)
  
  # Define mask to filter out background values (ie value = 0)
  mask = values != 0

  # Row index for each element in the values array ie. [0, 1, 2, .... n_values]
  row_idx = torch.as_tensor(np.arange(n_values), device=values.device)

  # Column index to encode the values
  #
  # Subtract 1 because one-hot encoded values have to be 1-indexed not 0-indexed
  # In other words, a value of 3 is encoded as [0 0 1 ...] not [0 0 0 1 ...]
  col_idx = values - 1
  
  # Encode each value by setting the corresponding column to 1. We filter out the
  # background rows so those values will not be set to 1
  encoded[row_idx[mask], col_idx[mask]] = 1
  return (encoded)

onehot_encode(LongTensor([1,2,0,1,3]),3)
onehot_encode(tensor([3, 4, 0]), 5)

###  ===================== Retinanet model =====================
![alt text](https://github.com/fastai/course-v3/raw/e08c4b712f459f77247df2f6fb04465579cd6653/nbs/dl2/images/retinanet.png)

In [0]:
from fastai.vision import *
from fastai.vision.models.unet import _get_sfs_idxs, model_sizes, hook_outputs
import pdb
import IPython.core.debugger as db

#-----------------------------------------------------
# Subclass to handle the Upsample, the Lateral Connection and the Merging of the two
#-----------------------------------------------------
class LateralUpsampleMerge(nn.Module):
    def __init__(self, enc_lateral_channels, n_channels, hook):
        super().__init__()

        # (I think) A hook is a Fastai callback which allows you to tap into the internals
        # of the encoder and store the output of a layer for use later
        self.hook = hook

        # Lateral Conv layer, with 1x1 kernel and stride=1
        self.lateral = conv2d(enc_lateral_channels, n_channels, ks=1, bias=True)
    
    def forward(self, x):
        # Use the hook to access the output of the Cx layer and apply the lateral connection to it
        lateral = self.lateral(self.hook.stored)
        
        # Use nearest neighbour upsampling for the top-down pathway
        upsample = F.interpolate(x, self.hook.stored.shape[-2:], mode='nearest')

        # Merge the upsample and the lateral connection with a simple addition
        return (lateral + upsample)

#-----------------------------------------------------
# Create the RetinaNet model
#
# It is based on a pre-trained 'encoder' CNN which acts as the backbone. It adds
# several Conv layers on top of the existing Conv layers of the backbone.
# 
# The backbone's built-in layers C1 through C5 is the 'bottom-up pathway'.
# We create layers P3 through P7 as the 'top-down pathway'
#
# There is a lateral connection from some Cx layers to the corresponding Px layers. This lateral
# connection is 'merged' with the top-down connections.
# 
# Some Px layers then have a 'smoothing' Conv layer. Finally, after that, there is a 'detector head'
# which consists of two similar sub-networks, one for classifier and one for the bounding box regressor
#
# An nn.Module has two required methods - init() and forward()
# In the init() method, we only define all the layers. The layers stay standalone and are not actually 'wired' to
# one another here. That happens in the forward() method, when we actually pass in input data and do the 
# computation as data flows from one layer to another, which implicitly wires them up as per the computation.
#-----------------------------------------------------
class RetinaNet(nn.Module):
  def __init__(self, encoder:nn.Module, n_classes, n_anchors=9, final_bias=0.):
    super().__init__()

    # All layers that we create will have the same depth of 256
    self.n_channels = 256

    self.n_classes = n_classes
    self.encoder = encoder

    # BOTTOM-UP PATHWAY
    # The encoder has many more layers internally, but we designate five of them
    # as C1 through C5. These are the layers which downsample from the previous layer by a factor of 2
    # resulting in a size that is half the previous. For example, if the input image size is (256x256), the layers are:
    # C1 - 128 x 128 x 64
    # C2 - 64 x 64 x 256
    # C3 - 32 x 32 x 512
    # C4 - 16 x 16 x 1024
    # C5 - 8 x 8 x 2048
    #
    # Note that if the image size is (128x128) the sizes of C1 through C5 will be half of the numbers shown. 
    # eg. C1 will be 64 x 64 x 64 and so on. The depth will not change.

    # Get the depth of layers C3, C4 and C5 from the encoder
    c3_channels, c4_channels, c5_channels  = self._encoder_layers()
    
    # The depth of all other layers is the same
    p6_channels = self.n_channels
    p7_channels = self.n_channels
    
    # TOP DOWN PATHWAY
    # The layers are:
    # P7 - 2 x 2 x 256   - ReLU activation then 3x3 with stride=2 on C6.
    # P6 - 4 x 4 x 256   - 3x3 with stride=2 on C5.
    # P5 - 8 x 8 x 256   - 1x1 with stride=1 on C5.
    # P4 - 16 x 16 x 256 - up-sample from P5 merged with lateral connection from C4
    # P3 - 32 x 32 x 256 - up-sample from P4 merged with lateral connection from C3
 
    self.p5 = conv2d(c5_channels, self.n_channels, ks=1, bias=True)
    self.p6 = conv2d(c5_channels, p6_channels, stride=2, bias=True)
    self.p7 = nn.Sequential(nn.ReLU(), conv2d(p6_channels, p7_channels, stride=2, bias=True))
    self.p4 = LateralUpsampleMerge(c4_channels, self.n_channels, self.hooks[3])
    self.p3 = LateralUpsampleMerge(c3_channels, self.n_channels, self.hooks[2])

    # SMOOTHERS - 3x3 with stride=1 on P5, P4 and P3. So sizes remain the same. All depths are 256.
    self.smoothers = nn.ModuleList([conv2d(self.n_channels, self.n_channels, bias=True) for _ in range(3)])

    # DETECTOR HEAD - Consists of two similar sub-networks viz. Classifier and Bounding Box Regressor
    # Each sub-net consists of:
    #    4 Conv layers (with 3x3, stride=1, depth=256) followed by 
    #    Predictor Conv layer (with 3x3, stride=1, depth=KA for Classifier and 4A for Regressor
    # where K = number of classes and A = number of anchors
    #
    # The Detector Head is connected in parallel, with each of P3 through P7 ie.
    # P7 -> Detector
    # P6 -> Detector
    # P5 -> Smoother -> Detector
    # P4 -> Smoother -> Detector
    # P3 -> Smoother -> Detector
    #
    # The final step of the Detector is a Sigmoid activation. However we don't apply that
    # here in the model, but inside the loss function when we take a Sigmoid of the model's output
    
    self.classifier = self._head_detector(self.n_classes * n_anchors, final_bias=final_bias)
    self.bbox = self._head_detector(4 * n_anchors)

  #-----------------------------------------------------
  # Perform the model's calculations layer by layer, given the input data. Here is where we actually wire up the output
  # of one layer to the input of the next etc.
  #-----------------------------------------------------
  def forward(self, x):
    # Get the output after applying all the encoder layers to the input image, ending up with the output of the final C5 layer.
    c5 = self.encoder(x)
    
    # Outputs of each of the P3-P7 layers. We have only 5 layers, but we allocate a list of size 8, just so that we could index
    # the list using indices p[3] - p[7] for readability instead of allocating a list of size 5 and using indices p[0] - p[4]
    p = [None] * 8
    p[5] = self.p5(c5.clone())    # Compute P5 from C5
    p[6] = self.p6(c5)            # Compute P6 from C5
    p[7] = self.p7(p[6])          # Compute P7 from C6
    p[4] = self.p4(p[5])          # Compute P4 from P5 on the top-down path
    p[3] = self.p3(p[4])          # Compute P3 from P4 on the top-down path

    # Apply smoother layers to P3-P5
    for i, smoother in enumerate(self.smoothers):
      p[3 + i] = smoother(p[3 + i])
    
    # Apply the Detector head on each of P3-P7, and then concatenate all the outputs.
    # In Pytorch, Conv2d() shape is (batch, depth channel, height, width), which we permute to (batch, height, width, depth)
    #    For the classifier, the depth is KA which is then reshaped to (batch, height * width * A, K)
    #    For the bbox regressor, the depth is 4A which is then reshaped to (batch, height * width * A, 4)
    # Now the outputs from P3-P7 can be concatenated together.
    #
    # In other words, every grid cell (ie. height x width) and every anchor for each grid cell (ie. A) gets unrolled and stacked linearly one below the other in the 2nd dimension
    
    class_pred = torch.cat([self.classifier(act).permute(0,2,3,1).contiguous().view(act.size(0), -1, self.n_classes) for act in p[3:8]],1)
    bbox_pred = torch.cat([self.bbox(act).permute(0,2,3,1).contiguous().view(act.size(0), -1, 4) for act in p[3:8]],1)
    
    # Get a list of all the height and width of P3-P7 as [[32, 32], [16, 16], ...]
    p_sizes = [[layer.size(2), layer.size(3)] for layer in p[3:8]]

    return([class_pred, bbox_pred, p_sizes])

  #-----------------------------------------------------
  # Get information about some of the internal layers of the encoder, which we're calling C1-C5
  #-----------------------------------------------------
  def _encoder_layers(self):
    # Assume image size
    imsize = (256,256)
    
    # Use a Fastai utility function from the Unet model, to get the sizes of all the main Conv layers of the encoder
    # These layers are: [1, 64, 128, 128], [1, 64, 128, 128], [1, 64, 128, 128], [1, 64, 64, 64], [1, 256, 64, 64]), 
    #                      [1, 512, 32, 32], [1, 1024, 16, 16], [1, 2048, 8, 8]
    enc_sizes = model_sizes(self.encoder, size=imsize)
    
    # Use another Fastai Unet function to get the indices of the layers in the above list, where the size changes ie. [2, 4, 5, 6]
    # These are the layers which we are designating as C1-C4, and then the final layer ie. index #7, is designated as C5
    enc_layer_idxs = _get_sfs_idxs(enc_sizes)
    
    # Use another Fastai Unet function to get the 'hooks' for each of the C1-C4 layers
    self.hooks = hook_outputs([self.encoder[i] for i in enc_layer_idxs])
    
    # Append -1 to get the index of the last layer C5. Then get the depth of all layers C1-C5
    enc_layer_idxs.append(-1)
    enc_layer_channels = [enc_sizes[i][1] for i in enc_layer_idxs]
    
    # Return the depths of C3-C5 (we don't use C1 and C2 for anything)
    return (enc_layer_channels[2:5])

  #-----------------------------------------------------
  # Build the Detector Head
  #-----------------------------------------------------
  def _head_detector(self, n_predictor_channels, final_bias=0., n_conv=4):
    # 4 Conv layers with 1x1 kernels and stride=1
    layers = [conv_layer(self.n_channels, self.n_channels, bias=True, norm_type=None) for _ in range(n_conv)]
    
    # Final predictor layer, also a 1x1 Conv layer with stride=1
    predictor_layer = conv2d(self.n_channels, n_predictor_channels, bias=True)
    
    # Initialise the Weigths and Bias of the predictor layer.
    predictor_layer.bias.data.zero_().add_(final_bias)
    predictor_layer.weight.data.fill_(0)
    
    # Build a Sequential module from all five layers
    layers += [predictor_layer]
    return nn.Sequential(*layers)

  def __del__(self):
    if hasattr(self, "sfs"): self.sfs.remove()

# Build the Resnet50 encoder, and then the RetinaNet model
encoder = create_body(models.resnet50, cut=-2)
# This is temporary. Initialise the random number seed so that we can get reproducible results between my version of
# the Retinanet model and the Fastai version
torch.manual_seed(0)
# Why final_bias=-4? That's because we want the network to predict background easily at the beginning (since it's the 
# most common class). At first the final convolution of the classifier is initialized with weights=0 and that bias, so 
# it will return -4 for everyone. If go though a sigmoid it'll give a corresponding probability of 0.02 roughly.
model = RetinaNet(encoder, data.c, final_bias=-4)
test_img = torch.randn(1, 3, 256, 256)
out = model(test_img)

In [0]:
class LateralUpsampleMerge_f(nn.Module):
    "Merge the features coming from the downsample path (in `hook`) with the upsample path."
    def __init__(self, ch, ch_lat, hook):
        super().__init__()
        self.hook = hook
        self.conv_lat = conv2d(ch_lat, ch, ks=1, bias=True)
    
    def forward(self, x):
        #db.set_trace()
        a = self.conv_lat(self.hook.stored)
        b = F.interpolate(x, self.hook.stored.shape[-2:], mode='nearest')
        c = a + b
        return c
      
class RetinaNet_f(nn.Module):
    "Implements RetinaNet from https://arxiv.org/abs/1708.02002"
    def __init__(self, encoder:nn.Module, n_classes, final_bias=0., chs=256, n_anchors=9, flatten=True):
        super().__init__()
        self.n_classes,self.flatten = n_classes,flatten
        imsize = (256,256)
        sfs_szs = model_sizes(encoder, size=imsize)
        sfs_idxs = list(reversed(_get_sfs_idxs(sfs_szs)))
        self.sfs = hook_outputs([encoder[i] for i in sfs_idxs])
        self.encoder = encoder
        self.c5top5 = conv2d(sfs_szs[-1][1], chs, ks=1, bias=True)
        self.c5top6 = conv2d(sfs_szs[-1][1], chs, stride=2, bias=True)
        self.p6top7 = nn.Sequential(nn.ReLU(), conv2d(chs, chs, stride=2, bias=True))
        print(sfs_idxs, sfs_idxs[0:2], sfs_szs[2], sfs_szs[4], sfs_szs[5], sfs_szs[6])
        self.merges = nn.ModuleList([LateralUpsampleMerge_f(chs, sfs_szs[idx][1], hook) 
                                     for idx,hook in zip(sfs_idxs[0:2], self.sfs[0:2])])
        self.smoothers = nn.ModuleList([conv2d(chs, chs, 3, bias=True) for _ in range(3)])
        self.classifier = self._head_subnet(n_classes, n_anchors, final_bias, chs=chs)
        self.box_regressor = self._head_subnet(4, n_anchors, 0., chs=chs)
        
    def _head_subnet(self, n_classes, n_anchors, final_bias=0., n_conv=4, chs=256):
        "Helper function to create one of the subnet for regression/classification."
        layers = [conv_layer(chs, chs, bias=True, norm_type=None) for _ in range(n_conv)]
        layers += [conv2d(chs, n_classes * n_anchors, bias=True)]
        layers[-1].bias.data.zero_().add_(final_bias)
        layers[-1].weight.data.fill_(0)
        return nn.Sequential(*layers)
    
    def _apply_transpose(self, func, p_states, n_classes):
        #Final result of the classifier/regressor is bs * (k * n_anchors) * h * w
        #We make it bs * h * w * n_anchors * k then flatten in bs * -1 * k so we can contenate
        #all the results in bs * anchors * k (the non flatten version is there for debugging only)
        if not self.flatten: 
            sizes = [[p.size(0), p.size(2), p.size(3)] for p in p_states]
            return [func(p).permute(0,2,3,1).view(*sz,-1,n_classes) for p,sz in zip(p_states,sizes)]
        else:
            return torch.cat([func(p).permute(0,2,3,1).contiguous().view(p.size(0),-1,n_classes) for p in p_states],1)
    
    def forward(self, x):
        c5 = self.encoder(x)
        p_states = [self.c5top5(c5.clone()), self.c5top6(c5)]
        p_states.append(self.p6top7(p_states[-1]))
        for merge in self.merges: p_states = [merge(p_states[0])] + p_states
        for i, smooth in enumerate(self.smoothers[:3]):
            p_states[i] = smooth(p_states[i])
        return [self._apply_transpose(self.classifier, p_states, self.n_classes), 
                self._apply_transpose(self.box_regressor, p_states, 4),
                [[p.size(2), p.size(3)] for p in p_states]]
    
    def __del__(self):
        if hasattr(self, "sfs"): self.sfs.remove()
          
torch.manual_seed(0)
model_f = RetinaNet_f(encoder, 10, final_bias=-4)
out_f = model_f(test_img)

**===================== Loss Function =====================**

In [0]:
#-----------------------------------------------------
# The RetinaNet model spits out an absurdly high number of predictions: for the features P3 to P7 with an image size of 256, we 
# have 32*32 + 16*16 + 8*8 + 4*4 +2*2 locations possible in one of the five feature maps, which gives 1,364 possible detections, 
# multiplied by the number of anchors we choose to attribute to each location (9 below), which makes 12,276 possible hits.
#
# The model itself isn't actually aware of things like grid cells, anchor boxes and coordinates. It just generates a lot of numbers as its output
# predictions. We choose to interpret those numbers as predicted bounding boxes per anchor box and so on, and all of that intelligence
# is in the Loss Function.
#
# Many of the 12276 predicted bounding boxes aren't going to correspond to any object in the picture, and we need to somehow match 
# all those predictions to either nothing or a given bounding box in the picture.
#
# The Loss Function has two parts, one for the classifier and one for the regressor. For the regression, we will use the 
# L1 (potentially smoothed) loss between the predicted activations for an anchor that matches a given object (we ignore 
# the no match or matches to background) and the corresponding bounding box (after going through bbox_to_prediction).
# For the classification, we use the focal loss, which is a variant of the binary cross entropy used when we have a lot 
# imbalance between the classes to predict (here we will very often have to predict 'background').
#-----------------------------------------------------
class RetinaNetFocalLoss(nn.Module):
    
    def __init__(self, gamma:float=2., alpha:float=0.25, scales:Collection[float]=None, 
                 ratios:Collection[float]=None, reg_loss:LossFunction=F.smooth_l1_loss):
      super().__init__()
      self.gamma, self.alpha, self.reg_loss = gamma, alpha, reg_loss
      self.scales, self.ratios = scales, ratios
      self.pad_idx = 0

    #-----------------------------------------------------
    # The grid sizes are different depending on the size of the input image. So figure out whether
    # the grid sizes for the current image are different than what the Loss Function was initialised with. 
    # If the sizes are different, we have to recreate the anchors.
    #
    # We return True if the passed in grid sizes are different than what we had saved previously
    #-----------------------------------------------------
    def _change_anchors(self, sizes:Sizes) -> bool:
        if not hasattr(self, 'sizes'): return True

        # Check if any of the grid sizes are different than what we had saved previously
        for sz1, sz2 in zip(self.sizes, sizes):
            if sz1[0] != sz2[0] or sz1[1] != sz2[1]: return True

        return False

    #-----------------------------------------------------
    # Create the anchor boxes
    #-----------------------------------------------------
    def _create_anchors(self, sizes:Sizes, device:torch.device):
      self.sizes = sizes
      self.anchors = create_anchor_boxes(self.sizes, self.ratios, self.scales).to(device)

    #-----------------------------------------------------
    # This is meant to reverse the bb_pad_collate function used when creating the
    # databunch. Don't fully understand the pad_idx logic
    #-----------------------------------------------------
    def _unpad(self, bbox_tgt, class_tgt):
        i = torch.min(torch.nonzero(class_tgt-self.pad_idx))
        return tl_br_to_ctr_hw(bbox_tgt[i:]), class_tgt[i:]-1+self.pad_idx

    #-----------------------------------------------------
    # Calculate the bounding box loss
    #-----------------------------------------------------
    def _bbox_loss(self, bp, bt, matches):
      # Get all anchor boxes which match a ground truth object (ie. we filter out matches with background)
      bbox_mask = matches >= 0
      
      # Count of anchor boxes which match some ground truth object
      # Note that multiple anchor boxes might match the same ground truth object
      n_anchors_matching_target = bbox_mask.sum()
      
      if n_anchors_matching_target != 0:
        # Get the corresponding bounding box predictions 
        # bp reduces in length from (number of anchor boxes, 4) -> (number of matching anchor boxes, 4)
        bp = bp[bbox_mask]
        
        # Get the corresponding bounding box ground truth targets
        # This changes the length of bt from (number of ground truth objects, 4) -> (number of matching anchor boxes, 4)
        # Note that matches contains the index of the object in the target array ie. which-th ground truth object
        bt = bt[matches[bbox_mask]]
        
        # Convert the target bounding box coordinates to activations
        bt = bbox_to_prediction (bt, self.anchors[bbox_mask])
        
        # Calculate loss using L1 loss
        bb_loss = self.reg_loss(bp, bt)
      else:
        # Zero loss if no anchor boxes matched any ground truth
        bb_loss = 0
      
      return(bb_loss, n_anchors_matching_target)

    #-----------------------------------------------------
    # Calculate the classification loss
    #-----------------------------------------------------
    def _class_loss(self, cp, ct, matches):
      # Get the mask of all anchor boxes which match either a ground truth object or background
      class_mask = matches >= 0
      
      # We treat background as Class 0. So now the real classes which went from 0 to n_classes-1
      # will now go from 1 to n_classes.
      ct = ct + 1
      
      # We also prepend one element to the beginning of the class target array. This element
      # corresponds to 'background' and will have a value of 0
      # ct is now [0, class of first ground truth object, class of second ground truth object, ...]
      # ct has shape (number of ground truth objects + 1)
      ct = torch.cat([ct.new_zeros(1), ct])
      
      # For the matching anchor boxes in mask, get the corresponding class targets
      # This changes the length of ct from (number of ground truth objects + 1) -> (number of matching anchor boxes)
      ct = ct[matches[class_mask]]
      
      # One-hot encode each value in ct. et has shape (number of matching anchor boxes, number of classes)
      et = onehot_encode (ct, cp.shape[1])
      
      # For the matching anchor boxes in mask, get the corresponding class predictions
      # cp reduces in length from (number of anchor boxes) -> (number of matching anchor boxes)
      cp = cp[class_mask]
      
      # Calculate loss using focal loss
      class_loss = self._focal_loss(cp, et)
      return(class_loss)

    def _focal_loss(self, cp, et):
      sig_cp = torch.sigmoid(cp)

      y_1 = self.alpha * et * ((1 - sig_cp) ** self.gamma) * torch.log(sig_cp)
      y_0 = (1 - self.alpha) * (1 - et) * (sig_cp ** self.gamma) * torch.log(1 - sig_cp)
      f_loss = torch.sum (- y_1 - y_0) 

      f_loss_fizyr_one, f_loss_fizyr_two = self._focal_loss_fizyr(sig_cp, et)
      f_loss_kdp_t = self._focal_loss_kdp_t(sig_cp, et)
      f_loss_kdbasic_log = self._focal_loss_kdbasic_log(sig_cp, et)
      #print(f_loss, f_loss_fizyr_one, f_loss_fizyr_two, f_loss_kdp_t, f_loss_kdbasic_log)

      return(f_loss)

    #-----------------------------------------------------
    # This is the main method which drives the loss calculation. We are given the final
    # outputs from the model along with the ground truth classifier and bounding box targets.
    #
    # We match each anchor box to a ground truth bounding box, or to the background. If it matches
    # a bounding box, then the anchor box is also associated with the corresponding class of that
    # bounding box.
    #
    # We then calculate the bounding box loss and the classification loss. For the classification
    # loss we use a focal loss calculation which is an improved version of binary cross entropy loss
    #-----------------------------------------------------
    def forward(self, output, bbox_tgts, class_tgts):
        # Bounding Box Target shape is (batch, number of ground truth objects, 4)
        # Class Target shape is (batch, number of ground truth objects)
        
        # Get the output class predictions and bounding box predictions
        # Class Pred shape is (batch, number of anchor boxes, number of classes)
        # Bounding Box Pred shape is (batch, number of anchor boxes, 4)
        class_pred, bbox_pred, p_sizes = output
        
        # Create anchor boxes for the grid sizes in p_sizes
        if self._change_anchors(p_sizes): self._create_anchors(p_sizes, class_pred.device)
          
        # Initialise the total loss
        total_loss = 0.
        
        # Go through the prediction and target values for each image
        for (cp, bp, ct, bt) in zip(class_pred, bbox_pred, class_tgts, bbox_tgts):
          # Shapes are:
          # cp - (number of anchor boxes, number of classes). Gives probability for each class
          # bp - (number of anchor boxes, 4)
          # ct - (number of ground truth objects)
          # bt - (number of ground truth objects, 4)
          
          # Reverse the padding that was applied when the Databunch was created
          bt, ct = self._unpad(bt,ct)
          
          # Match the anchors with ground truth bounding boxes
          # Matches has shape (number of anchor boxes). Gives the matching class, or background (or Ignore)
          matches = match_anchors(self.anchors, bt)
          
          # Calculate the bounding box loss
          bb_loss, n_anchors_matching_target = self._bbox_loss(bp, bt, matches)
         
          # In 'matches', anchor boxes which match 'background' have a value of -1. And anchor boxes which match
          # an object have a value of the index of the object in the target array ie. which-th ground truth 
          # object in class_tgts and bbox_tgts. We add 1 to each value, so that now 'background' matches have a 
          # value of 0. The class_tgts now has to be extended, which happens inside _class_loss()
          matches = matches + 1
          
          # Calculate the classification loss
          class_loss = self._class_loss(cp, ct, matches) / torch.clamp(n_anchors_matching_target, min=1.)
          
          # Loss for this image is Bounding Box Loss + Classification Loss
          image_loss = bb_loss + class_loss
          
          # Total loss (for all images) is the sum of loss for each image
          total_loss += image_loss

        # Finally we get the average loss by dividing the summed total loss by the number of images
        n_images = class_tgts.shape[0]
        return(total_loss / n_images)

    def _focal_loss_fizyr (self, sig_cp, et):
      alpha_k = torch.ones_like(et) * self.alpha
      alpha_k = torch.where(et > 0., alpha_k, 1 - alpha_k)
      p_k = torch.where(et > 0., 1 - sig_cp, sig_cp)
      weight_k = alpha_k * (p_k ** self.gamma)
      weight_k = weight_k.detach()

      f_loss_kw = torch.sum (weight_k * F.binary_cross_entropy(sig_cp, et, reduction='none'))
      f_loss_kred = F.binary_cross_entropy(sig_cp, et, weight_k, reduction='sum')
      return (f_loss_kw, f_loss_kred)
    
    def _focal_loss_kdp_t (self, sig_cp, et):
      alpha_t = self.alpha * et + (1 - self.alpha) * (1 - et)
      p_t = sig_cp * et + (1 - sig_cp) * (1 - et)
      weight = alpha_t * ((1 - p_t) ** self.gamma)
      weight = weight.detach()
      #f_loss_bce = F.binary_cross_entropy(p_t, et, weight)
      f_loss_bce_red = F.binary_cross_entropy(sig_cp, et, weight, reduction='sum')
      return (f_loss_bce_red)

    def _focal_loss_kdbasic_log (self, sig_cp, et):
      alpha_t = self.alpha * et + (1 - self.alpha) * (1 - et)
      p_t = sig_cp * et + (1 - sig_cp) * (1 - et)
      weight = alpha_t * ((1 - p_t) ** self.gamma)
      weight = weight.detach()

      f_loss_log = torch.sum (-weight * torch.log(p_t))
      return (f_loss_log)

    def _focal_loss_fastai (self, cp, sig_cp, et):
      weight_f = et * (1-sig_cp) + (1-et) * sig_cp
      weight_f = weight_f.detach()
      alpha_f = (1-et) * self.alpha + et * (1-self.alpha)
      weight_f.pow_(self.gamma).mul_(alpha_f)
      clas_loss = F.binary_cross_entropy_with_logits(cp, et, weight_f, reduction='sum')
      return (clas_loss)

ratios = [1/2,1,2]
scales = [1,2**(-1/3), 2**(-2/3)] 
bbox_targ = torch.randn(1, 5, 4)
clas_tgt = torch.randint(1, data.c + 1, (1, 5))

crit = RetinaNetFocalLoss(scales=scales, ratios=ratios)
loss = crit(out, bbox_targ, clas_tgt)

In [0]:
def tlbr2cthw(boxes):
    "Convert top/left bottom/right format `boxes` to center/size corners."
    center = (boxes[:,:2] + boxes[:,2:])/2
    sizes = boxes[:,2:] - boxes[:,:2]
    return torch.cat([center, sizes], 1)
  
def cthw2tlbr(boxes):
    "Convert center/size format `boxes` to top/left bottom/right corners."
    top_left = boxes[:,:2] - boxes[:,2:]/2
    bot_right = boxes[:,:2] + boxes[:,2:]/2
    return torch.cat([top_left, bot_right], 1)

def bbox_to_activ(bboxes, anchors, flatten=True):
    "Return the target of the model on `anchors` for the `bboxes`."
    if flatten:
        t_centers = (bboxes[...,:2] - anchors[...,:2]) / anchors[...,2:] 
        t_sizes = torch.log(bboxes[...,2:] / anchors[...,2:] + 1e-8) 
        return torch.cat([t_centers, t_sizes], -1).div_(bboxes.new_tensor([[0.1, 0.1, 0.2, 0.2]]))
    else: return [activ_to_bbox(act,anc) for act,anc in zip(acts, anchors)]
    return res

def encode_class(idxs, n_classes):
    target = idxs.new_zeros(len(idxs), n_classes).float()
    mask = idxs != 0
    i1s = LongTensor(list(range(len(idxs))))
    target[i1s[mask],idxs[mask]-1] = 1
    return target

def create_grid(size):
    "Create a grid of a given `size`."
    H, W = size if is_tuple(size) else (size,size)
    grid = FloatTensor(H, W, 2)
    linear_points = torch.linspace(-1+1/W, 1-1/W, W) if W > 1 else tensor([0.])
    grid[:, :, 1] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, 0])
    linear_points = torch.linspace(-1+1/H, 1-1/H, H) if H > 1 else tensor([0.])
    grid[:, :, 0] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, 1])
    return grid.view(-1,2)

def create_anchors(sizes, ratios, scales, flatten=True):
    "Create anchor of `sizes`, `ratios` and `scales`."
    aspects = [[[s*math.sqrt(r), s*math.sqrt(1/r)] for s in scales] for r in ratios]
    aspects = torch.tensor(aspects).view(-1,2)
    anchors = []
    for h,w in sizes:
        #4 here to have the anchors overlap.
        sized_aspects = 4 * (aspects * torch.tensor([2/h,2/w])).unsqueeze(0)
        base_grid = create_grid((h,w)).unsqueeze(1)
        n,a = base_grid.size(0),aspects.size(0)
        ancs = torch.cat([base_grid.expand(n,a,2), sized_aspects.expand(n,a,2)], 2)
        anchors.append(ancs.view(h,w,a,4))
    return torch.cat([anc.view(-1,4) for anc in anchors],0) if flatten else anchors

def intersection_f(anchors, targets):
    "Compute the sizes of the intersections of `anchors` by `targets`."
    ancs, tgts = cthw2tlbr(anchors), cthw2tlbr(targets)
    a, t = ancs.size(0), tgts.size(0)
    ancs, tgts = ancs.unsqueeze(1).expand(a,t,4), tgts.unsqueeze(0).expand(a,t,4)
    top_left_i = torch.max(ancs[...,:2], tgts[...,:2])
    bot_right_i = torch.min(ancs[...,2:], tgts[...,2:])
    sizes = torch.clamp(bot_right_i - top_left_i, min=0) 
    return sizes[...,0] * sizes[...,1]

def IoU_values(anchors, targets):
    "Compute the IoU values of `anchors` by `targets`."
    inter = intersection_f(anchors, targets)
    anc_sz, tgt_sz = anchors[:,2] * anchors[:,3], targets[:,2] * targets[:,3]
    union = anc_sz.unsqueeze(1) + tgt_sz.unsqueeze(0) - inter
    return inter/(union+1e-8)

def match_anchors_f(anchors, targets, match_thr=0.5, bkg_thr=0.4):
    "Match `anchors` to targets. -1 is match to background, -2 is ignore."
    matches = anchors.new(anchors.size(0)).zero_().long() - 2
    if targets.numel() == 0: return matches
    ious = IoU_values(anchors, targets)
    vals,idxs = torch.max(ious,1)
    matches[vals < bkg_thr] = -1
    matches[vals > match_thr] = idxs[vals > match_thr]
    #Overwrite matches with each target getting the anchor that has the max IoU.
    #vals,idxs = torch.max(ious,0)
    #If idxs contains repetition, this doesn't bug and only the last is considered.
    #matches[idxs] = targets.new_tensor(list(range(targets.size(0)))).long()
    return matches

class RetinaNetFocalLoss_f(nn.Module):
    
    def __init__(self, gamma:float=2., alpha:float=0.25,  pad_idx:int=0, scales:Collection[float]=None, 
                 ratios:Collection[float]=None, reg_loss:LossFunction=F.smooth_l1_loss):
        super().__init__()
        self.gamma,self.alpha,self.pad_idx,self.reg_loss = gamma,alpha,pad_idx,reg_loss
        self.scales = ifnone(scales, [1,2**(-1/3), 2**(-2/3)])
        self.ratios = ifnone(ratios, [1/2,1,2])
        
    def _change_anchors(self, sizes:Sizes) -> bool:
        if not hasattr(self, 'sizes'): return True
        for sz1, sz2 in zip(self.sizes, sizes):
            if sz1[0] != sz2[0] or sz1[1] != sz2[1]: return True
        return False
    
    def _create_anchors(self, sizes:Sizes, device:torch.device):
        self.sizes = sizes
        anchors = create_anchors(sizes, self.ratios, self.scales)
        self.anchors = anchors.to(device)
    
    def _unpad(self, bbox_tgt, clas_tgt):
        #db.set_trace()
        i = torch.min(torch.nonzero(clas_tgt-self.pad_idx))
        return tlbr2cthw(bbox_tgt[i:]), clas_tgt[i:]-1+self.pad_idx
    
    def _focal_loss(self, clas_pred, clas_tgt):
        #db.set_trace()
        encoded_tgt = encode_class(clas_tgt, clas_pred.size(1))
        ps = torch.sigmoid(clas_pred.detach())
        weights = encoded_tgt * (1-ps) + (1-encoded_tgt) * ps
        alphas = (1-encoded_tgt) * self.alpha + encoded_tgt * (1-self.alpha)
        weights.pow_(self.gamma).mul_(alphas)
        clas_loss = F.binary_cross_entropy_with_logits(clas_pred, encoded_tgt, weights, reduction='sum')
        print(clas_loss)
        return clas_loss
        
    def _one_loss(self, clas_pred, bbox_pred, clas_tgt, bbox_tgt):
        db.set_trace()
        bbox_tgt, clas_tgt = self._unpad(bbox_tgt, clas_tgt)
        matches = match_anchors_f(self.anchors, bbox_tgt)
        bbox_mask = matches>=0
        if bbox_mask.sum() != 0:
            bbox_pred = bbox_pred[bbox_mask]
            bbox_tgt = bbox_tgt[matches[bbox_mask]]
            b = bbox_to_activ(bbox_tgt, self.anchors[bbox_mask]).float()
            #db.set_trace()
            bb_loss = self.reg_loss(bbox_pred, b)
        else: bb_loss = 0.
        matches.add_(1)
        clas_tgt = clas_tgt + 1
        clas_mask = matches>=0
        clas_pred = clas_pred[clas_mask]
        ddd = clas_tgt.new_zeros(1).long()
        #db.set_trace()
        clas_tgt = torch.cat([ddd, clas_tgt])
        clas_tgt = clas_tgt[matches[clas_mask]]
        return bb_loss + self._focal_loss(clas_pred, clas_tgt)/torch.clamp(bbox_mask.sum(), min=1.)
    
    def forward(self, output, bbox_tgts, clas_tgts):
        #db.set_trace()
        clas_preds, bbox_preds, sizes = output
        if self._change_anchors(sizes): self._create_anchors(sizes, clas_preds.device)
        n_classes = clas_preds.size(2)
        return sum([self._one_loss(cp, bp, ct, bt)
                    for (cp, bp, ct, bt) in zip(clas_preds, bbox_preds, clas_tgts, bbox_tgts)])/clas_tgts.size(0)

crit_f = RetinaNetFocalLoss_f(scales=scales, ratios=ratios)
loss_f = crit_f(out_f, bbox_targ, clas_tgt)

In [0]:
def retina_net_split(model):
    groups = [list(model.encoder.children())[:6], list(model.encoder.children())[6:]]
    return groups + [list(model.children())[1:]]

learn = Learner(data, model, loss_func=crit)
learn = learn.split(retina_net_split)
learn.freeze()
learn.lr_find()
#learn.recorder.plot(skip_end=5)
learn.recorder.plot()

In [0]:
learn.fit_one_cycle(5, 1e-4)
learn.save('stage1-128')

In [0]:
learn.unfreeze()
learn.fit_one_cycle(10, slice(1e-6, 5e-5))
learn.save('stage2-128')

In [0]:
learn.data = get_data(32,192)
learn.freeze()
learn.lr_find()
learn.recorder.plot()
learn.fit_one_cycle(5, 1e-4)

In [0]:
learn.unfreeze()
learn.fit_one_cycle(10, slice(1e-6, 5e-5))

In [0]:
learn.data = get_data(24,256)
learn.freeze()
learn.fit_one_cycle(5, 1e-4)

In [0]:
learn.unfreeze()
learn.fit_one_cycle(10, slice(1e-6, 5e-5))

In [0]:
def _draw_outline(o:Patch, lw:int):
    "Outline bounding box onto image `Patch`."
    o.set_path_effects([patheffects.Stroke(
        linewidth=lw, foreground='black'), patheffects.Normal()])

def draw_rect(ax:plt.Axes, b:Collection[int], color:str='white', text=None, text_size=14):
    "Draw bounding box on `ax`."
    patch = ax.add_patch(patches.Rectangle(b[:2], *b[-2:], fill=False, edgecolor=color, lw=2))
    _draw_outline(patch, 4)
    if text is not None:
        patch = ax.text(*b[:2], text, verticalalignment='top', color=color, fontsize=text_size, weight='bold')
        _draw_outline(patch,1)

def unpad(bbox_tgt, class_tgt, pad_idx=0):
  i = torch.min(torch.nonzero(class_tgt-pad_idx))
  return tl_br_to_ctr_hw(bbox_tgt[i:]), class_tgt[i:]-1+pad_idx

def process_output (output, i, detect_thresh=0.25):
  class_preds, bbox_activ_preds, p_sizes = output
  class_pred, bbox_activ_pred = class_preds[i], bbox_activ_preds[i]
  anchors = create_anchor_boxes(p_sizes, ratios, scales).to(class_pred.device)
  bbox_pred = prediction_to_bbox (bbox_activ_pred, anchors)
  class_pred = torch.sigmoid(class_pred)
  class_max, class_idx = torch.max(class_pred, 1)
  detect_mask = class_max > detect_thresh

  bbox_detect_ctrhw = bbox_pred[detect_mask]
  bbox_detect_tlbr = torch.clamp (ctr_hw_to_tl_br(bbox_detect_ctrhw), min=-1, max=1)
  bbox_detect = tl_br_to_ctr_hw(bbox_detect_tlbr)
  class_detect = class_idx[detect_mask]
  class_max_detect = class_max[detect_mask]
  if class_detect.numel() == 0: return [],[],[]
  return (bbox_detect, class_max_detect, class_detect)
  #bbox_detect_tlbr = ctr_hw_to_tl_br(bbox_detect)

def show_preds(img, output, idx, detect_thresh=0.25, classes=None, ax=None):
    bbox_pred, scores, preds = process_output(output, idx, detect_thresh)
    if len(scores) != 0:
      to_keep = nms(bbox_pred, scores)
      to_keep_f = nms_f(bbox_pred, scores)
      print (to_keep, to_keep_f)
      bbox_pred, preds, scores = bbox_pred[to_keep].cpu(), preds[to_keep].cpu(), scores[to_keep].cpu()
      t_sz = torch.Tensor([*img.size])[None].float()
      bbox_pred[:,:2] = bbox_pred[:,:2] - bbox_pred[:,2:]/2
      bbox_pred[:,:2] = (bbox_pred[:,:2] + 1) * t_sz/2
      bbox_pred[:,2:] = bbox_pred[:,2:] * t_sz
      bbox_pred = bbox_pred.long()
    if ax is None: _, ax = plt.subplots(1,1)
    img.show(ax=ax)
    for bbox, c, scr in zip(bbox_pred, preds, scores):
        txt = str(c.item()) if classes is None else classes[c.item()+1]
        draw_rect(ax, [bbox[1],bbox[0],bbox[3],bbox[2]], text=f'{txt} {scr:.2f}')

def nms(bbox_detect, scores, thresh=0.3):
  keep = []
  bbox_remain, scores_remain = bbox_detect, scores
  #idxs = torch.argsort(scores)
  #keep = bbox_detect[idxs[0]]
  #bbox_remain, scores_remain = bbox_detect[idxs[1:]], scores[idxs[1:]]
  while (len(bbox_remain) > 0):
    idxs = scores_remain.argsort(descending=True)
    keep.append(idxs[0])
    bbox_max = bbox_remain[idxs[0]].view(-1, 4)
    bbox_remain, scores_remain = bbox_remain[idxs[1:]], scores_remain[idxs[1:]]
    iou_val = iou(bbox_max, bbox_remain).view(-1)
    bbox_remain, scores_remain = bbox_remain[iou_val < thresh], scores_remain[iou_val < thresh]
  return(LongTensor(keep))

def nms_f(boxes, scores, thresh=0.3):
    idx_sort = scores.argsort(descending=True)
    boxes, scores = boxes[idx_sort], scores[idx_sort]
    to_keep, indexes = [], torch.LongTensor(range_of(scores))
    while len(scores) > 0:
        to_keep.append(idx_sort[indexes[0]])
        iou_vals = iou(boxes, boxes[:1]).squeeze()
        mask_keep = iou_vals < thresh
        if len(mask_keep.nonzero()) == 0: break
        boxes, scores, indexes = boxes[mask_keep], scores[mask_keep], indexes[mask_keep]
    return LongTensor(to_keep)

def show_results(learn, start=0, n=5, detect_thresh=0.35, figsize=(10,25)):
    x,y = learn.data.one_batch(DatasetType.Valid, cpu=False)
    with torch.no_grad():
        z = learn.model.eval()(x)
    _,axs = plt.subplots(n, 2, figsize=figsize)
    for i in range(n):
        img,bbox = learn.data.valid_ds[start+i]
        img.show(ax=axs[i,0], y=bbox)
        show_preds(img, z, start+i, detect_thresh=detect_thresh, classes=learn.data.classes, ax=axs[i,1])

img,target = next(iter(data.valid_dl))
with torch.no_grad():
    output = learn.model(img)
idx = 0
img2 = data.valid_ds[idx][0]
show_preds(img2, output, idx, detect_thresh=0.25, classes=data.classes)

In [0]:
show_results(learn, start=10)

In [0]:
def get_predictions(output, idx, detect_thresh=0.05):
    bbox_pred, scores, preds = process_output(output, idx, detect_thresh)
    if len(scores) == 0: return [],[],[]
    to_keep = nms(bbox_pred, scores)
    return bbox_pred[to_keep], preds[to_keep], scores[to_keep]

def compute_ap(precision, recall):
    "Compute the average precision for `precision` and `recall` curve."
    recall = np.concatenate(([0.], list(recall), [1.]))
    precision = np.concatenate(([0.], list(precision), [0.]))
    for i in range(len(precision) - 1, 0, -1):
        precision[i - 1] = np.maximum(precision[i - 1], precision[i])
    idx = np.where(recall[1:] != recall[:-1])[0]
    ap = np.sum((recall[idx + 1] - recall[idx]) * precision[idx + 1])
    return ap

def class_AP(model, dl, n_classes, iou_thresh=0.5, detect_thresh=0.35, num_keep=100):
    #tps, clas, p_scores = [], [], []
    #classes, n_gts = LongTensor(range(n_classes)),torch.zeros(n_classes).long()
    with torch.no_grad():
        full_df = pd.DataFrame()
        for input, target in progress_bar(dl):
            output = model(input)
            n_images, ni = len(input), target[0].size(0)
            #print ('n images', n_images, ni)
            for i in range(10):
                bbox_detect, class_detect, class_max_detect = get_predictions(output, i, detect_thresh)
                tgt_bbox, tgt_class = unpad(target[0][i], target[1][i])
                if len(bbox_detect) != 0 and len(tgt_bbox) != 0:
                    #db.set_trace()
                    iou_val = iou(bbox_detect, tgt_bbox)
                    max_iou, matches = iou_val.max(1)
                    tps = max_iou > iou_thresh
                    ap = torch.cat((class_detect.view(-1, 1).float(), class_max_detect.view(-1, 1), max_iou.view(-1, 1), matches.view(-1, 1).float(), tgt_class[matches].view(-1, 1).float()), 1)
                    #print(ap)
                    #print('shapes', bbox_detect.shape[0], class_detect.shape, class_max_detect.shape, tgt_bbox.shape[0], iou_val.shape[0], max_iou.shape[0], matches.shape[0], tps.shape[0], ap.shape)
                    ap_df = pd.DataFrame(ap.cpu().numpy(), columns=['Class Pred', 'Class Score', 'IOU', 'IOU IDX', 'IOU Class'])
                    ap_df['TP'] = (ap_df['Class Pred'] == ap_df['IOU Class']) & (ap_df['IOU'] > iou_thresh)
                    ap_df['IMG'] = i
                    full_df = full_df.append(ap_df)
        full_df = full_df.sort_values(['Class Pred', 'Class Score'], ascending=[True, False])
        full_df['Pred Seq'] = full_df.groupby('Class Pred')['Class Score'].rank(method='first', ascending=False)
        return (full_df)
  
def compute_class_AP_f(model, dl, n_classes, iou_thresh=0.5, detect_thresh=0.35, num_keep=100):
    db.set_trace()
    tps, clas, p_scores = [], [], []
    classes, n_gts = LongTensor(range(n_classes)),torch.zeros(n_classes).long()
    with torch.no_grad():
        for input,target in progress_bar(dl):
            output = model(input)
            for i in range(target[0].size(0)):
                bbox_pred, preds, scores = get_predictions(output, i, detect_thresh)
                tgt_bbox, tgt_clas = unpad(target[0][i], target[1][i])
                if len(bbox_pred) != 0 and len(tgt_bbox) != 0:
                    ious = iou(bbox_pred, tgt_bbox)
                    max_iou, matches = ious.max(1)
                    detected = []
                    for i in range_of(preds):
                        if max_iou[i] >= iou_thresh and matches[i] not in detected and tgt_clas[matches[i]] == preds[i]:
                            detected.append(matches[i])
                            tps.append(1)
                        else: tps.append(0)
                    clas.append(preds.cpu())
                    p_scores.append(scores.cpu())
                n_gts += (tgt_clas.cpu()[:,None] == classes[None,:]).sum(0)
    tps, p_scores, clas = torch.tensor(tps), torch.cat(p_scores,0), torch.cat(clas,0)
    fps = 1-tps
    idx = p_scores.argsort(descending=True)
    tps, fps, clas = tps[idx], fps[idx], clas[idx]
    aps = []
    #return tps, clas
    for cls in range(n_classes):
        tps_cls, fps_cls = tps[clas==cls].float().cumsum(0), fps[clas==cls].float().cumsum(0)
        if tps_cls.numel() != 0 and tps_cls[-1] != 0:
            precision = tps_cls / (tps_cls + fps_cls + 1e-8)
            recall = tps_cls / (n_gts[cls] + 1e-8)
            aps.append(compute_ap(precision, recall))
        else: aps.append(0.)
    return aps

L = class_AP(learn.model, data.valid_dl, data.c-1)
#L_f = compute_class_AP(learn.model, data.valid_dl, data.c-1)
#for ap,cl in zip(L_f, data.classes[1:]): print(f'{cl}: {ap:.6f}')
L

In [0]:
data.

In [0]:
# Verify that the outputs from Fastai and my implementations are identical
class_pred_f, bbox_pred_f, sizes_f, acts_f = out_f
print (class_pred_f.shape, bbox_pred_f.shape, sizes_f, 'f done')

my_class_pred, my_bbox_pred, my_sizes, my_acts = out
print (my_class_pred.shape, my_bbox_pred.shape, my_sizes, 'my done')

for af, a in zip(acts_f, my_acts):
  print (af.shape, a.shape, torch.all(torch.eq(af, a)))
print (torch.all(torch.eq(class_pred_f, my_class_pred)), torch.all(torch.eq(bbox_pred_f, my_bbox_pred)))

### Temporary

In [0]:
def tlbr2cthw(boxes):
    "Convert top/left bottom/right format `boxes` to center/size corners."
    center = (boxes[:,:2] + boxes[:,2:])/2
    sizes = boxes[:,2:] - boxes[:,:2]
    return torch.cat([center, sizes], 1)
  
def cthw2tlbr(boxes):
    "Convert center/size format `boxes` to top/left bottom/right corners."
    top_left = boxes[:,:2] - boxes[:,2:]/2
    bot_right = boxes[:,:2] + boxes[:,2:]/2
    return torch.cat([top_left, bot_right], 1)

def encode_class(idxs, n_classes):
    target = idxs.new_zeros(len(idxs), n_classes).float()
    mask = idxs != 0
    i1s = LongTensor(list(range(len(idxs))))
    target[i1s[mask],idxs[mask]-1] = 1
    return target

def bbox_to_activ(bboxes, anchors, flatten=True):
    "Return the target of the model on `anchors` for the `bboxes`."
    if flatten:
        t_centers = (bboxes[...,:2] - anchors[...,:2]) / anchors[...,2:] 
        t_sizes = torch.log(bboxes[...,2:] / anchors[...,2:] + 1e-8) 
        return torch.cat([t_centers, t_sizes], -1).div_(bboxes.new_tensor([[0.1, 0.1, 0.2, 0.2]]))
    else: return [activ_to_bbox(act,anc) for act,anc in zip(acts, anchors)]
    return res

def intersection_f(anchors, targets):
    "Compute the sizes of the intersections of `anchors` by `targets`."
    ancs, tgts = cthw2tlbr(anchors), cthw2tlbr(targets)
    a, t = ancs.size(0), tgts.size(0)
    ancs, tgts = ancs.unsqueeze(1).expand(a,t,4), tgts.unsqueeze(0).expand(a,t,4)
    top_left_i = torch.max(ancs[...,:2], tgts[...,:2])
    bot_right_i = torch.min(ancs[...,2:], tgts[...,2:])
    sizes = torch.clamp(bot_right_i - top_left_i, min=0) 
    return sizes[...,0] * sizes[...,1]

def IoU_values(anchors, targets):
    "Compute the IoU values of `anchors` by `targets`."
    inter = intersection_f(anchors, targets)
    anc_sz, tgt_sz = anchors[:,2] * anchors[:,3], targets[:,2] * targets[:,3]
    union = anc_sz.unsqueeze(1) + tgt_sz.unsqueeze(0) - inter
    return inter/(union+1e-8)

def match_anchors_f(anchors, targets, match_thr=0.5, bkg_thr=0.4):
    "Match `anchors` to targets. -1 is match to background, -2 is ignore."
    matches = anchors.new(anchors.size(0)).zero_().long() - 2
    if targets.numel() == 0: return matches
    ious = IoU_values(anchors, targets)
    vals,idxs = torch.max(ious,1)
    matches[vals < bkg_thr] = -1
    matches[vals > match_thr] = idxs[vals > match_thr]
    #Overwrite matches with each target getting the anchor that has the max IoU.
    #vals,idxs = torch.max(ious,0)
    #If idxs contains repetition, this doesn't bug and only the last is considered.
    #matches[idxs] = targets.new_tensor(list(range(targets.size(0)))).long()
    return matches

match_anchors_f(anchors, targets)

In [0]:
size=(3,4)
anchors_np = create_grid_centres(size)
anchors = torch.from_numpy(anchors_np).float()
anchors = torch.cat([anchors, torch.tensor([2/size[0],2/size[1]]).expand_as(anchors)], 1)
activations = 0.1 * torch.randn(size[0]*size[1], 4)
bboxes_np = prediction_to_bbox (activations, anchors)
bboxes = torch.from_numpy(bboxes_np)

match_anchors_f(anchors,bboxes)
match_anchors(anchors,bboxes)

In [0]:
anchors_np = create_grid_centres((2,2))
anchors = torch.from_numpy(anchors_np).float()
anchors = torch.cat([anchors, torch.tensor([1.,1.]).expand_as(anchors)], 1)
targets = anchors.clone()
anchors = torch.cat([anchors, torch.tensor([[-0.5,0.,1.,1.8]])], 0)

match_anchors_f(anchors,targets)
match_anchors(anchors,targets)

### Obsolete

In [0]:
torch.arange(0,16).long().view(4,4)

def create_grid(size):
    "Create a grid of a given `size`."
    H, W = size if is_tuple(size) else (size,size)
    grid = FloatTensor(H, W, 2)
    linear_points = torch.linspace(-1+1/W, 1-1/W, W) if W > 1 else tensor([0.])
    grid[:, :, 1] = torch.ger(torch.ones(H), linear_points).expand_as(grid[:, :, 0])
    linear_points = torch.linspace(-1+1/H, 1-1/H, H) if H > 1 else tensor([0.])
    grid[:, :, 0] = torch.ger(linear_points, torch.ones(W)).expand_as(grid[:, :, 1])
    return grid.view(-1,2)


def create_anchors(sizes, ratios, scales, flatten=True):
    "Create anchor of `sizes`, `ratios` and `scales`."
    aspects = [[[s*math.sqrt(r), s*math.sqrt(1/r)] for s in scales] for r in ratios]
    aspects = torch.tensor(aspects).view(-1,2)
    anchors = []
    for h,w in sizes:
        #4 here to have the anchors overlap.
        sized_aspects = 4 * (aspects * torch.tensor([2/h,2/w])).unsqueeze(0)
        base_grid = create_grid((h,w)).unsqueeze(1)
        n,a = base_grid.size(0),aspects.size(0)
        ancs = torch.cat([base_grid.expand(n,a,2), sized_aspects.expand(n,a,2)], 2)
        anchors.append(ancs.view(h,w,a,4))
    return torch.cat([anc.view(-1,4) for anc in anchors],0) if flatten else anchors

size = (4,4)
show_anchors(create_grid(size), size)

ratios = [1/2,1,2]
scales = [1,2**(-1/3), 2**(-2/3)] 
#Paper used [1,2**(1/3), 2**(2/3)] but a bigger size (600) too, so the largest feature map gave anchors that cover less of the image.
sizes = [(2**i,2**i) for i in range(5)]
sizes.reverse() #Predictions come in the order of the smallest feature map to the biggest
anchors = create_anchors(sizes, ratios, scales)
anchors.size()

show_boxes(anchors[900:909])

In [0]:
def activ_to_bbox(acts, anchors, flatten=True):
    "Extrapolate bounding boxes on anchors from the model activations."
    if flatten:
        #acts.mul_(acts.new_tensor([[0.1, 0.1, 0.2, 0.2]])) #Can't remember where those scales come from, but they help regularize
        centers = anchors[...,2:] * acts[...,:2] + anchors[...,:2]
        sizes = anchors[...,2:] * torch.exp(acts[...,:2])
        return torch.cat([centers, sizes], -1)
    else: return [activ_to_bbox(act,anc) for act,anc in zip(acts, anchors)]
    return res
  
size=(3,4)
anchors = create_grid(size)
anchors = torch.cat([anchors, torch.tensor([2/size[0],2/size[1]]).expand_as(anchors)], 1)
activations = torch.randn(size[0]*size[1], 4) * 0.1
bboxes = activ_to_bbox(activations, anchors)

show_boxes(bboxes)

In [0]:
def cthw2tlbr(boxes):
    "Convert center/size format `boxes` to top/left bottom/right corners."
    top_left = boxes[:,:2] - boxes[:,2:]/2
    bot_right = boxes[:,:2] + boxes[:,2:]/2
    return torch.cat([top_left, bot_right], 1)

  def intersection(anchors, targets):
    "Compute the sizes of the intersections of `anchors` by `targets`."
    ancs, tgts = cthw2tlbr(anchors), cthw2tlbr(targets)
    a, t = ancs.size(0), tgts.size(0)
    ancs, tgts = ancs.unsqueeze(1).expand(a,t,4), tgts.unsqueeze(0).expand(a,t,4)
    top_left_i = torch.max(ancs[...,:2], tgts[...,:2])
    bot_right_i = torch.min(ancs[...,2:], tgts[...,2:])
    sizes = torch.clamp(bot_right_i - top_left_i, min=0) 
    return sizes[...,0] * sizes[...,1]

def IoU_values(anchors, targets):
    "Compute the IoU values of `anchors` by `targets`."
    inter = intersection(anchors, targets)
    anc_sz, tgt_sz = anchors[:,2] * anchors[:,3], targets[:,2] * targets[:,3]
    union = anc_sz.unsqueeze(1) + tgt_sz.unsqueeze(0) - inter
    return inter/(union+1e-8)

def match_anchors(anchors, targets, match_thr=0.5, bkg_thr=0.4):
    "Match `anchors` to targets. -1 is match to background, -2 is ignore."
    matches = anchors.new(anchors.size(0)).zero_().long() - 2
    if targets.numel() == 0: return matches
    ious = IoU_values(anchors, targets)
    vals,idxs = torch.max(ious,1)
    matches[vals < bkg_thr] = -1
    matches[vals > match_thr] = idxs[vals > match_thr]
    #Overwrite matches with each target getting the anchor that has the max IoU.
    #vals,idxs = torch.max(ious,0)
    #If idxs contains repetition, this doesn't bug and only the last is considered.
    #matches[idxs] = targets.new_tensor(list(range(targets.size(0)))).long()
    return matches
 
targets = torch.tensor([[0.,0.,2.,2.], [-0.5,-0.5,1.,1.], [1/3,0.5,0.5,0.5]])
show_boxes(targets)

match_anchors(anchors, targets)

In [0]:
anchors = create_grid((2,2))
anchors = torch.cat([anchors, torch.tensor([1.,1.]).expand_as(anchors)], 1)
targets = anchors.clone()
anchors = torch.cat([anchors, torch.tensor([[-0.5,0.,1.,1.8]])], 0)
match_anchors(anchors,targets)

### Temp

In [0]:
try_tfms = get_transforms(do_flip=True, flip_vert=True, max_rotate=45., max_zoom=3.2, max_lighting=.8)

def get_ex(): return open_image(jpeg_dir/image_dict[17])

def plots_f(rows, cols, width, height, **kwargs):
    [get_ex().apply_tfms(try_tfms[0], **kwargs).show(ax=ax) for i,ax in enumerate(plt.subplots(
        rows,cols,figsize=(width,height))[1].flatten())]
plots_f(2, 8, 20, 6, size=224)

In [0]:
import pdb
import IPython.core.debugger as db

_,axs = plt.subplots(1,3,figsize=(9,3))
for rsz,ax in zip([ResizeMethod.CROP, ResizeMethod.PAD, ResizeMethod.SQUISH], axs):
    db.set_trace()
    get_ex().apply_tfms([crop_pad()], size=224, resize_method=rsz, padding_mode='zeros').show(ax=ax, title=rsz.name.lower())

In [0]:
tfms = get_transforms()
data = ImageDataBunch.from_csv(data_dir, 'VOCdevkit/VOC2007/JPEGImages', csv_labels=csv_file, label_delim=' ', label_col=2, ds_tfms=tfms, 
                               size=224, resize_method=ResizeMethod.SQUISH, bs=64).normalize(imagenet_stats)
data.show_batch(rows=3, figsize=(7,6))

In [0]:
data2 = (ImageList.from_df(df, path=data_dir, folder='VOCdevkit/VOC2007/JPEGImages')
      .split_by_rand_pct()
      .label_from_df(cols=['tlr', 'tlc', 'brr', 'brc'], label_cls=FloatList)
      .transform(get_transforms(), resize_method=ResizeMethod.SQUISH, size=224)
      .databunch(bs=64)) 
data2.normalize(imagenet_stats)
data2.show_batch(rows=3, figsize=(7,6))

In [0]:
head_reg4 = nn.Sequential(Flatten(), nn.Linear(25088,4))
learn = cnn_learner(data2, models.resnet34, custom_head=head_reg4, loss_func=nn.L1Loss(), metrics=[accuracy])