In [None]:
import tensorflow as tf
from keras.models import load_model

import os
import nd2

import numpy as np
import pandas as pd

import stardist
from stardist.models import StarDist2D
from csbdeep.utils import normalize

from btrack.io import localizations_to_objects
import btrack

from skimage import io
import skimage

from scipy import ndimage as ndi

import edt
from scipy.ndimage import distance_transform_edt

import ray
import zarr
import dask.array as da

axis_norm = (0,1)

%config Completer.use_jedi = False

In [None]:
# load models for stardist and CNN for nuclei classification.
model = load_model('cnn_model_onlyintensityimage.h5')
modelStar =  StarDist2D(None, name='stardist', basedir='models')

In [None]:
# load nd2 file using a pop up window and nd2 library. 

import tkinter as tk
from tkinter import filedialog

root= tk.Tk()
root.wm_attributes('-topmost', True)
path = filedialog.askopenfile()

root.destroy()

nd2file = nd2.ND2File(path.name)
viewnd2da = nd2file.to_dask()


In [None]:
# Check nd2 file in napari to choose position and timepoints for cropping
import napari

viewer = napari.Viewer()

viewer.add_image(viewnd2da, channel_axis=-3)

In [None]:
# extract channel information from metadata and assign channels to constant for channel selection later on.This is specific
# to nd2 files.

if 'GFP em' in nd2file.metadata.channels[0].channel.name:
    for channel in nd2file.metadata.channels:
        if 'GFP em 1' in channel.channel.name:
            ERK = channel.channel.index
        elif 'GFP em 2' in channel.channel.name:
            H2B = channel.channel.index
else:
    for channel in nd2file.metadata.channels:
        if 'GFP' in channel.channel.name:
            ERK = channel.channel.index
        elif 'Cy3' in channel.channel.name:
            OCT = channel.channel.index
        elif 'Cy5' in channel.channel.name:
            H2B = channel.channel.index
            
#get image dimensions:

Y_dim = nd2file.attributes.heightPx
X_dim = nd2file.attributes.widthPx

#get time metadata and convert to min

try: t_interval = nd2file.experiment[0].parameters.periodMs/60000
    
except: t_interval = nd2file.experiment[0].parameters.periods[0].periodMs/60000
        
#select position and crop time and load into memory
pos = 4
volume = viewnd2da[80:500,pos,...].compute()

In [None]:

skimage.filters.try_all_threshold(skimage.filters.gaussian(volume[-1,ERK,...]+volume[-1,H2B,...],sigma=20))


In [None]:
# function to segment colony to later exclude nuclei/debris detection not in colony. Parallelized using ray.
from ray.util.multiprocessing import Pool


pool=Pool()

def binary_processing(image):

    smooth = skimage.filters.gaussian((image[ERK] + image[H2B]) , sigma= 20)
    
    # using try to avoid errors when there is no thresholding possible,i.e. colony spanning over the whole field of view.
    try: 
        tresh = skimage.filters.threshold_triangle(smooth)
        binary = smooth  > tresh
        binary = skimage.morphology.remove_small_objects(binary, min_size=10000)
        binary = ndi.binary_fill_holes(binary)
        binary = skimage.morphology.remove_small_holes(binary, area_threshold= 10000)
        binary = skimage.morphology.binary_dilation(binary, footprint=skimage.morphology.disk(5))
        print('binary')
        # eroding binary to create edge binary
        mask =  ndi.binary_erosion(binary)
    
        edge = binary.copy()
    
        # create edge image
        edge[mask]=0   
        edge[0,:] = 0
        edge[:,0] = 0
        edge[-1,:] = 0
        edge[:,-1] = 0

    except:
        # if thresholding fails create an empty array.
        binary=np.zeros_like(image[H2B],dtype=np.bool_)
        edge=np.zeros_like(image[H2B],dtype=np.bool_)
    return binary, edge

results = pool.map(binary_processing, [image for image in volume])


In [None]:
# retrieving stacks from results

binary = np.stack([result[0] for result in results])
edge = np.stack([result[1] for result in results])

# If thresholding can not be performed for all timepoints(because colony eventually spans over the whole field of view) 
# volume, binary and edge are cropt up to the last timepoint where thresholding was possible.

cropt = np.argmax(np.all(binary,axis=(1,2))| np.all(~binary,axis=(1,2)))
edge = edge[:cropt-1,...]
binary = binary[:cropt-1,...]
volume = volume[:cropt-1,...]

In [None]:
import napari

viewer = napari.Viewer()

viewer.add_image(volume, channel_axis=-3)
viewer.add_image(binary)
viewer.add_image(edge)
#viewer.add_labels(distlabel)
#viewer.add_labels(final_nuc)
#viewer.add_labels(cyto_label)

In [None]:
## preparing images for stardist 

nuc_list =[]
df=[]
for i, image in enumerate(volume):
    
    #multiply H2B channel by binary image of colony to avoid segementation of cell/debris outside of the colony
    
    img_norm = normalize((image[H2B]), 1, 99.8, axis=axis_norm)
    label, detail = modelStar.predict_instances(img_norm, prob_thresh=0.47)
    
    #multiply label by binary image of colony to remove segmentations not in the colony.
    label = label*binary[i]
    df_class = skimage.measure.regionprops_table(label, intensity_image=image[H2B], properties= ('label', 'intensity_image', 'area'))
    df_class = pd.DataFrame(df_class)
    df_class['t']=i
    df_class['t_min']= df_class['t']*t_interval 
    #df_class['centroid-0'] =df_class['centroid-0'].astype('int')
    #df_class['centroid-1'] =df_class['centroid-1'].astype('int')
    
    #preparing nuclei crops for image classification based on cell cycle state. Padding all crops to the same size.
    
    padarr = np.zeros((64,64), np.uint16)
    crops =[]
    
    for crop in df_class.intensity_image:
        
        result = padarr.copy()
        arr= crop
        
        xx = (64-arr.shape[1])//2
        yy = (64-arr.shape[0])//2
        
        result[yy:yy+arr.shape[0], xx:xx+arr.shape[1]] = arr
        
        crops.append(result)

    
    images_topre = np.stack(crops)
    images_topre = images_topre.reshape(images_topre.shape[0], images_topre.shape[1], images_topre.shape[2], 1)
    images_topre = normalize(images_topre,1,99.8)
    prediction = np.argmax(model.predict(images_topre), axis=-1)
    
    df_class['states'] = prediction.tolist()
    #del df_class['crop']
    del df_class['intensity_image']
    tf.keras.backend.clear_session()
    
    #label = da.from_array(label)
    nuc_list.append(label)
    df.append(df_class)
    print(i)
    
# create single 3D numpy array out of list of labels and single pandas dataframe from list

final_nuc = np.stack(nuc_list)
df = pd.concat(df)

In [None]:
## performing nuclear label expansion and shrinking for cyto plasma nuclear ration calculation, as well as calculation of the distance from
## each nuclei from the colony edge using a distance map created from the edge image. Uses ray for processing over all timepoints in parallel.

# creating a list of dataframes for every timepoint t

from ray.util.multiprocessing import Pool

pool = Pool()

def process_ERK_dist(idx, i, l, e):

    #creating a distance transform of the binary edge image and then finding the value for every x/y coordinate per label in the distance transform.
    # creating a cytoplasm label.
    cyto_label = skimage.segmentation.expand_labels(l, distance=4) - skimage.segmentation.expand_labels(l, distance=1)
   #### shrinking nuclei for cyto/nuc ratio calculation (by Lucien Hinderling)
    distance = 1.5
    distances = edt.edt(l)
    _, nearest_label_coords = distance_transform_edt(l == 0, return_indices=True)
    shrunknuc_label = np.zeros_like(l)
    dilate_mask = distances >= distance
    masked_nearest_label_coords = [dimension_indices[dilate_mask] for dimension_indices in nearest_label_coords]
    nearest_labels = l[tuple(masked_nearest_label_coords)]
    shrunknuc_label[dilate_mask] = nearest_labels
    ####
    print('shrink')
    
    # measuring mean intensity in shrunken nuclei as well as cytoplasm 
    Meas_nuc = skimage.measure.regionprops_table(shrunknuc_label, intensity_image=i[ERK], properties= ('label', 'mean_intensity', 'centroid'))
    Meas_cyto = skimage.measure.regionprops_table(cyto_label, intensity_image=i[ERK], properties= ('label', 'mean_intensity'))
    Meas_OCT = skimage.measure.regionprops_table(l, intensity_image=i[OCT], properties= ('label', 'mean_intensity'))
    Meas_H2B = skimage.measure.regionprops_table(l, intensity_image=i[H2B], properties= ('label', 'mean_intensity'))

    
    df_Meas_nuc = pd.DataFrame(Meas_nuc)
    df_Meas_cyto = pd.DataFrame(Meas_cyto)
    df_Meas_OCT = pd.DataFrame(Meas_OCT)
    df_Meas_H2B = pd.DataFrame(Meas_H2B)
    
    df_Meas_nuc.rename(columns={'mean_intensity': 'mean_intensity_nuc'}, inplace=True)
    df_Meas_cyto.rename(columns={'mean_intensity': 'mean_intensity_cyto'}, inplace=True)
    df_Meas_OCT.rename(columns={'mean_intensity': 'mean_intensity_OCT'}, inplace=True)
    df_Meas_H2B.rename(columns={'mean_intensity': 'mean_intensity_H2B'}, inplace=True)
    #merge Dataframe and calculate Cyto/Nuc ratio
    df_ERK_meas = pd.merge(df_Meas_nuc, df_Meas_cyto, on='label').merge(df_Meas_OCT, on='label').merge(df_Meas_H2B, on='label')
    df_ERK_meas['CNr'] = df_ERK_meas['mean_intensity_cyto']/df_ERK_meas['mean_intensity_nuc']
    print('finsh')
    df_ERK_meas['t'] = idx
    
    df_ERK_meas.rename(columns={'centroid-0':'y', 'centroid-1':'x'}, inplace=True)
    df_ERK_meas=df_ERK_meas.astype({'y':'uint16','x':'uint16'})
    
    # creating a distance transform of the binary edge image and then finding the value for every x/y coordinate per label in the distance transform.

    disttrans=ndi.distance_transform_edt(e==0)
    distances = []
    for label in df_ERK_meas['label']:
        dist = disttrans[df_ERK_meas[df_ERK_meas['label']==label]['y'],df_ERK_meas[df_ERK_meas['label']==label]['x']]
        distances.append(dist[0])
    df_ERK_meas['dist'] = distances
    
    return df_ERK_meas, cyto_label
 
results = pool.starmap(process_ERK_dist, [(idx,*data) for idx,data in enumerate(zip(volume,final_nuc,edge))])


In [None]:
# extracting cytoplasm label array
cyto_label = np.stack([result[1] for result in results])

# extracting ERKmeasurement dataframe and merging with tracking data frame.

df_ERKmeas = pd.concat([result[0] for result in results])
df_ERKmeas = df_ERKmeas.merge(df, on=['t','label'])

In [None]:

#renaming label column since 'label' in btrack is used for cell state annotation
df_ERKmeas['label_nuc'] = df_ERKmeas['label']
df_ERKmeas['label'] = df_ERKmeas['states']
df_ERKmeas['states'] = 5

In [None]:
# converting dataframe to btrack oject. btrack optimizations will take too long if too many
# tracks are identified prior of the optimization step. In my experience more than 25000 will result in excessively long 
# processing. Therefore one has to potentially reduce the time period or omit optimization and lineage tracing.
objects_to_track = localizations_to_objects(df_ERKmeas)#[(df_tracking['t']<500) &(df_tracking['t']>50)])

In [None]:
from btrack.constants import BayesianUpdates
# initialise a tracker session using a context manager
with btrack.BayesianTracker() as tracker:

    # configure the tracker using a config file
    tracker.configure_from_file('cell_config3.json')
    
    # use APPROXIMATE to speed up tracking
    tracker.update_method = BayesianUpdates.APPROXIMATE
    tracker.max_search_radius = 25
    
    # append the objects to be tracked
    tracker.append(objects_to_track)

    # set the volume (Z axis volume is set very large for 2D data)
    tracker.volume=((0, X_dim), (0, Y_dim), (-1e5,1e5))
    
    
    # track them (in interactive mode)
    tracker.track()
 
    # generate hypotheses and run the global optimizer, only required if lineage tracking is desired.
    #tracker.optimize(options={'tm_lim': int(6e6)})
    
    tracks = tracker.tracks
    
    # get the tracks in a format for napari visualization
    data, properties, graph = tracker.to_napari(ndim=2)

In [None]:
# extracting tracks and combine them into a dataframe

df_ERKmeas = pd.concat([pd.DataFrame(i.to_dict()) for i in tracks])
df_ERKmeas = df_ERKmeas.reset_index(drop=True)
# this is removing dummy objects to avoid problems with ERK cyto/nuc measurements
df_ERKmeas.dropna(inplace=True)
# converting the "label_nuc" column to "label" for dataframe merging.
df_ERKmeas['label'] = df_ERKmeas.label_nuc.astype('uint')
# deleting not needed columns
df_ERKmeas.drop(['z'], axis=1, inplace=True)
df_ERKmeas.drop(['label_nuc'], axis=1, inplace=True)
df_ERKmeas.drop(['dummy'], axis=1, inplace=True)

In [None]:
# Creating directory for saving
base_path = r'//izbkingston.izb.unibe.ch/imaging.data/mic01-imaging/Yannick/'
folder = path.name.split('/')[-2]
dest = f'{base_path}{folder}/pos{pos}/'

if os.path.exists(dest):
    pass
else:
    os.makedirs(dest)

In [None]:
# saving tracking data for napari

import pickle
with open(f'{dest}properties.pkl', 'wb') as f:
    pickle.dump(properties, f)
with open(f'{dest}graph.pkl', 'wb') as f:
    pickle.dump(graph, f)
with open(f'{dest}data.pkl', 'wb') as f:
    pickle.dump(data, f)

In [None]:
# saving data to csv and and labels,edge,binary, cyto_label to zarr arrays.
df_ERKmeas.to_csv(f'{dest}EKRmeas.csv')

zarr.save(f'{dest}labels.zarr', final_nuc)
zarr.save(f'{dest}edge.zarr', edge)
#zarr.save(f'{dest}distlabel.zarr', distlabel)
zarr.save(f'{dest}binary.zarr', binary)
zarr.save(f'{dest}cyto_label.zarr', cyto_label)

In [None]:
# remap labels to IDs for color consistency in napari and also mapping measuruments (ERK) to segmentation. Can be used to
# remap any data to the labels.
from ray.util.multiprocessing import Pool

df_list_t = [df_ERKmeas[df_ERKmeas['t']==t][['CNr','dist','label','ID']] for t in df_ERKmeas['t'].unique()]

pool = Pool()

def remap_labels(label, cyto, df_t):
    
    cyto = cyto.copy()
    label = label.copy()
    df_t=df_t.copy()
    
    in_map = df_t['label'].to_numpy()
    new_map = df_t['ID'].to_numpy()
    remap = skimage.util.map_array(label,in_map,new_map)
    
    in_map = df_t['label'].to_numpy()
    new_map = df_t['ID'].to_numpy()
    remap_cyto = skimage.util.map_array(cyto,in_map,new_map)

    # for ERK remapping CNr value is multiplied by 1000 to fit a 16 bit image and save space compared to 32bit float image.
    in_map = df_t['label'].to_numpy()
    new_map = 1000*df_t['CNr'].to_numpy()
    new_map = new_map.astype('uint16')
    remapERK = skimage.util.map_array(label,in_map,new_map)
    
    return remap, remap_cyto, remapERK

results = pool.starmap(remap_labels, [data for data in zip(final_nuc, cyto_label, df_list_t)])

remap = np.stack([result[0] for result in results])
remap_cyto = np.stack([result[1] for result in results])
remapERK = np.stack([result[2] for result in results])

In [None]:
# Visualize results in napari

import napari
viewer = napari.Viewer()

viewer.add_image(volume, channel_axis=-3)
viewer.add_image(remapERK)

viewer.add_labels(remap)
viewer.add_labels(remap_cyto)

viewer.add_image(edge)

viewer.add_tracks(data=data, properties=properties, graph=graph)

In [None]:
remap = []
for i in range(final_nuc.shape[0]):
    in_arr = final_nuc[i].copy()
    in_map = df_ERKmeas[df_ERKmeas['t']==i]['label'].to_numpy()
    new_map = df_ERKmeas[df_ERKmeas['t']==i]['ID'].to_numpy()
    remapped_ID_out = skimage.util.map_array(in_arr,in_map,new_map)
    remap.append(remapped_ID_out)
remap = np.stack(remap)   

In [None]:
remapcyto = []
for i in range(cyto_label.shape[0]):
    in_arr = cyto_label[i].copy()
    in_map = df_ERKmeas[df_ERKmeas['t']==i]['label'].to_numpy()
    new_map = df_ERKmeas[df_ERKmeas['t']==i]['ID'].to_numpy()
    remapped_ID_out = skimage.util.map_array(in_arr,in_map,new_map)
    remapcyto.append(remapped_ID_out)
remapcyto = np.stack(remapcyto)

In [None]:
remapERK = []

for i in range(final_nuc.shape[0]):
    in_arr = final_nuc[i].copy()
    in_map = df_ERKmeas[df_ERKmeas['t']==i]['label'].to_numpy()
    new_map = 1000*df_ERKmeas[df_ERKmeas['t']==i]['CNr'].to_numpy()
    new_map = new_map.astype('uint16')
    remapped_ID_out = skimage.util.map_array(in_arr,in_map,new_map)
    remapERK.append(remapped_ID_out)
remapERK = np.stack(remapERK)