In [4]:
# Some imports
import sys
import os, glob
from mfx.processlocalizations import ProcessLocalizations
import re
from matplotlib import pyplot as plt
import math
import numpy as np
from mfx import mfxcolnames as col
import numpy as np
import plotly
import plotly.express as px
import pandas as pd 
import random

In [5]:
def fig_mean_tids(indict, idx_tids, title, color = False, axis_min = [0,0,0], axis_max = [1,1,1]):
    color_map = {-1:'gray', 0: 'blue', 1: 'magenta', 2: 'cyan'}
    concat_array= None
    for tid in idx_tids:
        idx = (indict[col.TID] == tid)
        uid = np.unique(indict[col.CLS_TRACK][idx])
        pos = indict[col.LTR][idx]
        cls = indict[col.CLS_TRACK][idx]
        for id_cls in uid:
            if id_cls == -1:
                continue
            m = np.mean(pos[cls==id_cls], axis = 0)
            sd = np.std(pos[cls==id_cls], axis = 0)
            sd = math.sqrt(np.mean(sd**2))
            nl = np.size(pos[cls==id_cls], axis = 0)
            se = sd/math.sqrt(nl)
            mean_pos_tid = np.append(m, [sd, se, nl, id_cls, tid])
            
            if concat_array is None:
                concat_array = mean_pos_tid
            else:
                concat_array = np.vstack([concat_array, mean_pos_tid])
    df = pd.DataFrame(concat_array, columns = ['x', 'y', 'z', 'sd', 'se', 'nl', 'CLS_TRACK', 'TID'] )  
    df.CLS_TRACK = pd.Categorical(df.CLS_TRACK)
    fig = px.scatter_3d(df, x='x', y='y',  z = 'z', color = 'CLS_TRACK', symbol = 'TID',
                        color_discrete_map=color_map)
    fig.update_traces(marker = dict(size = 2, line=dict(width=2,
                                        color='DarkSlateGrey')), showlegend=False)
    fig.update_layout(scene_aspectmode='cube', 
                      plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)', 
                      scene = dict(
                          xaxis = dict(range=[axis_min[0],axis_max[0]]),
                          yaxis = dict( range=[axis_min[1],axis_max[1]]),
                          zaxis = dict( range=[axis_min[2],axis_max[2]])))
    camera = dict(
        eye=dict(x=1.25, y=3, z=1.25)
    )



    fig.update_layout(scene_camera=camera)
    return fig

def fig_tids(indict, idx_tids, title, color = False):
    color_map = {-1:'gray', 0: 'blue', 1: 'magenta', 2: 'cyan'}
    df_all = None
    print(color_map)
    for tid in idx_tids:
        idx = (indict[col.TID] == tid)
        pos = indict[col.LTR][idx]
        axis_min = np.min(pos, axis = 0)
        axis_max = np.max(pos, axis = 0)
        concat_array = np.insert(indict[col.LTR][idx], 3, 
                                 indict[col.CLS_TRACK][idx], axis = 1)
        concat_array = np.insert(concat_array, 4, indict[col.TID][idx], axis = 1)

        df = pd.DataFrame(concat_array, columns=['x', 'y', 'z', 'CLS_TRACK', 'TID'])
        
        if df_all is None:
            df_all = df
            axis_min_all = axis_min
            axis_max_all = axis_max
        else:
            axis_min_all = np.min([axis_min_all, axis_min], axis = 0)
            axis_max_all = np.max([axis_max_all, axis_max], axis = 0)
            
            df_all = df_all.append(df)
    axis_min_all = axis_min_all - 0.02*axis_min_all
    axis_max_all = axis_max_all + 0.02*axis_max_all
    
    print(axis_min_all)
    df_all.CLS_TRACK = pd.Categorical(df_all.CLS_TRACK)
    if color:
        fig = px.scatter_3d(df_all, x='x', y='y',  z = 'z', 
                            color = 'CLS_TRACK', symbol = 'TID', 
               color_discrete_map=color_map)
        fig.update_traces(marker = dict(size = 2, line=dict(width=2,
                                        color='DarkSlateGrey')), showlegend=False)
        
    else:
        fig = px.scatter_3d(df_all, x='x', y='y',  z = 'z', 
                           symbol = 'TID')
        fig.update_traces(marker = dict(size = 2, color = 'white', 
                                   line=dict(width=2,
                                        color='DarkSlateGrey')), showlegend=False)
        
    
    fig.update_layout(scene_aspectmode='cube', 
                      plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)', 
                      scene = dict(
                          xaxis = dict(range=[axis_min_all[0],axis_max_all[0]]),
                          yaxis = dict( range=[axis_min_all[1],axis_max_all[1]]),
                          zaxis = dict( range=[axis_min_all[2],axis_max_all[2]])))
    camera = dict(
        eye=dict(x=1.25, y=3, z=1.25)
    )


    fig.update_layout(scene_camera=camera)
    return [fig, axis_min_all, axis_max_all]
    

def get_random_tids_toplot(tid_splits, nr_samples = 2):
    keys = list(tid_splits.keys())
    tids_to_plot = list() 
    #dict(zip(keys, []*len(keys)))
    for label in keys:
        tid_splits[label].sort(key=lambda x:x[1], reverse=True)
        l = len(tid_splits[label])
        borders = [[0, round(l/3)],[round(l/3)+1, round(2*l/3)],[round(2*l/3), l]  ]
        idx_to_plots = [random.sample(range(x[0],x[1]), nr_samples) for x in borders]
        tids = [x[0] for x in tid_splits[label]]
        tids_to_plot.append([[tids[idx] for idx in tier] for tier in idx_to_plots])
    return tids_to_plot
    
def show_tids_set(indict, tid_splits, tids_to_plot_in = None, 
                  postfix   = '', wash = None, color = True, savedir = None):
    
    keys = list(tid_splits.keys())
    if tids_to_plot_in is None:
        tids_to_plot = get_tids_toplot(tid_splits)
    else:    
        tids_to_plot = tids_to_plot_in
    if wash == None:
        wash = [0,1]
    for ikey in wash:
        for tier in tids_to_plot[ikey]:
            title = '%s, TIDs %s, %s' % (keys[ikey],  ' '.join(map(str, tier)), postfix)
            [fig, axis_min, axis_max] = fig_tids(indict[keys[ikey]], tier, title = title, color=color)
            fig.show()
            fig2 = fig_mean_tids(indict[keys[ikey]], tier, title = title, color=color, axis_min=axis_min, axis_max=axis_max)
            fig2.show()
            
            if savedir is not None:
                imgname = os.path.join(savedir, '%s_TID%s_%s.pdf' % (keys[ikey],  '_'.join(map(str, tier)), postfix))
                fig.write_image(imgname, format='pdf')
                imgname = os.path.join(savedir, '%s_TID%s_%s_meanTID.pdf' % (keys[ikey],  '_'.join(map(str, tier)), postfix))
                fig2.write_image(imgname, format='pdf')
                
            
    return tids_to_plot

In [6]:
# Get file names

def gen_directories( maindir, subdir, keys):
    outdir = {}
    for key in keys:
        outdir[key] = os.path.join(maindir, subdir, key)
    return outdir

def gen_npyfiles(dir_dict):
    npy_files = {}
    for key in dir_dict:
        npy_files[key] = []

    for key in dir_dict:
        for (root, dirs, files) in os.walk(dir_dict[key]):
            if len(dirs) > 0:
                for adir in dirs:
                    npyfile = os.path.join(root, adir, adir +'.npy')
                    if os.path.exists(npyfile):
                        npy_files[key].append(npyfile)
                    
    return npy_files


OUTDIR_LOC = 'C:/Users/apoliti/Desktop/mflux_zarr_tmp_storage/analysis' # Main directory to store zarr files
OUTDIR_REM =  'Z:/siva_minflux/analysis'  # Main directory to store results 
INDIR  = 'Z:/siva_minflux/data'       # main directory of msr file

# Multiple washes with different imager strand
keys = ['Syp_ATG9', 'ZnT3_Syp', 'Syp_Picc']
indir_mwash = gen_directories(INDIR, 'Multiwash', keys)
outdir_mwash = gen_directories(OUTDIR_REM, 'Multiwash', keys)
zarrdir_mwash = gen_directories(OUTDIR_LOC, 'Multiwash', keys) 
npy_mwash = gen_npyfiles(outdir_mwash)

# Wash with a single imager strand
keys = ['Syp', 'ATG9', 'VGLUT1']
indir_swash =  gen_directories(INDIR, 'Single wash', keys)
outdir_swash =  gen_directories(OUTDIR_REM, 'Single wash', keys)
zarrdir_swash =  gen_directories(OUTDIR_LOC, 'Single wash', keys)
npy_swash = gen_npyfiles(outdir_swash)


# Consitency controls. Wash with a single imager strand but multiple times. 
keys = ['VGLUT1_VGLUT1']
indir_cwash =  gen_directories(INDIR, 'Multiwash', keys)
outdir_cwash = gen_directories(OUTDIR_REM, 'Multiwash', keys)
zarrdir_cwash = gen_directories(OUTDIR_LOC, 'Multiwash', keys) 
npy_cwash = gen_npyfiles(outdir_cwash)


In [7]:
pl = ProcessLocalizations(npy_mwash['ZnT3_Syp'][1])
pl.STD_QUANTILE = 0.9
pl.MIN_LOCALIZATIONS = 3
pl.DBCLUSTER_SIZE = 3
pl.trim_min_localizations()

# Split tracks if needed, remove outliers in each track
pl.DBCLUSTER_EPS_TRACK = 1.5e-8

pl.cluster_tid(method=pl.CLS_METHOD_DBSCAN)
#pl.cluster_tid(method=pl.CLS_METHOD_BAYES_GMM, weight_prior = 1e-3)

tid_splits_dbscan = pl.get_split_events()
loc_dbscan = pl.loc.copy()


pl = ProcessLocalizations(npy_mwash['ZnT3_Syp'][1])
pl.STD_QUANTILE = 0.9
pl.MIN_LOCALIZATIONS = 3
pl.DBCLUSTER_SIZE = 3
pl.trim_min_localizations()
#pl.cluster_tid(method=pl.CLS_METHOD_DBSCAN)
pl.cluster_tid(method=pl.CLS_METHOD_BAYES_GMM, weight_prior = 1)
tid_splits_bayes_gmm = pl.get_split_events()
loc_bayes_gmm = pl.loc.copy()
summary_tid = pl.summary_per_tid2()


[38;5;39m2022-10-06 10:45:58,190 [INFO] **** trim_min_localizations ****
220309_ZnT3_P1 Removed 81/361 tracks with less than 3 localizations
220309_Syp_P2 Removed 233/725 tracks with less than 3 localizations
[0m
[38;5;39m2022-10-06 10:46:00,385 [INFO] **** cluster_tid ****
220309_ZnT3_P1 MIN_SPLIT_LOCALIZATION: 6, sd_limit: 11.45 nm
Processed TID: 21 / 280, Total tracks TID2: 289
220309_Syp_P2 MIN_SPLIT_LOCALIZATION: 6, sd_limit: 11.45 nm
Processed TID: 41 / 492, Total tracks TID2: 501
[0m
[38;5;39m2022-10-06 10:46:00,531 [INFO] **** trim_min_localizations ****
220309_ZnT3_P1 Removed 81/361 tracks with less than 3 localizations
220309_Syp_P2 Removed 233/725 tracks with less than 3 localizations
[0m
[38;5;39m2022-10-06 10:46:03,877 [INFO] **** cluster_tid ****
220309_ZnT3_P1 MIN_SPLIT_LOCALIZATION: 6, sd_limit: 11.45 nm
Processed TID: 21 / 280, Total tracks TID2: 297
220309_Syp_P2 MIN_SPLIT_LOCALIZATION: 6, sd_limit: 11.45 nm
Processed TID: 41 / 492, Total tracks TID2: 528
[0m


In [None]:
# Show how localizations in TIDs are merged
random.seed(15) # to ensure to show the same graph
tids_in = get_random_tids_toplot(tid_splits_dbscan)
tids_in = [[x[0]] for x in tids_in]
savedir = os.path.dirname(os.path.dirname(npy_mwash['ZnT3_Syp'][1]))
savedir = os.path.join(savedir, 'figures')
#show_tids_set(loc_dbscan, tid_splits_dbscan, tids_to_plot_in=tids_in, 
#              postfix = 'dbscan_nocolor', wash = [0], color = False, savedir = savedir)
show_tids_set(loc_bayes_gmm, tid_splits_bayes_gmm, tids_to_plot_in=tids_in, 
              postfix = 'bayes_gmm_nocolor', wash = [0], color = False, savedir = savedir)
show_tids_set(loc_bayes_gmm, tid_splits_bayes_gmm, tids_to_plot_in=tids_in, 
              postfix = 'bayes_gmm', wash = [0], color = True, savedir = savedir)

In [10]:
pl.loc['220309_ZnT3_P1']['eco']

array([ 90,  76, 100, ...,  53, 140,  61])

In [9]:
pl['']

AttributeError: 'ProcessLocalizations' object has no attribute 'mfx_all'