# Tracking of YeastMate detections over timeseries

By default, YeastMate processes every 2D input image independently. In case you want to specfifically analyze time series, this notebook provides to code to match cells across frames via simple overlap-based tracking.

## 1) Function definitions

Run the cell below to define the function we will be using to track our results: ```track_yeastmate_results(masks, jsons, min_iou=0.5)```. It expects a list of masks and detection JSON-dicts and will match labels above a defined minimum overlap ```min_iou``` in them.

In [None]:
from itertools import count
from operator import sub
from copy import deepcopy

import numpy as np
from skimage.metrics import contingency_table
from skimage.segmentation import relabel_sequential
from scipy.optimize import linear_sum_assignment

def get_mask_ious(mask1, mask2):
    # relabel to ensure masks start at 1
    mask1, _, _ = relabel_sequential(mask1)
    mask2, _, _ = relabel_sequential(mask2)
    intersection = contingency_table(mask1, mask2).toarray()
    union = np.add.outer(np.diag(contingency_table(mask1, mask1).toarray()), np.diag(contingency_table(mask2, mask2).toarray())) - intersection
    # drop background (0) row and col
    return (intersection / union)[1:,1:]

def match_ious(ious, min_iou):
    ious[ious < min_iou] = 0
    rows, cols = linear_sum_assignment(ious, maximize=True)

    valid = ious[rows, cols] > min_iou
    return rows[valid], cols[valid]

def relabel_mask_next_frame(mask_prev, mask_next, min_iou, max_label=0):

    # previous mask as sequential starting from 1
    # we need the inverse map indices of matched rows to values in prev. frame
    seq_mask_prev, _, inv_prev = relabel_sequential(mask_prev)
    # next mask starts at previous max value + 1
    seq_mask_next, fwd_next, inv_next = relabel_sequential(mask_next, max_label+1)

    # get matching prev -> next (indices correspond to sequential idxs in seq_masks)
    rows, cols = match_ious(get_mask_ious(seq_mask_prev, seq_mask_next), min_iou)

    # keep track of unmatched labels in next frame
    unmatched_keys = set(np.unique(mask_next))
    unmatched_keys.remove(0)
    # standard map of labels
    label_map = dict()
    for (r,c) in zip(rows, cols):
        # change map to value from previous frame
        fwd_next[inv_next[c+max_label+1]] = inv_prev[r+1]
        # also keep in standard map
        label_map[inv_next[c+max_label+1]] = inv_prev[r+1]
        unmatched_keys.remove(inv_next[c+max_label+1])

    # re-index unmatched labels from start values (to prevent jumps in labels)
    dst = count(max_label+1)
    for src in unmatched_keys:
        v = next(dst)
        fwd_next[src] = v
        label_map[src] = v

    # apply to mask
    seq_mask_next = fwd_next[mask_next]

    return seq_mask_next, np.max([max_label, np.max(seq_mask_next)]), label_map

def match_boxes(boxes_map1, boxes_map2, min_iou, max_label=0):

    # get list of keys as int and boxes as array in same order
    keys1 = list(map(int, boxes_map1.keys()))
    keys2 = list(map(int, boxes_map2.keys()))
    a = np.array(list(boxes_map1.values()))
    b = np.array(list(boxes_map2.values()))

    res = dict()
    if not(len(a)==0 or len(b)==0):

        # get max mins and min maxes
        ymin = np.maximum.outer(a[:,0], b[:,0])
        ymax = np.minimum.outer(a[:,2], b[:,2])
        xmin = np.maximum.outer(a[:,1], b[:,1])
        xmax = np.minimum.outer(a[:,3], b[:,3])
        # product of intersection lenghts in x and y, clipped at 0
        intersection = np.clip(ymax - ymin, 0, np.finfo(float).max) * np.clip(xmax - xmin, 0, np.finfo(float).max)

        # get areas (rowwise product of max cols - min cols)
        aareas = np.prod(sub(*np.split(a, 2, axis=1)[::-1]), axis=1)
        bareas = np.prod(sub(*np.split(b, 2, axis=1)[::-1]), axis=1)

        union = np.add.outer(aareas, bareas) - intersection
        iou = intersection / union
        iou[iou<min_iou] = 0

        rows, cols = linear_sum_assignment(iou, maximize=True)

        # return map from keys2 to keys1 for valid matches
        for (r,c) in zip(rows, cols):
            if iou[r,c] > min_iou:
                res[keys2[c]] = keys1[r]

    # for unmatched labels, give a new, sequential label
    label_ctr = count(max_label+1)
    for k2 in keys2:
        if k2 not in res:
            res[k2] = next(label_ctr)

    # return map and maximum label
    return res, np.max([max_label, max(res.values()) if len(res) > 0 else 0])

def relabel_json_next_frame(json_prev, json_next, single_cell_map : dict, min_iou, max_label=0):

    # get mating boxes and match
    mating_boxes_map_prev = {k:v['box'] for k,v in json_prev.items() if v['class'][0] == '1'}
    mating_boxes_map_next = {k:v['box'] for k,v in json_next.items() if v['class'][0] == '1'}
    mating_map, max_label = match_boxes(mating_boxes_map_prev, mating_boxes_map_next, min_iou, max_label)

    # get budding boxes and match
    budding_boxes_map_prev = {k:v['box'] for k,v in json_prev.items() if v['class'][0] == '2'}
    budding_boxes_map_next = {k:v['box'] for k,v in json_next.items() if v['class'][0] == '2'}
    budding_map, max_label = match_boxes(budding_boxes_map_prev, budding_boxes_map_next, min_iou, max_label)

    # get combined map of relabelling to be done in json_next
    all_labels_map = dict()
    all_labels_map.update(single_cell_map)
    all_labels_map.update(mating_map)
    all_labels_map.update(budding_map)

    # actually update the json dict
    json_next_updated = update_json(json_next, all_labels_map)

    return json_next_updated, max_label

def update_json(json_next, label_map):
    # update all labels in a detections dict json_next from remapped label in label_map
    # NB: ids in label_map are int but json_next has str labels, result will be string
    res = {}
    for k,v in json_next.items():
        v = deepcopy(v)
        v['id'] = str(label_map[int(v['id'])])
        v['links'] = [str(label_map[int(l)]) for l in v['links']]
        res[str(label_map[int(k)])] = v
    return res

def track_yeastmate_results(masks, jsons, min_iou=0.5):

    # get maximum label in first frame
    max_label = np.max(list(map(int, jsons[0].keys())) + [0])

    # NB: we assume the first frames to be already sequentially labelled
    out_masks = [masks[0]]
    out_jsons = [jsons[0]]

    for i in range(1, len(masks)):

        # update mask for next frame
        next_mask, max_label, label_map = relabel_mask_next_frame(out_masks[-1], masks[i], min_iou, max_label)
        out_masks.append(next_mask)

        # update json for next frame: match compound objects, update labels
        json_next_updated, max_label = relabel_json_next_frame(out_jsons[-1], jsons[i], label_map, min_iou, max_label)
        out_jsons.append(json_next_updated)

    return out_masks, out_jsons

## 2) Read YeastMate output

In the cells below, we read the ```*_mask.tif``` and ```*_detections.json``` files produced by the standalone GUI for a folder containing a timeseries. **When adapting this to your own data, make sure that you list and read the files in the correct order** 

In [None]:
import json
from glob import glob

mask_files = sorted(glob('C:/Users/david/Desktop/yit_ds1/yit_ds1_BF_frame???_mask.tif'))
detection_files = sorted(glob('C:/Users/david/Desktop/yit_ds1/yit_ds1_BF_frame???_detections.json'))

# print to make sure files are sorted
mask_files[:10], detection_files[:10]

In [None]:
from skimage.io import imread

masks = [imread(f) for f in mask_files]
jsons = []
for detection_file in detection_files:
    with open(detection_file, 'r') as fd:
        detection_json = json.load(fd)['detections']
        jsons.append(detection_json)

## 2) ALTERNATIVE: Predict and track from code

If you want to use the YeastMate detection from code, you can apply tracking immediately.

In [None]:
from glob import glob

# make sure that the frames of your timeseries are sorted correctly
files = sorted(glob('C:/Users/david/Desktop/yit_ds1/yit_ds1_BF_frame???.tif'))
files [:10]

In [None]:
from skimage.io import imread
from yeastmatedetector.inference import YeastMatePredictor

predictor = YeastMatePredictor('../models/yeastmate.yaml')

jsons = []
masks = []

# load raw images and predict with YeastMate
for file in files:
    img = imread(file)
    detections, mask = predictor.inference(img)
    jsons.append(detections)
    masks.append(mask)

## 3) Track timeseries

Once you have a list of masks and detection dicts, you can simply use ```track_yeastmate_results``` to match labels over time.

In [None]:
updated_masks, updated_jsons = track_yeastmate_results(masks, jsons, min_iou=0.25)

To visualize the tracked masks quickly, you can use napari:

In [None]:
from napari import view_image

view_image(np.stack(updated_masks))

## 4) Save results

If you have loaded output from the standalone GUI, you can overwrite the mask and detection files with updated versions using the cell below:

In [None]:
from skimage.io import imsave

for mask_file, detection_file, updated_mask, updated_json in zip(mask_files, detection_files, updated_masks, updated_jsons):

    # overwrite mask file with updated version
    imsave(mask_file, updated_mask)

    # read old json again
    with open(detection_file, 'r') as fd:
        json_old = json.load(fd)

    # replace 'detections' in json
    json_old['detections'] = updated_json

    # overwrite json with updated version
    with open(detection_file, 'w') as fd:
        json.dump(json_old, fd, indent=1)
