# Ball detection in basketball images

We consider the task of detecting the ball in basketball images.
For this, we use a fully convolutional neural network which is driven by a binary target containing `1`s on the pixels where the ball is and `0`s elsewhere. The heatmap output by the network is compared against the binary target.

Input image | Output heatmap | Target
--- | --- | ---
![RGB input image](https://arena-data.keemotion.com/tmp/gva/input_image.png) | ![output heatmap](https://arena-data.keemotion.com/tmp/gva/output_heatmap.png) | ![binary target](https://arena-data.keemotion.com/tmp/gva/input_target.png)

### Peak Local Maxima layer implementation
We extract the local maxima from a given heatmap, somewhat similar to the function `skimage.feature.peak_local_max`. The layer should work with a batch of multiple images.

In [11]:
import tensorflow as tf
import numpy as np

class PeakLocalMax():
    def __init__(self, min_distance=20, threshold_abs=0.5):
        """
            Find peaks in a batch of images as boolean mask. Peaks are the local
            maxima in a region of 2 * min_distance + 1 (i.e. peaks are separated
            by at least min_distance).

            If there are multiple local maxima with identical pixel intensities
            inside the region defined by 'min_distance', the coordinates of all
            such pixels are returned.

            Arguments:
                - min_distance (int): Minimum number of pixels separating peaks
                in a region of 2 * min_distance + 1 (i.e. peaks are separated by
                at least min_distance). To find all the local maxima, use
                min_distance=1).
                - threshold_abs (float): Minimum intensity of peaks.
        """
        self.min_distance = min_distance
        self.threshold_abs = threshold_abs

    def __call__(self, batch_heatmap):
        """
            Performs the peak-local-max operation on batch_heatmap.

            Arguments:
                - batch_heatmap: a float32 tensor of shape [B,H,W] in [0,1]
                containing B images of width W and height H.
            Returns:
                Returns a boolean tensor of shape [B,H,W] with
                - True: on local maxima.
                - False: elsewhere.
        """

        batch_heatmap_with_channels = tf.expand_dims(batch_heatmap, axis=-1)
        dilation = tf.nn.max_pool2d(batch_heatmap_with_channels, ksize=(2*self.min_distance+1), strides=1, padding='SAME')
        dilation = dilation[:, :, :, 0]  # Remove channel axis

        local_maxima = (batch_heatmap == dilation)
        all_maxima = (batch_heatmap >= self.threshold_abs)
        return tf.logical_and(local_maxima, all_maxima)

input = np.array([[[0.5, 0, 0, 0], [0, 0, 0, 0], [0.0, 0, 0, 0], [0.76, 0, 0, 0.89]]])
layer = PeakLocalMax(min_distance=2)
print(layer(input))


tf.Tensor(
[[[ True False False False]
  [False False False False]
  [False False False False]
  [ True False False  True]]], shape=(1, 4, 4), dtype=bool)


### Detection metric layer implementation

Lets consider the following binary maps:
- `batch_hitmap` has `1`s for each ball candidate and `0`s elsewhere (such map has one single pixel set to `1` per candidate)
- `batch_target` has `1`s for every pixels on the ball and `0`s elsewhere (such map contains multiple pixels set to `1` on the ball)

We compute the following detection metrics:
- 1 true-positive for each **candidate** where there is a ball;
- 1 false-positive for each **candidate** where there is no ball;
- 1 true-negative for each **image** without any ball and without any candidates;
- 1 false-negative for each **ball** that is not detected by a candidate.

In [None]:
class ComputeElementaryMetrics():
    def __init__(self):
        """
            Computes the elementary detection metrics given a map of ball
            candidates and a target map.
            
            The elementary metrics are:
                - 1 true-positive (TP) for each candidate where there is a ball
                - 1 false-positive (FP) for each candidate where there is no ball
                - 1 true-negative (TN) for each image without any ball and any candidates
                - 1 false-negative (FN) for each ball that is not detected by a candidate
        """
        pass

    def __call__(self, batch_hitmap, batch_target):
        """
            Performs the elementary metrics computation on the batch given
            batch_hitmap and batch_target.

            Arguments:
                - batch_hitmap: a uint8 tensor of shape [B,H,W] in {0,1}
                containing B maps of candidates of width W and height H. It has
                1s for each ball candidate and 0s elsewhere.
                - batch_target: a uint8 tensor of shape [B,H,W] in {0,1}
                containing B targets corresponding to the WxH input images. It
                has 1s where there is a ball and 0s elsewhere.
            
            Returns:
                Returns a dictionary of 1-D arrays containing the number of TP,
                FP, TN and FN for each element of the batch.
        """
        batch_hitmap_as_bool = tf.cast(batch_hitmap, tf.bool)
        batch_target_as_bool = tf.cast(batch_target, tf.bool)

        batch_TP = tf.reduce_sum(tf.cast(tf.logical_and(batch_hitmap_as_bool, batch_target_as_bool), tf.int32))
        batch_FP = tf.reduce_sum(
            tf.cast(tf.logical_and(batch_hitmap_as_bool, tf.logical_not(batch_target_as_bool)), tf.int32))

        # TN
        images_without_candidates = (tf.reduce_sum(batch_hitmap, axis=[1,2]) == 0)
        targets_without_candidates = (tf.reduce_sum(batch_target, axis=[1,2]) == 0)
        batch_TN = tf.reduce_sum(tf.cast(tf.logical_and(
            images_without_candidates,
            targets_without_candidates
        ), tf.int32))

        # FN
        batch_FN = tf.reduce_sum(
            tf.cast(tf.logical_and(tf.logical_not(batch_hitmap_as_bool), batch_target_as_bool), tf.int32))

        return {
            "batch_TP": batch_TP,
            "batch_FP": batch_FP,
            "batch_TN": batch_TN,
            "batch_FN": batch_FN,
        }

### Ball Detection evaluation metrics

Using the number of **true** and **false** **positives** and **negatives** computed on batch of data, you are asked to compute relevant detection metrics on the whole dataset (for each epoch) and print them to standard output. Please include succinct interpretation of the metrics you selected.

For this, we provide the structure of `ComputeDetectionMetricsCallback` that implements callbacks called before and after each batch and each epochs. They  receive a `state` dictionary containing the current elementary metrics, as well as other state variables that are not necessary here.


In [None]:
class ComputeDetectionMetricsCallback():
    ### TO COMPLETE
    def __init__(self):
        self.precision = 0
        self.recall = 0
        self.f1_score = 1

    def on_epoch_begin(self, state):
        """ called at the begining of each epoch """
        # reset metrics
        self.precision = 0
        self.recall = 0
        self.f1_score = 0
        
        self.num_batches = 0

    def on_epoch_end(self, state):
        """ called at the end of each epoch """
        print(f'Precision: {self.precision}')
        print(f'Recall: {self.recall}')
        print(f'F1 score: {self.f1_score}')

    def on_batch_begin(self, state):
        """ called before processing each batch """
        self.num_batches += 1

    def on_batch_end(self, state):
        """ called after processing each batch """
        # state contains "batch_TP", "batch_FP", "batch_TN", "batch_FN" for the current batch
        
        # Precision calculates what fraction of ball detections are real ball detections
        batch_precision = state['batch_TP'] / (state['batch_TP'] + state['batch_FP'])

        # Recall tells us what fraction of real balls are detected
        batch_recall = state['batch_TP'] / (state['batch_TP'] + state['batch_FN'])

        # F1 score is a method to combine precision and recall
        batch_f1_score = (2 * batch_precision * batch_recall) / (batch_precision + batch_recall)
        
        # Update
        self.precision = (1/self.num_batches) * batch_precision + ((self.num_batches-1) / self.num_batches) * self.precision
        self.recall = (1 / self.num_batches) * batch_recall + ((self.num_batches - 1) / self.num_batches) * self.recall
        self.f1_score = (1 / self.num_batches) * batch_f1_score + ((self.num_batches - 1) / self.num_batches) * self.f1_score
        





