# Tracking using feature maps

### Basics
**Author**: Paul Willot  
**Creation date**: 2019/01/15  
**Tested on**: Ubuntu 18.04, and OSX 10.14 (Mojave).


### Idea

The goal is to **track individual objects through time** in a video.

My attempt to solve this rely on a [pre-trained MaskRCNN model](https://github.com/facebookresearch/maskrcnn-benchmark), infering the bounding boxes image by image independently.

Obviously as expected from such a common task, solutions have already been proposed, like simply measuring the IOU of bounding boxes over frames ([Simple Online and Realtime Tracking](https://arxiv.org/abs/1602.00763) paper), or a bit more advanced ones like adding a Kalman filter and smarter heuristics ([DeepSORT](https://arxiv.org/abs/1703.07402)).

But code for this already exist so it's not fun to just reproduce it! I wanted to try something a bit different.  
So below I **extract the feature maps associated with each bounding boxes**, track feature maps/objects relation in time and **compute the distance between objects across frames**.  
Objects across frames are then **paired by increasing distance** (objects can only be paired once). Leftover objects are considered new, and objects above a threshold are considered too distant and not paired, and considered new as well.

### Data

I used a pre-trained model so I don't need to train anything.  
So instead of using a nice dataset, I just searched youtube for "dashcam", copied the first interesting links then downloaded with `youtube-dl`, and extracted some frames with `ffmpeg`.

If you ran `setup.sh` you should now have some data under `local_data`, already separated into images.

Of course, that shoud work for any video (and most content, not driving related only), you could do something like this to create a new set:
```sh
youtube-dl -f 135 "URL" -o "set42.%(ext)s"
mkdir -p set42
ffmpeg -i "set42.mp4" -ss 00:00:10 -t 00:00:16 -vf fps=4 set42/img_%d.png
```

### Notes

See the markdown blocks for some general comments per section.

The code itself is documented and should be pretty straightforward.


### Bunch of imports

In [None]:
import numpy as np
import cv2
from IPython.display import display

import joblib

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

from maskrcnn_benchmark.config import cfg
from predictor import COCODemo, to_image_list
import torch

import re
from pathlib import Path
from collections import OrderedDict

### Prepare the pytorch model.
First time this is called, it will download about 300Mb of data for the weights.

I'm using a pretrained MaskRCNN implementation by Facebook research, despite the pace of DL research it's still close to the state of the art (as of January 2019), so that will do.  
Also, it's nicely packaged :D

I setup a color palette so that tracked objects keep their color though time.  
It's a global for this notebook so not very clean, with a bit more time I would put most of the code below in a class.

In [None]:
config_file = "./maskrcnn-benchmark/configs/caffe2/e2e_mask_rcnn_R_50_FPN_1x_caffe2.yaml"

cfg.merge_from_file(config_file)
cfg.merge_from_list(["MODEL.DEVICE", "cpu"])

coco_demo = COCODemo(
    cfg,
    min_image_size=800,
    confidence_threshold=0.7,
)

palette = sns.color_palette("deep", 20)

### Extract feature maps
Get the feature maps associated with each bounding boxes.

In [None]:
# Caching predictions to disk because they take ~12s each to generate,
# feel free to remove it if using GPU or beefier machine.
mem = joblib.Memory("./joblib_cache", verbose=0)

@mem.cache
def get_pred_feat(image):
    """Make predictions on an image and keep the feature maps."""
    original_image = image.copy()
    
    # Setup the image
    img = coco_demo.transforms(original_image)
    image_list = to_image_list(img, coco_demo.cfg.DATALOADER.SIZE_DIVISIBILITY)
    image_list = image_list.to(coco_demo.device)
    
    # Predict
    with torch.no_grad():
        targets = None
        images = to_image_list(image_list)
        features = coco_demo.model.backbone(images.tensors)
        proposals, proposal_losses = coco_demo.model.rpn(images, features, targets)
        x, result, detector_losses = coco_demo.model.roi_heads(features, proposals, targets)
        
    # Filter all predictions
    scores = result[0].get_field("scores")
    keep = torch.nonzero(scores > coco_demo.confidence_threshold).squeeze(1)
    pred = result[0][keep]
    scores = pred.get_field("scores")
    _, idx = scores.sort(0, descending=True)
    pred = pred[idx]
    
    # Resize to fit original image shape
    pred = pred.resize(original_image.shape[:-1][::-1])

    # Filter the feature maps as well
    feat_map = x[keep]
    feat_map = feat_map[idx].data.numpy()
    
    return pred, feat_map

### Two ways to compute distance

`compute_dist` is just the norm of the difference between two sets of feature maps.

`compute_bbox_dist` simply measure the distance between the center of two bounding boxes.

Turns out that the difference between feature maps was enough for it to work, but I didn't know.

In [None]:
def compute_dist(dict_objs: dict, features: np.ndarray):
    """Compute the distance between previous features and current ones"""
    dist_mat = np.zeros((len(dict_objs), features.shape[0]))
    
    for idxi, (k, v) in enumerate(dict_objs.items()):
        for idxj, j in enumerate(features):
            i = v["features"]
            dist = np.linalg.norm(i - j)
            dist_mat[idxi, idxj] = dist
    
    return dist_mat.round(2)

def compute_bbox_dist(dict_objs: dict, pred: torch.tensor):
    """Compute the distance between two bounding boxes"""
    bboxes = pred.bbox.data.numpy()
    dist_mat = np.zeros((len(dict_objs), bboxes.shape[0]))
    
    for idxi, (k, v) in enumerate(dict_objs.items()):
        for idxj, j in enumerate(bboxes):
            i = v["center"]
            assert j[2] > j[0]
            center = np.array([
                (j[2] - j[0]) / 2,
                (j[3] - j[1]) / 2,
            ])
            dist = np.abs(np.sum((i - center).flatten()))
            dist_mat[idxi, idxj] = dist
    
    return dist_mat.round(2)

def merge_dist(
    dist_features: np.ndarray, dist_bbox: np.ndarray, alpha: float):
    """Rescale and merge of distances.
    
    Alpha control the ratio of each, 1 for distance from features only.
    Hardcoded for quick prototyping.
    """
    a = dist_features / 200
    b = dist_bbox / 50
    return a * alpha + b * (1 - alpha)

### Keeping track of objects

Main method to keep track of objects in time.

It's just a big dictionary, with IDs for objects being simply the order in which they appear.

In [None]:
def init_obj_dict(pred: torch.tensor, features: np.ndarray):
    """Initialize a dict with all objects from current frame."""
    labels = pred.get_field("labels")
    # bbox are each: [x_min, y_min, x_max, y_max]
    bboxes = pred.bbox.data.numpy()
    current_objs = {
        k: {
            # age is to keep track of how long the object 
            # has not been seen, and discard old ones
            "age": 0,
            "features": v,
            "color": palette[k % len(palette)],
            "label": coco_demo.CATEGORIES[labels[k]],
            # center of bbox
            "center": np.array([
                (bboxes[k][2] - bboxes[k][0]) / 2,
                (bboxes[k][3] - bboxes[k][1]) / 2,
            ])
        } for k, v in enumerate(features)}
    return OrderedDict(sorted(current_objs.items()))

def update_obj_dict(
    dict_objs: dict,
    features: np.ndarray,
    dist_mat: np.ndarray,
    pred: torch.tensor,
    age_limit: int=2):
    """Update the dict of objects. Also return objects to plot at this frame.
    
    This method is a bit big and should be divided but well...
    """
    
    # Threshold for distance with previous objects, above is considered new
    threshold = 1.
    
    # index to keep track of objects from previous/current frame
    x = np.tile(np.arange(len(features)), len(dict_objs))
    y = np.repeat(np.arange(len(dict_objs.keys())), len(features))
    z = np.array(sorted(zip(dist_mat.flatten(), x, y), key=lambda x: x[0]))

    previous_obj = {i: None for i in range(len(dict_objs.keys()))}
    new_obj = {i: None for i in range(len(features))}
    
    # List of paired objects
    match_list = []
    try:
        to_check = z[z[:, 0] < threshold]
        for dist, a, b in to_check:
            try:
                previous_obj.pop(int(b))
                try:
                    new_obj.pop(int(a))
                except KeyError:
                    # Put back the key
                    previous_obj[int(b)] = None
                    continue
                match_list.append((int(a), int(b)))
            except KeyError:
                continue
    except IndexError:
        # Nothing above threshold
        pass
            
    new_labels = pred.get_field("labels")
    new_bboxes = pred.bbox.data.numpy()

    # Prepare new dictionary of updated objects in current frame
    new_dict = OrderedDict()
    _keys = list(dict_objs.keys())
    to_plot = [(a, _keys[b]) for a, b in match_list]
    # Objects still here in current frame
    still_here = {_keys[x[1]]: x[0] for x in match_list}

    for k, v in dict_objs.items():
        age = v["age"]
        
        # If the object is still in the frame, reset it's age
        if k in still_here.keys():
            age = -1
            new_id = still_here[k]
            # Update feature map
            v["features"] = features[new_id]
            # Update bbox center
            _bbox = new_bboxes[new_id]
            v["center"] = np.array([
                (_bbox[2] - _bbox[0]) / 2,
                (_bbox[3] - _bbox[1]) / 2,
            ])
            # Update label
            v["label"] = coco_demo.CATEGORIES[new_labels[new_id]]
            
        # Don't keep it if the object is too old
        if age > age_limit:
            continue
            
        # Icrement all ages
        v["age"] = age + 1
        new_dict[k] = v

    # Add new objects
    biggest_key = max(new_dict.keys(), default=0)
    for k, _ in new_obj.items():
        biggest_key += 1
        new_dict[biggest_key] = {
            "age": 0,
            "features": features[k],
            "color": palette[biggest_key % len(palette)],
            "label": coco_demo.CATEGORIES[new_labels[k]],
            "center": np.array([
                (new_bboxes[k][2] - new_bboxes[k][0]) / 2,
                (new_bboxes[k][3] - new_bboxes[k][1]) / 2,
            ])
        }
        to_plot.append((k, biggest_key))
        
    to_plot = [x[1] for x in sorted(to_plot, key=lambda x: x[0])]
    return new_dict, to_plot

### Plotting, loading and saving

Convenience methods to plot and save figures.

In [None]:
def natsort(l, key=lambda x: x):
    """Alphanumeric sort, for convenient listing of files"""
    r = re.compile("([0-9]+)")

    def alphanum_key(s):
        return [int(c) if c.isdigit() else c.lower() for c in r.split(key(s))]

    return sorted(l, key=alphanum_key)

def fill_bbox(bbox, ax, img, color, crop=False, alpha=0.6, label=None):
    t = bbox.data.numpy().round().astype(int)
    x0, y0 = t[0], t[1]
    x1, y1 = t[2], t[3]
    if crop:
        x1 = min(img.shape[1], x1)
        y1 = min(img.shape[0], y1)
    ax.fill([x0, x0, x1, x1], [y0, y1, y1, y0], alpha=alpha, color=color, label=label)
    
def plot_bbox(img, pred, meta, ids):
    fig, ax = plt.subplots(figsize=(16, 6))
    colors = []
    labels = []
    for i in ids:
        colors.append(meta[i]["color"])
        labels.append("{:>3}, {:<12}".format(i, meta[i]["label"]))
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    for idx, bbox in enumerate(pred.bbox):
        fill_bbox(bbox, ax, img, colors[idx], label=labels[idx])
    
    # Sort the legend for readability
    hand, lab = ax.get_legend_handles_labels()
    if len(hand) > 0:
        # sort both labels and handles by labels
        lab, hand = zip(*sorted(zip(lab, hand), key=lambda t: int(t[0].split(",")[0])))
        # Crop legends too long
        hand, lab = hand[:20], lab[:20]
        ax.legend(hand, lab, loc=6, bbox_to_anchor=(1., 0.5), frameon=False)
    
    return fig, ax

def save_fig(
    fig, ax, filename, save_dir="./results/"
):
    """Save a figure and create directory if necessary."""
    root = Path(save_dir)
    root.mkdir(parents=True, exist_ok=True)
    path = root.joinpath(filename)
    plt.savefig(str(path), bbox_inches="tight", pad_inches=0.1)
    plt.close()

### Predicting

Now we can just put it together and make our predictions.

In [None]:
def predict_and_update(obj_di: dict, img_path: str, alpha=0.5):
    """Convenience method to init or update objects dict, and make predictions."""
    
    # Get the image and predict
    img = cv2.imread(str(img_path))
    pred, fm = get_pred_feat(img)
    
    if obj_di is None:
        # Initialize with objects from current frame
        obj_di = init_obj_dict(pred, fm)
        to_plot = list(range(len(obj_di)))
    else:
        # Compare with previous frames
        dist_features = compute_dist(obj_di, fm)
        dist_bbox = compute_bbox_dist(obj_di, pred)
        dist_mat = merge_dist(dist_features, dist_bbox, alpha=alpha)
        obj_di, to_plot = update_obj_dict(obj_di, fm, dist_mat, pred)
    return img, pred, obj_di, to_plot

In [None]:
# No objects to start with
obj_di = None

# Location of directory containing the images.
# Replace with 
img_dir = Path("./local_data/set1/")
# img_dir = Path("./local_data/set2/")
# img_dir = Path("./local_data/set3/")

# Iter over all in order, and save the predictions
for img_path in natsort(img_dir.glob("*.png"), key=lambda x: x.stem):
    img, pred, obj_di, to_plot = predict_and_update(obj_di, img_path, alpha=0.8)
    
    fig, ax = plot_bbox(img, pred, obj_di, to_plot)

    save_fig(fig, ax, img_path.stem, save_dir=Path("./results").joinpath(
        img_path.parent.name))

### Result

Clearly, results are far from perfect, but it does work in most simple cases.  
Work in the sense that objects that don't change orientation drastically are matched from frame to frame and keep their ID. 

I played with the `alpha` parameter a bit, and not using the distance from bounding boxes (setting `alpha=1`) doesn't make much difference, so the feature map seems to be enough for a simple tracking.

Note, I generated the gif using this command:
```sh
ffmpeg -f image2 -framerate 4 -y -i ./results/set1/img_%d.png set1.gif
```

Again, if you ran `setup.sh` you should already have some results under `./results`