In [2]:
## Imports
import os
import numpy as np
from collections import OrderedDict
import scipy.spatial.distance as dist
from anytree import Node, RenderTree
from tqdm import tqdm
import glob
import pickle

import logging

class RetroTracker:
    def __init__(self, in_dir: str, maxDist: int = 15, maxAbscence: int = 3, verbose: bool = False):
        '''
        '''
        self.nextID = 0
        self.Total_time = 0
        self.objects = OrderedDict()
        self.tracks = OrderedDict()
        self.timeframe_ids = []
        self.missing = OrderedDict()
        self.merge = OrderedDict()                  ## Save timepoint where cell must be merged
        self.maxDist = maxDist
        self.maxAbscence = maxAbscence
        self.output = OrderedDict()
        self.Tree = {}
        self.parents = []
        self.in_dir = in_dir
        self.verbose = verbose

    def register(self, centroid: list, label: str):
        '''
        Adds a cellID to the tracker
        Args:
            centroid        list containing the XY coordinates of centroid
        '''
        ID = self.nextID
        self.objects[ID] = centroid       # Save current location
        self.tracks[ID] = OrderedDict()    # Create dictionary to save previous locations
        self.tracks[ID] = [centroid]
        self.missing[ID] = 0
        self.nextID += 1
        
        return ID

    def remove(self, objectID: int):
        '''
        Removes cell from list of active objects i.e. ends tracking
        Args:
            objectID        ID of the object being removed
        '''
        del self.objects[objectID]
        del self.missing[objectID]
        self.tracks[objectID] = self.tracks[objectID][:-self.maxAbscence]
        self.merge[objectID] = len(self.tracks[objectID])

    def update(self, ids: list, positions: np.ndarray, time_frame: str = None):
        '''
        Updates the status of tracked cells by either:
            - adding a new location
            - marking them as abscent
            - Stop tracking and merging into another track

        Args:
            positions       Numpy array of XY positions with shape (n,2)
        '''
        
        InputIDs = []
        ReturnIDs = []
        self.timeframe_ids.append(time_frame)
        iterations = 0
        
        ## Check if the list of positions is empty
        if len(positions[0]) == 0:
            ## Mark existing objects as missing
            for objectID in self.missing:
                self.missing[objectID] += 1
                self.tracks[objectID].append(self.tracks[objectID][-1])

                ## If object exceeds maxAbscence, remove from tracking list
                if self.missing[objectID] > self.maxAbscence:
                    self.remove(objectID)

            ## Return as there are no more updates to be done
            if self.verbose:
                logging.info('No objects to be updated')
            return [], []
        
        ## Check if no more objects are being tracked
        elif len(self.objects.values()) == 0:
            ## Return as there are no more objects being tracked
            if self.verbose:
                logging.info('No more objects being tracked')
            return None, None

        ## Match input positions to existing objects
        else:
            ## Get existing objects
            objectIDs = list(self.objects.keys())
            objects_xy = list(self.objects.values())

            ## Compute distance between current objects and new positions
            D = dist.cdist(np.array(objects_xy), positions)

            ## Sort the smallext distance for each object
            rows = D.min(axis=1).argsort()

            ## Also sort the smallest distance of new centroids and sort by rows
            cols = D.argmin(axis=1)[rows]

            ## Keep track of used rows and columns
            usedRows = set()
            usedCols = set()

            ## Check matched objects and positions
            for (row, col) in zip(rows, cols):
                if D[row,col] < self.maxDist:
                    ## If cell or position is already assigned
                    if row in usedRows or col in usedCols:
                        continue

                    # Get the objectID
                    objectID = objectIDs[row]

                    ## If object is marked missing, skip in first iteration
                    if self.missing[objectID] > 0:
                        continue

                    ## Update existing object with new position
                    self.tracks[objectID].append(self.objects[objectID])
                    self.objects[objectID] = positions[col]
                    self.missing[objectID] = 0
                    
                    InputIDs.append(ids[col])
                    ReturnIDs.append(objectID)

                    ## Update used rows and columns
                    usedRows.add(row)
                    usedCols.add(col)

            ## Get unused rows and columns that have not been assigned yet
            unusedRows = set(range(0, D.shape[0])).difference(usedRows)
            unusedCols = set(range(0, D.shape[1])).difference(usedCols)

            ## Reexamine remaining objects and positions for potential matches, 
            ## keep examining matches until no new matches are being added
            if len(unusedCols) > 0 and len(unusedRows) > 0:
                new_matches = 1
                iterations = 2
                while new_matches > 0:
                    new_matches = 0
                    obs2 = [objects_xy[x] for x in list(unusedRows)]
                    pos2 = [positions[x] for x in list(unusedCols)]
                    
                    ## Calculate new distance tree
                    D2 = dist.cdist(obs2, pos2)

                    ## Once again order the rows
                    rows = D2.min(axis=1).argsort()
                    rowsx = [list(unusedRows)[i] for i in rows]
                    cols = D2.argmin(axis=1)[rows]
                    colsx = [list(unusedCols)[i] for i in cols]

                    ## Isolate distances for rows and cols
                    Dv = [D2[k,v] for k,v in zip(rows,cols)]

                    ## loop over the combination of the (rowsx, colsx) index tuples
                    for k, (row, col) in enumerate(zip(rowsx, colsx)):
                        if Dv[k] < self.maxDist:
                            # if we have already examined either the row or
                            # column value before, ignore it
                            if row in usedRows or col in usedCols:
                                continue

                            # otherwise, grab the object ID for the current row,
                            # set its new centroid, and reset the disappeared
                            # counter
                            objectID = objectIDs[row]
                            self.tracks[objectID].append(self.objects[objectID])
                            self.objects[objectID] = positions[col]
                            self.missing[objectID] = 0
                            
                            InputIDs.append(ids[col])
                            ReturnIDs.append(objectID)

                            # indicate that we have examined each of the row and
                            # column indexes, respectively
                            usedRows.add(row)
                            usedCols.add(col)
                            new_matches += 1
                    iterations += 1

            # List all the unused rows and columns
            unusedRows = set(range(0, D.shape[0])).difference(usedRows)
            unusedCols = set(range(0, D.shape[1])).difference(usedCols)

            ## If there are still unused rows left, mark them as missing
            for row in unusedRows:
                # grab the object ID for the corresponding row
                # index and increment the missing counter
                objectID = objectIDs[row]
                self.missing[objectID] += 1
                self.tracks[objectID].append(self.objects[objectID])

                # check to see if the number of consecutive frames the object has 
                # been marked "missing" for warrants removing the object from tracking list
                if self.missing[objectID] > self.maxAbscence:
                    self.remove(objectID)

            if self.verbose:
                logging.info(f'Updated tracks for {len(self.objects.keys())} in {iterations} iterations')
            return InputIDs, ReturnIDs
    
    def merge_tracks(self):
        '''
        Takes the list of object merges and returns a collection of lineage trees. 
        Start merging from earliest timepoint up to latest
        Args:

        '''
        merge_ID_list, timepoints = np.array(list(self.merge.keys())), list(self.merge.values())
        
        for t in np.unique(timepoints)[::-1]:
            t_now = [x == t for x in timepoints] ## Get indexes of current t
            merge_IDs = merge_ID_list[t_now]  ## Get IDs for cells merging at current t
            merge_xy = [self.tracks[x][-1] for x in merge_IDs]

            object_IDs = []
            object_xy = []
            for k in self.tracks:
                if k not in merge_IDs:
                    try:
                        object_xy.append(self.tracks[k][t])
                        object_IDs.append(k)
                    except:
                        continue
            
            if not len(object_xy) == 0:
                ## Compute distance between merge objects and tracked objects
                D = dist.cdist(np.array(merge_xy), np.array(object_xy))

                ## Sort the smallext distance for each merge object
                rows = D.min(axis=1).argsort()

                ## Also sort the smallest distance of existing objects and sort by rows
                cols = D.argmin(axis=1)[rows]

                ## Keep track of used rows and columns
                usedRows = set()
                usedCols = set()

                ## Check matched objects and positions
                for (row, col) in zip(rows, cols):
                    if D[row,col] < self.maxDist:
                        ## If cell or position is already assigned
                        if row in usedRows or col in usedCols:
                            continue

                        # Get the merge_ID and object_ID
                        mergeID = merge_IDs[row]
                        objectID = object_IDs[col]

                        ## Add to tree
                        parent = str(objectID)
                        child = str(mergeID)
                        if parent not in self.Tree:
                            self.Tree[parent] = Node(parent)
                            self.parents.append(parent)
                        self.Tree[child] = Node(child, parent= self.Tree[parent])

                        ## Update used rows and columns
                        usedRows.add(row)
                        usedCols.add(col)
                                        

    def fit(self):
        '''
        Runs the full Retrotracker workflow and returns
        '''
        
        ## First list all files in directory
        File = sorted(glob.glob(os.path.join(self.in_dir, '*.pkl')))
        
        ## Register all endpoint cells in the tracker
        cells = pickle.load(open(os.path.join(self.in_dir, File[-1]), 'rb')) ## Start with last file (assumed files are in correct order)
        t = File[-1].split('.')[0]
        self.timeframe_ids.append(t)
        self.Total_time = len(File)
        
        renumbered = []
        for id, centroid in zip(cells['id'], cells['centroid']):
            new_id = self.register(centroid=centroid, label = id)
            renumbered.append(new_id)
        cells['cell_id'] = renumbered
        
        ## Generate output dict
        output = OrderedDict()
        for k in cells['cell_id']:
            self.output[k] = {}
            x = np.where(np.array(cells['cell_id']) == k)[0][0]
            self.output[k]['id'], self.output[k]['area'], self.output[k]['centroid'] = [cells['id'][x]], [cells['area'][x]], [cells['centroid'][x]]
            self.output[k]['eccentricity'], self.output[k]['image'] = [cells['eccentricity'][x]], [cells['image'][x]]
            
        ## Save renumbered dictionary
        pickle.dump(cells, open(os.path.join(self.in_dir, File[-1]), 'wb'))
        logging.info(f'Registered initial cell positions')
        
        ## Update with new list of ids and positions
        progress = tqdm(total = len(File)-1)
        for f in File[-2::-1]:
            cells = pickle.load(open(os.path.join(self.in_dir, f), 'rb'))
            inputIDs, returnIDs = self.update(cells['id'], cells['centroid'], f.split('.')[0])
            
            if inputIDs == None:
                break
            
            ## Update output dict
            for inp, ret in sorted(zip(inputIDs, returnIDs)):
                for k in ['id', 'area', 'centroid', 'eccentricity', 'image']:
                    self.output[ret][k].append(cells[k][np.where(np.array(cells['id']) == inp)[0][0]])
            progress.update(1)
        progress.close()

        ## Resolve the tree
        self.merge_tracks()

In [33]:
mdir = '/rawa/sl/embryoscope/processed/EV25_180621_processed/segmented/Region1_YFP_tl'
pos = 'pos_14'

sdir = f'{mdir}/{pos}'

tracks_table = pd.read_csv(os.path.join(sdir, 'tracks.csv'), index_col=0)
Tree = pkl.load(open(os.path.join(sdir, 'Tree.pkl'), 'rb'))
RT = pkl.load(open(os.path.join(sdir, 'RT.pkl'), 'rb'))

print(Tree)
print(np.unique(tracks_table.label))
print(RT.merge_tracks)

merge_ID_list, timepoints = np.array(list(RT.merge.keys())), list(RT.merge.values())

In [53]:
for parent in Tree:
    if Tree[parent].is_root:
        print(parent)
        print(Tree[parent].descendants)
#         if len(Tree[parent].descendants) > 0:
#             for child in Tree[parent].descendants:

42
(Node('/42/44'),)
48
(Node('/48/46'),)
40
(Node('/40/25'),)
38
(Node('/38/36'), Node('/38/36/37'))
3
(Node('/3/5'),)
24
(Node('/24/32'),)
52
(Node('/52/56'),)


In [49]:
File = sorted(os.listdir('/rawa/sl/embryoscope/processed/EV25_180621_processed/segmented/Region1_YFP_tl/pos_14/cells'))
t = File[-1].split('.')[0]
len(File)

649

In [12]:
import pandas as pd
import pickle as pkl

mdir = '/rawa/sl/embryoscope/processed/EV25_180621_processed/segmented/Region1_YFP_tl'

for pos in os.listdir(mdir):
    sdir = f'{mdir}/{pos}'

    tracks_table = pd.read_csv(os.path.join(sdir, 'tracks.csv'), index_col=0)
    Tree = pkl.load(open(os.path.join(sdir, 'Tree.pkl'), 'rb'))
    
    print(pos)
    print(Tree)
    print(np.unique(tracks_table.label))
    print('###############')

pos_24
{'4': Node('/4'), '5': Node('/4/5'), '3': Node('/4/3')}
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
###############
pos_7
{'14': Node('/14'), '11': Node('/14/11'), '10': Node('/10'), '9': Node('/10/9')}
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18]
###############
pos_16
{}
[ 0  1  2  3  4  5  6  7  8  9 10 11]
###############
pos_30
{'17': Node('/17'), '20': Node('/17/20'), '13': Node('/13'), '12': Node('/13/12'), '22': Node('/22'), '21': Node('/22/21')}
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52]
###############
pos_11
{}
[0 1 2]
###############
pos_0
{}
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
###############
pos_23
{'2': Node('/2'), '3': Node('/2/3')}
[0 1 2 3 4 5 6]
###############
pos_29
{}
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15]
###############
pos_27
{'404': Node('/404'), '387': Node('/404/387'), '68': Node('/68'), '

In [10]:
print(Tree)
print(np.unique(tracks_table.label))

{'30': Node('/30'), '27': Node('/30/27')}
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38]
