In [None]:
#%pip install opencv-python
%pip install tensorflow-datasets

# Object detection challenge: `WIDER` faces
***
For this challenge, we ask you to detect faces in the `WIDER` faces dataset
* We use a stripped down version of the data to speed up the training process
* Your task is to maximize for either:
    * `precision`
    * `recall`
    * `overall`, which really is just precision + recall

A fully functioning code example is provided, as always:
* End-to-end face detection with a simple `CNN` architecture in `keras`
* Feel free to adapt the code to your liking for best results.

Post your questions and results in our discord channel!

Happy hunting!

Small note:
The following cells contain the necessary code for data preprocessing etc. Just run the cells until you reach the "Welcome back" slide ;)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# image visualization utility
def show_image_with_bbs(img, bbs):
    '''img and bbs should be numpy arrays
        bbs should have format [top, left, bottom, right]
        if bbs.dtype == "float32":
            we interprete bbs as relative bbs
        elif bbs.dtype is int32:
            we interprete them as absolute bbs
    '''
    if bbs.dtype == "float32":
        H, W = img.shape[:2]
        bbs[:, [1,3]] *= W
        bbs[:, [0,2]] *= H
        bbs = bbs.astype("int32")

    fig = plt.figure()
    ax = plt.gca()
    plt.imshow(img)
    for y1,x1,y2,x2 in bbs:
        rect = patches.Rectangle((min(x1,x2), min(y2,y1)), 
                                 abs(x2-x1), abs(y2-y1), 
                                 linewidth=1, edgecolor='r', 
                                 facecolor='none')
        ax.add_patch(rect)
    plt.show()

In [None]:
import numpy as np
import tensorflow as tf

import tensorflow_datasets as tfds

# image utilities for tensorflow and tf dataset

def tf_resize_img_with_bbs(img, bbs, H, W):
    '''
    Resizes the image and the bboxes with padding.
    img and bbs are expected to be Tensors.
    '''
    # compute the padding that will be added
    shape = tf.shape(img)
    h = shape[0]
    w = shape[1]
    scale = tf.cast(tf.minimum(H/h, W/w), "float32")
    nw = tf.cast(w, "float32") * scale
    nh = tf.cast(h, "float32") * scale

    # we divide by 2 to get the padding per side, not in total
    pad_x = (tf.cast(W, "float32") - nw)/2.0
    pad_y = (tf.cast(H, "float32") - nh)/2.0

    # resize the image
    img = tf.image.resize_with_pad(img, H, W)
    img = tf.cast(img, "uint8")

    # transform bboxes
    x_coords = (tf.gather(bbs, [1,3], axis=1)*nw + pad_x)/tf.cast(W, "float32")
    y_coords = (tf.gather(bbs, [0,2], axis=1)*nh + pad_y)/tf.cast(H, "float32")
    bbs = tf.stack([y_coords[:,0],x_coords[:,0],
                    y_coords[:,1],x_coords[:,1],], axis=1)
    return img, bbs

def tfds_resize_img_with_bbs(H, W):
    '''
    TF Dataset variant of `tf_resize_img_with_bbs`. Expects "image" and a "bbox" column
    '''
    def t(sample):
        img, bbs = tf_resize_img_with_bbs(sample["image"], sample["bbox"], H, W)
        return {"image":img, "bbox":bbs}
    return t

In [None]:
# tf and tfds preprocessing functions for object detection

def tf_get_output_maps(bbs, H, W, stride):
    th, tw = H//stride, W//stride
    num_bbs = tf.shape(bbs)[0]
    cx = (bbs[:, 1] + bbs[:, 3])/2.0
    cy = (bbs[:, 0] + bbs[:, 2])/2.0
    indices = tf.cast(tf.floor(tf.stack([cy*th, cx*tw], axis=1)), "int32")
    values = tf.ones((num_bbs, 1))
    cmap = tf.scatter_nd(indices,values, [th,tw,1])

    # bb map
    # get absolute bb coords wrt to input dims
    abs_bbs = bbs * tf.cast([[H,W,H,W]],"float32")
    # we want for each cell in bb_map to contain the
    # offset of the bb in absolute pixels from the 
    # top left corner of the cell
    coords = tf.cast(tf.tile(indices, [1, 2]), "float32")
    # multiplying the cell indices by the network's stride
    # we get the coords of the top left corner in the input
    # image dims
    coords = coords * tf.cast([[stride,stride,stride,stride]],"float32")
    # by subtracting these coords, each bb is now expressed
    # relative to the coordinates of the cell it is contained in
    # this allows the network to learn position independent
    # representations
    rel_bbs = abs_bbs - coords
    bb_map = tf.scatter_nd(indices,rel_bbs,[th, tw, 4])
    return cmap, bb_map

def tfds_get_output_maps(H, W, stride):
    '''
    TF dataset variant of tf_get_output_maps.
    This function can be used to map a {.., "bbox":..}
    dataset to {"bbox":.., coverage":.., "bbox_map":..,...,}
    '''
    def f(sample):
        cmap, bbmap = tf_get_output_maps(sample["bbox"], H, W, stride)
        sample = sample.copy()
        sample.update({"coverage": cmap, "bbox_map": bbmap})
        return sample
    return f

def tf_bbs_from_output_maps(coverage, bbmap, H, W, stride, threshold = 0.5):
    '''
    Given coverage[th x tw x 1] and bbmap [th x tw x 4] tensors,
    computes the list of bounding boxes with a coverage>=threshold
    bbs in bbmap are expected to be relative to their cell
    Also supports batched args of shapes (B, th, tw, 1) and (B, th, tw, 4),
    respectively.
    positions.
    Returns:
        List of bbs relative to image size, or list thereof (for batched args)
        
        List of confidence scores of each bb, or list thereof (for batched args)
    '''
    return_batched = True
    if coverage.ndim == 3:
        return_batched = False
        coverage = tf.expand_dims(coverage, 0)
        bbmap = tf.expand_dims(bbmap, 0)

    batch_size = tf.shape(bbmap)[0]
    bb_list = []
    score_list = []

    for i in range(batch_size):
        # get the absolute offsets for the predicted cells
        indices = tf.where(coverage[i,:,:,0]>=threshold)
        scores = tf.gather_nd(coverage[i], indices)
        score_list.append(scores)

        bb_offsets = tf.gather_nd(bbmap[i], indices)
        # to reconstruct the bb according to the input image dimensions
        # we need to compute which coords in the input image correspond
        # to the cell in the output map (aka coverage, bbmap)
        # This is done simply by multiplying the indices by the stride
        # of the network
        indices = tf.cast(tf.tile(indices, [1, 2]),"float32")
        input_coords = indices * tf.cast([[stride,stride,stride,stride]],"float32")
        bbs = input_coords + bb_offsets
    
        bb_list.append(bbs)

    return (bb_list,score_list) if return_batched else (bb_list[0], score_list[0])

def tfds_bbs_from_output_maps(H, W, stride, threshold=0.5):
    '''
    TF Dataset equivalent of `tf_bbs_from_output_maps`.
    Expects "coverage" and "bbox_map" columns to exist.
    
    '''
    def t(sample):
        bbs, scores = tf_bbs_from_output_maps(sample["coverage"], sample["bbox_map"], H, W, stride, threshold)
        sample = sample.copy()
        sample.update({"bbox_pred": bbs, "scores":scores})
        return sample
    return t

In [None]:
# metrics for object detection

def tf_iou(bbs_true, bbs_pred):
    '''
    Computes the pair-wise intersection-over-union metric
    for the given bounding box tensors.
    Both bbs_true and bbs_pred are expected to have shape
    [X,4], where X might be different.
    
    Returns a num_pred_bbs x num_true_bbs
    '''
    # assume bbs are sorted correctly
    n_true = tf.shape(bbs_true)[0]
    n_pred = tf.shape(bbs_pred)[0]
    # tiled true bbs
    bbtt = tf.tile(bbs_true[tf.newaxis],(n_pred,1,1))
    # tiled pred bbs
    bbpt = tf.tile(bbs_pred[:,tf.newaxis],(1,n_true,1))
    # get the highest top, left and the lowest right, bottom
    # components for each pair of true,pred bb
    # These components are the intersections between true and
    # pred pairs.
    intersections = tf.stack([
        tf.where(bbtt[:,:,0]>bbpt[:,:,0], bbtt[:,:,0], bbpt[:,:,0]),
        tf.where(bbtt[:,:,1]>bbpt[:,:,1], bbtt[:,:,1], bbpt[:,:,1]),
        tf.where(bbtt[:,:,2]<bbpt[:,:,2], bbtt[:,:,2], bbpt[:,:,2]),
        tf.where(bbtt[:,:,3]<bbpt[:,:,3], bbtt[:,:,3], bbpt[:,:,3])
    ],axis=-1)

    # filter out non-overlapping candidates
    index = tf.logical_and(intersections[:,:,0]<intersections[:,:,2],
                           intersections[:,:,1]<intersections[:,:,3])
    non_empty = intersections[index]
    # compute area of the intersections
    inter_area = (non_empty[:,2] - non_empty[:,0])*(non_empty[:,3] - non_empty[:,1])

    # get only those bbs that have an overlap
    bbtto = bbtt[index]
    bbpto = bbpt[index]
    # next we need the union of those pairs.
    # we compute the union as the sum of both areas minus the intersection area
    summed_area = (bbtto[:,2]-bbtto[:,0])*(bbtto[:,3]-bbtto[:,1]) + \
                    (bbpto[:,2]-bbpto[:,0])*(bbpto[:,3]-bbpto[:,1])

    union_area = summed_area - inter_area
    # these are the ious of only the overlapping pairs
    # we project them back into the n_pred x n_true matrix
    # using the index in the next step. Luckily, non-
    # overlapping bbs automatically have an iou of zero.
    ious = inter_area / union_area
    # project back into 2d matrix where non overlaps are zero.
    ious = tf.scatter_nd(tf.where(index), ious,(n_pred,n_true))
    return ious

def tf_precision_recall(bbs_true, bbs_pred, scores, iou_threshold = 0.5):
    '''
        Computes precision and recall values for ground truth and predicted
        bounding box arrays (N x [t,l,b,r]) and (M x [t,l,b,r]).
        `scores` (M x 1) contains the confidence in the predicted bbs 
        It works as follows:
            For each ground truth bounding box, find all 
            predicted bbs with an iou > iou_threshold.
            If there are one or more bbs:
                The bb with the highest confidence is considered a true positive
                for the current ground truth bb, whereas the other matched bbs are
                considered false positives.
            else:
                the groundtruth bb is considered a false negative
            After the above procedure, any remaining predicted bounding boxes 
            are then automatically false positives
    '''
    n_true = tf.shape(bbs_true)[0]
    n_pred = tf.shape(bbs_pred)[0]

    scores = tf.squeeze(scores, axis=-1)
    
    # a bool map indicating which pred bb has been selected already
    available = tf.ones((n_pred,), "bool")
    fns = 0 # true bbs that werent covered (false negatives)
    tps = 0 # true bbs that were covered (true positives)

    ious = tf_iou(bbs_true, bbs_pred)

    for i in range(n_true):
        # find all bbs pred with greater iou than threshold,
        # which have not been selected yet
        index = tf.logical_and(ious[:,i] >= iou_threshold, available)
        if tf.reduce_any(index):
            indices = tf.where(index)
            # find the bb among the candidates with the highest score
            max_conf_index = tf.argmax(scores[index])
            bb_index = indices[max_conf_index]
            # update the availability index
            available = tf.tensor_scatter_nd_update(available,[bb_index],[False])
            tps += 1
        else:
            # the true bb was not covered by any prediction :(
            fns += 1

    # pred bbs that didnt ever cover a true bb (false positives)
    fps = tf.cast(tf.math.count_nonzero(available),"float32")

    # make sure all the types match for the next calculations...
    fns = tf.convert_to_tensor(fns,"float32")
    tps = tf.convert_to_tensor(tps,"float32")

    # compute the true positive rate (aka precision)
    precision = tf.math.divide_no_nan(tps, tps+fps)
    # compute how many true bbs where "hit" by a prediction (aka recall)
    recall = tf.math.divide_no_nan(tps, tps+fns)

    return precision, recall

def tf_precision_recall_batch(bbs_true, bbs_pred, scores, iou_threshold=0.5):
    '''
    Batched variant of `tf_precision_recall`.
    Expected args:
        `bbs_true`: List of [Xs,4] tensors, with varying Xs
        `bbs_pred`: List of [Ys,4] tensors, with varying Ys
        `scores`: List of [Ys] tensors.
        All lists need to have the same length (batch_size).

    See `tf_precision_recall` for semantics for the elements of
    the lists.
    We compute the per sample precision/recall in parallel and
    return the mean.
    '''
    ps,rs = [], []
    batch_size = len(bbs_true)
    for i in range(batch_size):
        p,r = tf_precision_recall(bbs_true[i], bbs_pred[i], scores[i], 
                                  iou_threshold=iou_threshold)
        ps.append(p)
        rs.append(r)
    precision = tf.reduce_mean(ps)
    recall = tf.reduce_mean(rs)
    return precision, recall

def tf_precision_recall_from_output_maps(bbs_true,
                                         coverage_pred, bb_map_pred,
                                         H,W, 
                                         coverage_threshold=0.5,
                                         iou_threshold=0.5,):
    '''
    Computes the precision and recall for a batch of coverage/bb_map
    pairs.
    The expected args are:
        `bbs_true`: List of [X,4] tensors containing the true bbs
        `coverage_pred`: Tensor of shape [B, th, tw, 1]
        `bb_map_pred`: Tensor of shape [B, th, tw, 4]
            Values are expected to be absolute pixel offsets from the 
            original image H/W.
        `H`,`W`: Height and width of input images (int)
    Returns the mean precision/recall over the batch 
    '''
    # TODO: write a method that takes single samples of the maps and returns
    # the sample precision/recall. This should be vectorizable!
    th = tf.shape(coverage_pred)[1]
    stride = H//th
    bbs_pred, scores = tf_bbs_from_output_maps(coverage_pred, bb_map_pred,H,W,
                                      stride,threshold=coverage_threshold)
    return tf_precision_recall_batch(bbs_true, bbs_pred, scores, 
                                    iou_threshold=iou_threshold)

def tf_precision_recall_curve(bbs_true, bbs_pred, scores, thresholds=None, iou_threshold=0.5):
    '''
    Computes the PR curve for the given true and predicted bounding boxes [NO BATCHES].
    `thresholds` is a list of specific recall values at which to interpolate the curve.
        If None is given, the complete curve will be computed.
        Make sure to contain 0 and 1 as values.
    `bbs_pred` should contain all bounding boxes with scores within min and max thresholds.
    Returns precision, recall and thresholds
    '''
    scores = tf.squeeze(scores, -1)

    if thresholds is None:
        t,_ = tf.unique(tf.concat([[0,1],scores], axis=0))
        thresholds = tf.sort(t)

    ps, rs = [], []
    for t in thresholds:
        index = scores >= t
        p,r = tf_precision_recall(bbs_true, bbs_pred[index], 
                                  tf.expand_dims(scores[index],-1), 
                                  iou_threshold)
        ps.append(p); rs.append(r)
    precision = tf.stack(ps)
    recall = tf.stack(rs)
    return precision, recall, thresholds


In [None]:
# custom losses

def masked_loss(loss):
    '''
    Returns the loss computed ONLY on the TRUE values of
    y_true!
    This is useful if you only have labels for the true class
    but not the false class. Dont expect this to work well
    on classification though.
    '''
    def l(y_true, y_pred):
        yp = tf.cast(y_true != 0., "float32") * y_pred
        return loss(y_true, yp)
    return l

In [None]:
# evaluation and validation code


def evaluate_model(dataset, model, coverage_threshold=0.5, iou_threshold=0.5):
    '''
    Runs evaluation on the given dataset and model.
    `dataset` must be from the `tfds_wider_test_pipeline`!

    '''
    H,W = model.input_shape[1:3]
    th, tw = model.output_shape[0][1:3]
    stride = H//th

    model.reset_metrics()
    ps, rs = [], []
    for x, bbs_true, (cov, bbm) in tqdm(dataset):
        # the true bbs are in relative coords, so we convert them
        bbs_true = bbs_true * tf.cast([[H,W,H,W]],"float32")
        # we need to create batches of size 1 for inference...
        x = tf.expand_dims(x, 0)
        cov = tf.expand_dims(cov,0)
        bbm = tf.expand_dims(bbm,0)
        y = (cov, bbm)

        y_pred = model(x, training=False)
        cov_pred, bbm_pred = y_pred

        # evaluate losses and metrics
        if model.compiled_loss is not None:
            model.compiled_loss(y, y_pred, regularization_losses=model.losses)
        if model.compiled_metrics is not None:
            model.compiled_metrics.update_state(y, y_pred)

        # calculate precision and recall
        with tf.device("/cpu:0"):
            bbs_pred, scores = tf_bbs_from_output_maps(cov_pred, bbm_pred, H,W,
                                                       stride, coverage_threshold)
            p, r = tf_precision_recall(bbs_true, bbs_pred[0], scores[0], iou_threshold)
        ps.append(p); rs.append(r)

    precision = tf.reduce_mean(ps)
    recall = tf.reduce_mean(rs)

    # Collect metrics
    result_metrics = {"precision":precision, "recall":recall}
    for metric in model.metrics:
        result = metric.result()
        if isinstance(result, dict):
            result_metrics.update(results)
        else:
            result_metrics[metric.name] = result
    return result_metrics


In [None]:
# custom callbacks
from tensorflow import keras
from tqdm import tqdm

class PlotCallback(keras.callbacks.Callback):
    '''
    Callback to plot examples after each epoch
    '''
    def __init__(self, imgs, threshold=0.5, max_bbs=50):
        keras.callbacks.Callback.__init__(self)
        self.imgs = imgs
        self.threshold = threshold
        self.max_bbs = max_bbs

    def on_epoch_end(self, *args):
        H,W = self.model.input_shape[1:3]
        th, tw = self.model.output_shape[0][1:3]
        stride = H//th
        cov, bbm = self.model.predict(self.imgs)
        bb_list, scores = tf_bbs_from_output_maps(cov, bbm, H, W, stride, self.threshold)

        for i in range(len(self.imgs)):
            fig, (left, right) = plt.subplots(1,2, squeeze=True)
            left.imshow(self.imgs[i])
            bbs = bb_list[i].numpy()[:self.max_bbs]
            bbs = bbs.astype("int32")
            for y1,x1,y2,x2 in bbs:
                rect = patches.Rectangle((min(x1,x2), min(y2,y1)), 
                                         abs(x2-x1), abs(y2-y1), 
                                         linewidth=1, edgecolor='r', 
                                         facecolor='none')
                left.add_patch(rect)
            im = right.imshow(cov[i,:,:,0])
            fig.colorbar(im, ax=right)
            plt.show()



class MetricCallback(keras.callbacks.Callback):
    '''
    A special callback for cutomized metric calculation.
    It serves two purposes:
        1. To compute precision and recall, we need both
        coverage and bb_map layer outputs as inputs to a
        single metric function. This is not supported by
        keras natively. Hence, we do this in this callback.
        2. Doing 1. would result in two passes over the
        validation data: One for the validation losses etc.
        and one for precicsion/recall. As a solution, we
        also do the "standard" validation that keras would
        usually take care of. Specifically, for each metric
        of the model, we add a "val_" version of it based
        on the results of the validation data.
    Summing up, we add the following metrics:
        val_precision
        val_recall
        val_loss
        Plus any other "val_" + X metric of the model.
    As a downside, the default keras epoch report won't
    show our additional metrics.
    However, they do show up in the returned History callback.
    '''
    
    def __init__(self, dataset, coverage_threshold=0.5, iou_threshold=0.5):
        '''
        `dataset` must be from the `tfds_wider_test_pipeline`!
        '''
        self.dataset = dataset
        self.coverage_threshold = coverage_threshold
        self.iou_threshold = iou_threshold

    def on_epoch_end(self, epoch, logs=None):
        metrics = evaluate_model(self.dataset, self.model, 
                                 self.coverage_threshold,
                                 self.iou_threshold)
        val_metrics = {"val_"+k:v for k,v in metrics.items()}
        logs.update(val_metrics)

        # create a small report
        line = " - ".join([f"{k}: {v:.4f}" for k,v in logs.items()])
        print(f"[VAL {epoch+1}/{self.params['epochs']}]", line)

In [None]:
# WIDER corpus loading & preprocessing

def tf_filter_bbs(bbs, min_h=0, max_h=1.0, min_w=0, max_w=1.0):
    '''
    Given a tensor of relative bbs, return those that lie
    within the given constraints.
    '''
    bbh = bbs[:,2] - bbs[:,0]
    bbw = bbs[:,3] - bbs[:,1]
    height_res = tf.logical_and(bbh>=min_h, bbh<=max_h)
    width_res = tf.logical_and(bbw>=min_w, bbw<=max_w)
    restrictions = tf.logical_and(height_res, width_res)
    bbs_new = tf.boolean_mask(bbs, restrictions)
    return bbs_new

def tfds_filter_bbs(min_h=0, max_h=1.0, min_w=0, max_w=1.0):
    def t(sample):
        bbs = tf_filter_bbs(sample["bbox"], min_h, max_h, min_w, max_w)
        sample = sample.copy()
        sample.update({"bbox": bbs})
        return sample
    return t    

def tfds_wider_train_pipeline(ds, H, W, stride, batch_size):
    '''
    The pipeline to prepare the wider corpus for training.
    Filters out any image with more than 3 faces
    Converts the bounding boxes to coverage and bb maps
    Resizes the images to uniform height and width
    Caches, batches and prefetches the pipeline.
    Returns a dataset of tuples (image, (coverage, bb_map))
    '''
    ds = ds.map(lambda x: {"image":x["image"], "bbox": x["faces"]["bbox"]})
    ds = ds.filter(lambda x: len(x["bbox"])>0 and len(x["bbox"])<4)
    ds = ds.map(tfds_resize_img_with_bbs(H, W))
    ds = ds.map(tfds_get_output_maps(H,W, stride))
    
    ds = ds.map(lambda x: (x["image"], (x["coverage"],x["bbox_map"])))
    ds = ds.cache()
    ds = ds.batch(batch_size)
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    
    return ds

def tfds_wider_test_pipeline(ds, H, W, stride, batch_size):
    '''
    The pipeline to prepare the wider corpus for testing.

    Filters out any image with more than 3 faces
    Converts the bounding boxes to coverage and bb maps
    Resizes the images to uniform height and width
    Caches and prefetches the pipeline. (no batching!)

    Returns tuple of (image, bbs_true, (coverage, bb_map))
    '''
    ds = ds.map(lambda x: {"image":x["image"], "bbox": x["faces"]["bbox"]})
    ds = ds.filter(lambda x: len(x["bbox"])>0 and len(x["bbox"])<4)
    ds = ds.map(tfds_resize_img_with_bbs(H, W))
    ds = ds.map(tfds_get_output_maps(H,W, stride))
    
    ds = ds.map(lambda x: (x["image"],x["bbox"], (x["coverage"],x["bbox_map"])))
    ds = ds.cache()
    ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
    
    return ds

def get_wider_data(H, W, stride, batch_size, val_split=None):
    '''
    Loads the WIDER corpus as three TF Datasets:
    "train": The training data after cutting of validation
    "val": The validation data, sampled from the train samples
    "test": Completely independent from train and val, only for
        testing purposes. (We use WIDER validation split for testing)
    Parameters:
        `H`,`W`: Height and width for the images
        `stride`: Factor by which your network scales down H and W,
            usually through Pooling. Must be an integer!
        `batch_size`: you probably know this one..
        `val_split`: float, proportion of training samples to use
            for validation. None means no validation! Should be
            0<val_split<1 or None.
    Returns:
        (train, val, test) data if val_split is not None or
        (train, test) else
    '''
    if val_split is not None:
        n = int(val_split*100)
        split = [f"train[{n}%:]",f"train[:{n}%]", "validation"]
        (ds_train, ds_val, ds_test), ds_info = tfds.load(
            'wider_face',
            split=split,
            shuffle_files=True,
            with_info=True,
        )
        ds_val = tfds_wider_test_pipeline(ds_val, H, W, stride, batch_size)
    else:
        (ds_train, ds_test), ds_info = tfds.load(
            'wider_face',
            split=['train', 'validation'],
            shuffle_files=True,
            with_info=True,
        )
    # setup data pipelines
    ds_train = tfds_wider_train_pipeline(ds_train, H, W, stride, batch_size)
    ds_test = tfds_wider_test_pipeline(ds_test, H, W, stride, batch_size)

    if val_split is not None:
        return ds_train, ds_val, ds_test
    else:
        return ds_train, ds_test


# Welcome back
***
The pre- and post processing code ends here. Let's start with the training!

First: Specify the dataset dimensions and load the WIDER corpus

Remember: `stride` is the factor by which your network reduces the image dimensions.\
E.g. `stride=8` means that a 128x256 image becomes a 16x32 output grid

In [None]:
H, W = 128, 256
stride = 8
batch_size = 64
val_split = 0.05
th, tw = H//stride, W//stride
ds_train, ds_val, ds_test = get_wider_data(H, W, stride, batch_size, val_split)

# Building the model
***
We'll build a pretty simple model of three blocks:
* 3x `Conv2D` followed by a `MaxPool2D`
    This is the 'meat' of the network

After that, we form the model output:
* The `bbox` layer with 4 output channels, one for each bounding box coordinate
* The `coverage` layer, which acts as a classification grid to spot pixels that contain a face

Finally, we build the model and print the summary()

In [None]:
from tensorflow import keras
from tensorflow.keras.layers import (
    Conv2D, GlobalMaxPool2D, SpatialDropout2D, Dense, BatchNormalization,
    Input, MaxPool2D, AvgPool2D
)

input_layer = Input((H,W,3), name="image")
l = input_layer
for i in range(3):
    for i in range(3):
        l = Conv2D(32, (3,3), activation="relu", padding="same")(l)
    l = MaxPool2D((2,2))(l)

l = Conv2D(64, (3,3), activation="relu", padding="same")(l)

bbox_layer = Conv2D(4, (3,3), activation="linear", padding="same", name="bbox")(l)
coverage_layer = Conv2D(1, (3,3), activation="sigmoid", padding="same", name="coverage")(l)

model = keras.models.Model(input_layer, [coverage_layer, bbox_layer])
model.summary()

# Preparing the training
***
To have some more insight into the training process there are two Callbacks to use
* `MetricCallback`: It computes precision and recall metrics on our validation data
* `PlotCallback`: Draws `n_plot=3` example predictions of the model after each epoch.

In [None]:
# take a few validation samples for plotting
n_plot = 3
ds_plot = ds_val.take(n_plot)
imgs = np.stack([x for x,_,_ in ds_plot])

# setup callbacks
cb = [
    MetricCallback(ds_val, coverage_threshold=0.2),
    PlotCallback(imgs, threshold=0.1),
]

# Time to train
***
We need two losses to train our object detector:
* `BinaryCrossentropy`: A classification loss for our coverage layer
* `MaskedMSE`: A regression loss for bounding box prediction, that only works on the positive class
    * The mask is needed since we do not have any bounding boxes for negative classes (what would they even be?)

Then, just compile the model with an optimizer of your liking and start the training!

In [None]:
loss = {
    "coverage": "binary_crossentropy",
    "bbox": masked_loss(keras.losses.MSE),
}

opt = keras.optimizers.Adam(learning_rate=0.001)
model.compile(opt, loss)

model.fit(ds_train, epochs=1, verbose=1, callbacks=cb,)

# Training visualization
***

In [None]:
plt.figure(figsize=(17,8))
for k, v in model.history.history.items():
    plt.plot(v, label=k)

plt.xlabel("Epochs")
plt.legend()
plt.show()

# Model evaluation
***
We are looking for three best models (at `iou=0.5`):
* Max precision
* Max recall
* Max overall
    * Just add recall+precision

As always: Post your solutions in the discord!

In [None]:
metrics = evaluate_model(ds_test, model, coverage_threshold=0.2)
print("Precision", metrics["precision"])
print("Recall", metrics["recall"])
print("Overall", metrics["precision"] + metrics["recall"])

# Hints
***
Here are some areas that will certainly improve your precision + recall:
* Model architecture
    * The current model is very straight forward. Can you come up with a better one?
    * Notable architectures: ResNet, DenseNet, Hour-Glass
* Bounding box post processing
    * Non maxima suppression: Group together strongly overlapping candidates
        * This could be used in the `evaluate_model` function!
* Image resolution
    * Could input image size and the network stride have positive impact on performance?
* `coverage_threshold`:
    * Which threshold gives the best overall results?
    * Try the evaluation code with different values!