In [4]:
%load_ext autoreload
%autoreload 2

import numpy as np
from tqdm import tqdm

import sqlalchemy as sqla
from sqlalchemy import create_engine, Column, and_
from sqlalchemy.orm import Session

from ultrack.core.database import NodeDB
from ultrack.core.export.utils import solution_dataframe_from_sql

import sys
sys.path.append('..')
from tracks_interactions.db.db_model import Base, CellDB, TrackDB
from tracks_interactions.db.db_translate_functions import add_track_ids_to_tracks_df
from tracks_interactions.db.db_functions import calculate_cell_signals

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [5]:
# create a new database

new_db_path = r'D:\kasia\tracking\E6_exp\double_segmentation_ultrack\Exp6_gardener.db'

engine = create_engine(f'sqlite:///{new_db_path}')

# creates a table
Base.metadata.create_all(engine) 

In [6]:
# get engine for the original database

org_db_path = r'D:\kasia\tracking\E6_exp\double_segmentation_ultrack\data.db'
engine_org = create_engine(f'sqlite:///{org_db_path}')

In [7]:
# get a solution in a form of a dataframe

df = solution_dataframe_from_sql(f'sqlite:///{org_db_path}')
df = add_track_ids_to_tracks_df(df)

df.reset_index(inplace=True)
df

  return d[key]


Unnamed: 0,id,parent_id,t,z,y,x,track_id,parent_track_id,root
0,3000015,2000015,2,0.0,195.0,3850.0,45,-1,45.0
1,3000020,2000019,2,0.0,214.0,4561.0,54,-1,54.0
2,3000021,2000023,2,0.0,218.0,5011.0,64,-1,64.0
3,3000022,2000024,2,0.0,241.0,4204.0,56,-1,56.0
4,3000028,2000022,2,0.0,239.0,3780.0,55,-1,55.0
...,...,...,...,...,...,...,...,...,...
2328359,238012044,237012158,237,0.0,8296.0,3763.0,50518,50516,50514.0
2328360,238012045,237012159,237,0.0,8310.0,4604.0,53233,-1,53233.0
2328361,238012046,237012160,237,0.0,8307.0,4994.0,34369,34367,34356.0
2328362,238012047,237012161,237,0.0,8318.0,4131.0,56264,-1,56264.0


In [8]:
df.columns

Index(['id', 'parent_id', 't', 'z', 'y', 'x', 'track_id', 'parent_track_id',
       'root'],
      dtype='object')

## Create a cells table

In [9]:
# that has to be changed to operate on the original database
# because at the moment objects not assigned to a track are not saved in the database
# the consideration is what if multiple segmentations were given to ultrack and
# there are multiple possible objects for a single cell ???

def add_cell(row):

        global session
        global session_db_org
        
        cell = CellDB(id = row['id'],
                    t =row['t'],
                    track_id = row['track_id'],
                    row = row['y'],
                    col = row['x'])
        
        # get a mask of this cell
        cell_obj = session_db_org.query(NodeDB).filter(NodeDB.id==row['id']).first()

        cell.mask = cell_obj.pickle.mask
        
        cell.bbox_0 = int(cell_obj.pickle.bbox[0])
        cell.bbox_1 = int(cell_obj.pickle.bbox[1])
        cell.bbox_2 = int(cell_obj.pickle.bbox[2])
        cell.bbox_3 = int(cell_obj.pickle.bbox[3])

        session.add(cell)

In [10]:
# create a table of cells
# exp6 - ~ 15 min

tqdm.pandas(desc="Progress")

session_db_org = Session(engine_org)
session = Session(engine)

df.progress_apply(add_cell, axis=1)

session.commit()

session_db_org.close()
session.close()

Progress: 100%|██████████| 2328364/2328364 [14:59<00:00, 2589.43it/s] 


### Add signals to the cells table

In [11]:
import dask.array as da

In [12]:
ch0_path = r'D:\kasia\tracking\E6_exp\E6_C0.zarr'
ch1_path = r'D:\kasia\tracking\E6_exp\E6_C1.zarr'

ch0_da = da.from_zarr(ch0_path,1)
ch1_da = da.from_zarr(ch1_path,1)

In [13]:
type(ch0_da)==da.core.Array

True

In [14]:
# for exp6 around 25 min

session = Session(engine)

for frame in tqdm(range(ch0_da.shape[0])):

    cells = session.query(CellDB).filter(CellDB.t==frame).all()
    ch0 = ch0_da[frame].compute()
    ch1 = ch1_da[frame].compute()

    for cell in cells:

        # Calculate cell measurements for each cell
        new_signals = calculate_cell_signals(cell, [ch0, ch1])
        
        # Update the signals field with the new JSON data
        cell.signals = new_signals
        
    # Commit changes to the database
    session.commit()

session.close()

100%|██████████| 241/241 [30:08<00:00,  7.50s/it]


## Create a tracks table

In [15]:
df_tracks = df.groupby(['track_id','parent_track_id' ,'root']).agg({'t':['min','max']})
df_tracks.reset_index(inplace=True)
df_tracks.columns = ['_'.join(col).strip('_') for col in df_tracks.columns.values]
df_tracks

Unnamed: 0,track_id,parent_track_id,root,t_min,t_max
0,1,-1,1.0,3,14
1,2,-1,2.0,1,50
2,3,2,2.0,51,116
3,4,2,2.0,51,71
4,5,-1,5.0,1,28
...,...,...,...,...,...
56297,56298,-1,56298.0,237,239
56298,56299,-1,56299.0,237,237
56299,56300,56299,56299.0,238,239
56300,56301,56299,56299.0,238,239


In [16]:
def add_track(row):

        global session
        
        track = TrackDB(track_id = row['track_id'],
                        parent_track_id = row['parent_track_id'],
                        root = row['root'],
                        t_begin = row['t_min'],
                        t_end = row['t_max'])
        

        session.add(track)

In [17]:
# create a table of tracks

session = Session(engine)  

df_tracks.apply(add_track, axis=1)

session.commit()

session.close()

## Tests

In [14]:
with Session(engine) as session:

    results = session.query(CellDB.t, CellDB.signals['ch1_nuc'])\
        .filter(CellDB.track_id == 40)\
        .order_by(CellDB.t)\
        .all()

results

[(72, 1124.8367816092),
 (73, 974.61690647482),
 (74, 1056.7583497053),
 (75, 1048.63026819923),
 (76, 934.575959933222),
 (77, 870.22629969419),
 (78, 902.058441558442),
 (79, 871.124213836478),
 (80, 799.590395480226),
 (81, 896.310177705977),
 (82, 878.967032967033),
 (83, 898.824193548387),
 (84, 849.671775223499),
 (85, 868.837282780411),
 (86, 842.46015037594),
 (87, 880.224324324324),
 (88, 821.526315789474),
 (89, 791.246110325318),
 (90, 768.794466403162),
 (91, 782.233196159122),
 (92, 813.235807860262),
 (93, 920.645328719723),
 (94, 789.628180039139),
 (95, 767.321828358209),
 (96, 907.497478991597),
 (97, 797.595716198126),
 (98, 915.130136986301),
 (99, 1009.13223140496)]