In [2]:
%load_ext autoreload
%autoreload 2

import os
import zarr
import dask.array as da
import napari
import sys
import numpy as np

from sqlalchemy import create_engine
from sqlalchemy.orm import Session

from ete3 import Tree

sys.path.append('../libraries')
import family_graph as fg
from cells_database import Base, CellDB, TrackDB

In [3]:
os.environ['NAPARI_PERFMON'] = '0'

In [4]:
labels_zarr_path = r'D:\kasia\tracking\E6_exp\code\tests\example_track_labels.zarr'
labels = da.from_zarr(labels_zarr_path)
labels.shape

(241, 8396, 8401)

In [5]:
# get an example small set of labels
labels_small = labels[:10, 4000:4500, 4000:4500].compute()
labels_small.shape

(10, 500, 500)

In [6]:
# display the labels
viewer = napari.Viewer()
labels_layer = viewer.add_labels(labels,name='Labels')
fg.viewer = viewer

In [47]:
# get access to the database
new_db_path = r'D:\kasia\tracking\E6_exp\code\gardener_20_dev\cells_database_2tables - Copy.db'
engine = create_engine(f'sqlite:///{new_db_path}')

# send engine to the library
fg.engine = engine

In [8]:
# create a plot widget
t_max = viewer.dims.range[0][1]
plot_widget = fg.build_lineage_widget(t_max)
fg.plot_widget = plot_widget


# add lineage graph
viewer.window.add_dock_widget(plot_widget,area='bottom',name='family')

<napari._qt.widgets.qt_viewer_dock_widget.QtViewerDockWidget at 0x1fc82832f80>

In [9]:
# connect lineage graph update
labels_layer.events.selected_label.connect(fg.update_lineage_display)

<function family_graph.update_lineage_display(event)>

In [10]:
import track_module as tm 
from sqlalchemy import and_
import pandas as pd
from sqlalchemy.orm import aliased

In [26]:
def get_descendants(engine,active_label):
    
    """
    Function to recursively get all descendants of a given label.
    input:
        engine - sqlalchemy engine
        active_label - label for which we want to get descendants
    output:
        descendants - list of descendants as row objects 
    """

    with Session(engine) as session:

        cte = session.query(TrackDB).filter\
        (TrackDB.track_id == active_label).cte(recursive=True)
        
        cte_alias = aliased(cte, name='cte_alias')

        cte = cte.union_all(\
            session.query(TrackDB).filter\
            (TrackDB.parent_track_id == cte_alias.c.track_id)
        )

        descendants = session.query(cte).all()

        return descendants

In [27]:
def get_track_bbox(query):

    """
    Helper function that returns bounding box of a track.
    input:
        query - list of row objects
    output:
        row_start, row_stop, column_start, column_stop- bounding box
        t_stop - last time point of the track
    """

    # find bounding boxes of the track
    row_start = min(cell.bbox_0 for cell in query)
    row_stop = max(cell.bbox_2 for cell in query)
    column_start = min(cell.bbox_1 for cell in query)
    column_stop = max(cell.bbox_3 for cell in query)
    t_stop = max(cell.t for cell in query)

    return (t_stop, row_start, row_stop, column_start, column_stop)

In [56]:
def modify_trackDB(engine,descendants,active_label,new_track,current_frame):
    
    """
    Function to modify track_id for a given label and all its descendants.
    input:
        engine - sqlalchemy engine
        descendants - descendants always starting with the active_label itself
        active_label - label for which we want to get descendants
        new_track - new track_id
        current_frame - time point of a cut
    output:
        None
    """

    with Session(engine) as session:

        # get the acual track and check what will be done
        record = session.query(TrackDB).filter_by(track_id=descendants[0].track_id).first()

        # when a track is truly cut
        if record.t_begin < current_frame:

            # add completely new track to start the new family
            track = TrackDB(track_id = new_track,\
            parent_track_id = -1,\
            root = new_track,\
            t_begin = current_frame,\
            t_end = record.t_end)

            session.add(track)

            # modify the end of the track 
            record.t_end = current_frame - 1
            
        # if just deataching from mitosis or by accident clicked on the first label that is unconnected
        else:

            # indicate that the track is now first in the family
            record.parent_track_id = -1

            # there will be no new_track entry added
            new_track = active_label

        # changes for true descendants 
        for tr in descendants[1:]:

            # Fetch the actual record from the database
            record = session.query(TrackDB).filter_by(track_id=tr.track_id).first()

            # change the value of the root track
            record.root = new_track

            if record.parent_track_id == active_label:

                 # indicate that the track is now first in the family
                record.parent_track_id = new_track

        session.commit()

In [29]:
def cut_cellsDB(engine,descendants,active_label,current_frame):

    """
    Function to change track_id in cellsDB.
    To minimize the number of queries returns the bounding box of the track
    and the value for a new track_id.
    input:
        engine - sqlalchemy engine
        descendants - list of descendants as row objects
        active_label - label for which the track is cut
        current_frame - current time point  
    output:
        new_track - new track_id
        track_bbox - bounding box of the track
    """

    with Session(engine) as session:
        
        # query CellDB 
        # order by time
        query = session.query(CellDB).filter\
        (and_(CellDB.track_id == active_label, CellDB.t >= current_frame))\
        .order_by(CellDB.t).all()

        assert len(query) > 0, 'No cells found for the given track'

        # change the parent id of the first cell
        query[0].parent_id = -1

        # the cell is actually the beginning of the track
        # only changes to the connections
        if (query[0].t == descendants[0].t_begin):

            # no need for changing the active track_id
            new_track = active_label

            # there will be no changes in the labels layer
            track_bbox = None

        else:

            # find new track number
            new_track = tm.newTrack_number(engine)

            # change track_ids for the cells
            for cell in query:
                cell.track_id = new_track

            # get the track_bbox
            track_bbox = get_track_bbox(query)

            session.commit()

    return new_track, track_bbox

In [42]:
# get the position in time
current_frame = viewer.dims.current_step[0]

# get my label
active_label = int(viewer.layers['Labels'].selected_label)

# find new track number
newTrack = tm.newTrack_number(engine)

In [52]:
print(active_label)
print(newTrack)
print(current_frame)
print(descendants)  

11430
75019
16
[(11430, 11429, 11429, 13, 47), (11432, 11430, 11429, 48, 54), (11431, 11430, 11429, 48, 182), (11433, 11432, 11429, 55, 55), (11436, 11432, 11429, 55, 80), (11435, 11433, 11429, 56, 111), (11434, 11433, 11429, 56, 180)]


In [57]:
# get descendants
descendants = get_descendants(engine,active_label)

# cut cellsDB
new_track, track_bbox = cut_cellsDB(engine,descendants,active_label,current_frame)

# cut trackDB
modify_trackDB(engine,descendants,active_label,new_track,current_frame)

In [45]:
print(new_track)
print(track_bbox)

75019
(47, 3941, 4010, 4153, 4209)


In [None]:
# modify labels if needed
if track_bbox is not None:

    sel = labels[current_frame:track_bbox[0],track_bbox[1]:track_bbox[2],track_bbox[3]:track_bbox[4]
    sel[sel == active_label] = newTrack
    labels[current_frame:track_bbox[0],track_bbox[1]:track_bbox[2],track_bbox[3]:track_bbox[4]] = sel

    viewer.layers['Labels'].data = labels

In [58]:
with Session(engine) as session:

    query = session.query(TrackDB).filter(TrackDB.root == new_track).all()

query


[Track 11431 from 48 to 182,
 Track 11432 from 48 to 54,
 Track 11433 from 55 to 55,
 Track 11434 from 56 to 180,
 Track 11435 from 56 to 111,
 Track 11436 from 55 to 80,
 Track 75019 from 16 to 15]

In [53]:
with Session(engine) as session:

    query = session.query(CellDB).filter\
        (and_(CellDB.track_id == active_label, CellDB.t >= current_frame))\
    .order_by(CellDB.t).all()

In [54]:
query

[17004006 from frame 16 with track_id 11430 at (3959,4190),
 18004011 from frame 17 with track_id 11430 at (3958,4188),
 19004042 from frame 18 with track_id 11430 at (3960,4186),
 20004162 from frame 19 with track_id 11430 at (3961,4186),
 21004244 from frame 20 with track_id 11430 at (3960,4184),
 22004265 from frame 21 with track_id 11430 at (3959,4183),
 23004320 from frame 22 with track_id 11430 at (3958,4182),
 24004298 from frame 23 with track_id 11430 at (3957,4180),
 25004413 from frame 24 with track_id 11430 at (3959,4186),
 26004425 from frame 25 with track_id 11430 at (3954,4175),
 27004476 from frame 26 with track_id 11430 at (3955,4174),
 28004499 from frame 27 with track_id 11430 at (3959,4180),
 29004646 from frame 28 with track_id 11430 at (3960,4178),
 30004605 from frame 29 with track_id 11430 at (3962,4174),
 31004688 from frame 30 with track_id 11430 at (3963,4174),
 32004702 from frame 31 with track_id 11430 at (3965,4172),
 33004717 from frame 32 with track_id 11

In [151]:
print(row_start)
print(row_stop)
print(column_start)
print(column_stop)
print(current_frame)
print(t_stop)

0
0
0
0
19
0


In [142]:
viewer.layers['Labels'].data = labels

In [110]:
with Session(engine) as session:
    
    query = session.query(CellDB).filter((CellDB.track_id == newTrack))
    df = pd.read_sql(query.statement, engine)

In [111]:
df.head()

Unnamed: 0,track_id,t,id,parent_id,row,col,bbox_0,bbox_1,bbox_2,bbox_3,mask
0,75019,71,72009349,-1,5578,4428,5561,4413,5597,4445,"[[False, False, False, False, False, False, Fa..."
1,75019,72,73009365,72009349,5570,4425,5553,4410,5588,4441,"[[False, False, False, False, False, False, Fa..."
2,75019,73,74009387,73009365,5569,4423,5551,4409,5589,4439,"[[False, False, False, False, False, False, Fa..."
3,75019,74,75009333,74009387,5563,4421,5545,4408,5583,4437,"[[False, False, False, False, False, False, Fa..."
4,75019,75,76009314,75009333,5558,4417,5540,4401,5576,4434,"[[False, False, False, False, False, False, Fa..."


### Test

In [9]:
my_root = 1

plot_view = plot_widget.getItem(0,0)
plot_view.clear()

# buid the tree
tree = fg.build_Newick_tree(engine, my_root)

In [10]:
# update the widget with the tree
plot_view = fg.render_tree_view(plot_view,tree,viewer)