In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os,sys, kimimaro, pickle
import seaborn as sns
import pandas as pd
import numpy as np
from skimage.morphology import disk, ball, square
import scipy.ndimage as nd
import matplotlib.pyplot as plt
import matplotlib.animation as anim
import cloudvolume as cv
from IPython.display import HTML
from matplotlib.collections import LineCollection
sys.path.append("/home/phornauer/Git/axon_tracking/")
import axon_tracking.skeletonization as skel
import axon_tracking.template_extraction as te
import axon_tracking.visualization as vis
import axon_tracking.quantification as quant

In [None]:
params = dict()
params['noise_threshold'] = -2
params['abs_threshold'] = -0.2
params['upsample'] = [1,1,1]
params['sampling_rate'] = 20000/params['upsample'][2] # [Hz]
params['ms_cutout'] = [1.5, 5.0]
params['filter_footprint'] = ball(1)#np.ones((2,2,3))
params['max_velocity'] = 1 # [m/s]
params['max_noise_level'] = 0.25

In [None]:
skel_params = dict()
skel_params['scale'] = 1#1
skel_params['const'] = 2#3
skel_params['dust_threshold'] = 5
skel_params['anisotropy'] = (1, 1, 1)
skel_params['tick_threshold'] = 5
skel_params['n_jobs'] = 16

In [None]:
qc_params = dict()
qc_params['window_size'] = 7
qc_params['max_duplicate_ratio'] = 0.3
qc_params['min_r2'] = 0.8
qc_params['vel_range'] = [0.25, 1.25]
qc_params['min_length'] = 1

In [None]:
def condition_from_well_id(well_id):
    if well_id < 12:
        reg = 0
    else:
        reg = 1

    if well_id%6 < 2:
        aav_id = 0
    elif well_id%6 > 3:
        aav_id = 129
    else:
        aav_id = 128

    return reg, aav_id

In [None]:
root_path = '/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics/Low_dose_range/concatenated/'
week = 2
experiment_id = 1

column_names = ['full_velocities', 'velocity_mean', 'velocity_std','template_size','branch_point_count','branch_dist_mean','branch_dist_std',
                'branch_length_mean','branch_length_std','longest_axon','terminal_count','projection_dist_mean','projection_dist_std','well_id',
               'aav','age','experiment','region','responder','unit_id']

In [None]:
full_velocities = []
velocity_mean = []
velocity_std = []
template_size = []
branch_point_count = []
branch_dist_mean = []
branch_dist_std = []
branch_length_mean = []
branch_length_std = []
longest_axon = []
terminal_count = []
projection_dist_mean = []
projection_dist_std = []
well_id = []
aav = []
age = []
experiment = []
region = []
unit_id = []
responder = []

In [None]:
for w in range(24):
    stream_id = 'well' + str(w).zfill(3)
    well_path = os.path.join(root_path, stream_id)
    if not os.path.exists(os.path.join(well_path,'good_responders.npy')):
        continue
    responders = np.load(os.path.join(well_path,'good_responders.npy'))
    non_responders = np.load(os.path.join(well_path,'good_non_responders.npy'))
    
    analysis_path = os.path.join(root_path, stream_id, 'analysis')
    if not os.path.exists(analysis_path):
        os.mkdir(analysis_path)
        
    template_ids = np.concatenate((responders,non_responders)).astype('int')
    for template_id in template_ids:
        try:
            qc_skeleton, scaled_qc_list = skel.full_skeletonization(root_path, stream_id, template_id, params, skel_params, qc_params)
            if len(scaled_qc_list) < 1:
                continue
            path_list, r2s, full_vels, lengths = skel.perform_path_qc(scaled_qc_list, params,**qc_params)
    
            qc_skeleton, scaled_qc_list = skel.full_skeletonization(root_path, stream_id, template_id, params, skel_params, qc_params)
            with open(os.path.join(analysis_path,str(template_id) + '_skel.pkl'),'wb') as file_name:
                pickle.dump(qc_skeleton,file_name)
                
            path_list, r2s, full_vels, lengths = skel.perform_path_qc(scaled_qc_list, params,**qc_params)
            good_full_vels = [full_vels[x] for x in range(len(full_vels)) if r2s[x] > 0.9 and lengths[x]>5]
            full_velocities.append(np.mean(np.unique(good_full_vels)))
            mean_vel, std_vel = quant.get_sliding_window_velocity(scaled_qc_list,params,window_size=6,min_r2=0.9)
            velocity_mean.append(mean_vel)
            velocity_std.append(std_vel)
            template_size.append(quant.get_simple_template_size(scaled_qc_list))
            branch_point_count.append(quant.get_branch_point_count(qc_skeleton))
            branch_dists = quant.get_branch_point_dists(qc_skeleton)
            branch_dist_mean.append(np.mean(branch_dists))
            branch_dist_std.append(np.std(branch_dists))
            branch_lengths = quant.get_branch_lengths(scaled_qc_list)
            branch_length_mean.append(np.mean(branch_lengths))
            branch_length_std.append(np.std(branch_lengths))
            longest_axon.append(quant.get_longest_path(qc_skeleton))
            terminal_count.append(quant.get_terminal_count(qc_skeleton))
            dists = quant.get_projection_dists(qc_skeleton)
            projection_dist_mean.append(np.mean(dists))
            projection_dist_std.append(np.std(dists))
            well_id.append(w)
            reg, aav_id = condition_from_well_id(w)
            aav.append(aav_id)
            age.append(week)
            experiment.append(experiment_id)
            region.append(reg)
            unit_id.append(template_id)
            if template_id in responders:
                responder.append(1)
            elif template_id in non_responders:
                responder.append(0)
            else:
                responder.append(0)
                print('template id ' + str(template_id) + ' not found')
        except Exception as e:
            print(e)

In [None]:
w = 14
stream_id = 'well' + str(w).zfill(3)
well_path = os.path.join(root_path, stream_id)
responders = np.load(os.path.join(well_path,'good_responders.npy'))
non_responders = np.load(os.path.join(well_path,'good_non_responders.npy'))

analysis_path = os.path.join(root_path, stream_id, 'analysis')

template_ids = np.concatenate((responders,non_responders)).astype('int')
template_id = non_responders[7]
with open(os.path.join(analysis_path,str(template_id) + '_skel.pkl'),'rb') as file_name:
    qc_skeleton = pickle.load(file_name)

all_branches = skel.branches_from_paths(qc_skeleton)
scaled_qc_list, r2s, vels, lengths = skel.perform_path_qc(all_branches, params,**qc_params)

In [None]:
skeleton = qc_skeleton
np.mean(quant.get_projection_dists(skeleton))

In [None]:
for w in range(24):
    stream_id = 'well' + str(w).zfill(3)
    well_path = os.path.join(root_path, stream_id)
    if not os.path.exists(os.path.join(well_path,'good_responders.npy')):
        continue
    responders = np.load(os.path.join(well_path,'good_responders.npy'))
    non_responders = np.load(os.path.join(well_path,'good_non_responders.npy'))
    
    analysis_path = os.path.join(root_path, stream_id, 'analysis')
    
    template_ids = np.concatenate((responders,non_responders)).astype('int')
    for template_id in template_ids:
        try:
           with open(os.path.join(analysis_path,str(template_id) + '_skel.pkl'),'rb') as file_name:
                qc_skeleton = pickle.load(file_name)
                
            path_list, r2s, full_vels, lengths = skel.perform_path_qc(scaled_qc_list, params,**qc_params)
            good_full_vels = [full_vels[x] for x in range(len(full_vels)) if r2s[x] > 0.9 and lengths[x]>5]
            full_velocities.append(np.mean(np.unique(good_full_vels)))
            mean_vel, std_vel = quant.get_sliding_window_velocity(scaled_qc_list,params,window_size=6,min_r2=0.9)
            velocity_mean.append(mean_vel)
            velocity_std.append(std_vel)
            template_size.append(quant.get_simple_template_size(scaled_qc_list))
            branch_point_count.append(quant.get_branch_point_count(qc_skeleton))
            branch_dists = quant.get_branch_point_dists(qc_skeleton)
            branch_dist_mean.append(np.mean(branch_dists))
            branch_dist_std.append(np.std(branch_dists))
            branch_lengths = quant.get_branch_lengths(scaled_qc_list)
            branch_length_mean.append(np.mean(branch_lengths))
            branch_length_std.append(np.std(branch_lengths))
            longest_axon.append(quant.get_longest_path(qc_skeleton))
            terminal_count.append(quant.get_terminal_count(qc_skeleton))
            dists = quant.get_projection_dists(qc_skeleton)
            projection_dist_mean.append(np.mean(dists))
            projection_dist_std.append(np.std(dists))
            well_id.append(w)
            reg, aav_id = condition_from_well_id(w)
            aav.append(aav_id)
            age.append(week)
            experiment.append(experiment_id)
            region.append(reg)
            unit_id.append(template_id)
            if template_id in responders:
                responder.append(1)
            elif template_id in non_responders:
                responder.append(0)
            else:
                responder.append(0)
                print('template id ' + str(template_id) + ' not found')
        except Exception as e:
            print(e)

In [None]:
data_dict = {}
for column_name in column_names[1:]:
    data_dict[column_name] = locals()[column_name]

In [None]:
for k in data_dict.keys():
    print(k)#, len(data_dict[k]))

In [None]:
result_df = pd.DataFrame(data_dict)
final_df = result_df.drop_duplicates()

In [None]:
final_df.shape

In [None]:
final_df.to_pickle('/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics/Low_dose_range/concatenated/new_skel_features.pkl')

In [None]:
exp_1 = pd.read_pickle('/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics/Low_dose_range/concatenated/skel_features.pkl')
exp_2 = pd.read_pickle('/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics/Large_dose_range/concatenated/skel_features.pkl')
#exp_3 = pd.read_pickle('/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics_2/Week_2/concatenated/skel_features.pkl')
#exp_4 = pd.read_pickle('/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics_2/Week_3/concatenated/skel_features.pkl')

In [None]:
full = pd.concat((exp_1,exp_2))#,exp_3,exp_4))
full.drop_duplicates(inplace=True)

In [None]:
sel_df = full.query('aav<129 and region==1')
sel_df = full.query('aav<129 and region==1 and experiment==1')

In [None]:
sel_df.columns

In [None]:
import matplotlib as mpl
cmap = mpl.cm.get_cmap('viridis',10)

In [None]:
numpy_df = sel_df.to_records(index=False)

In [None]:
np.save('/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics/Large_dose_range/concatenated/skel_features.npy', numpy_df)

In [None]:
from scipy.io import savemat
savemat('/net/bs-filesvr02/export/group/hierlemann/intermediate_data/Maxtwo/phornauer/Chemogenetics/Large_dose_range/concatenated/skel_features.mat', {"data":numpy_df})

In [None]:
g = sns.violinplot(data=sel_df,x="aav",y="projection_dist_mean",hue="responder",saturation=1,split=True,gap=.05,palette=cmap([2,8]),density_norm="area",bw_adjust=0.8,
                   inner_kws=dict(box_width=5,color="0.3"))
plt.xticks(ticks=[0,1],labels=["Control", "AAV"])
plt.ylabel("Mean projection distance")
plt.xlabel("")
#plt.box(False)
plt.legend(title="Putative cell type", labels=["EXC","_","_","_","INH"],frameon=False,loc="upper left")
plt.show()

In [None]:
#g = sns.violinplot(data=sel_df,x="aav",y="template_size",hue="responder",saturation=0.5,split=True,gap=.1,density_norm="count",bw_adjust=0.8)
g = sns.violinplot(data=sel_df,x="aav",y="template_size",saturation=0.5,density_norm="area",bw_adjust=0.8)
plt.xticks(ticks=[0,1],labels=["Control", "AAV"])
plt.ylabel("Skeleton size (#electrodes)")
plt.xlabel("")
#plt.legend(title="Putative cell type", labels=["EXC","_","_","_","INH"],frameon=False,loc="upper left")
plt.show()

In [None]:
g = sns.violinplot(data=sel_df,x="aav",y="terminal_count",hue="responder",saturation=0.5,split=True,gap=.1,density_norm="count",bw_adjust=0.8)
plt.xticks(ticks=[0,1],labels=["Control", "AAV"])
plt.ylabel("Terminal count")
plt.xlabel("")
plt.legend(title="Putative cell type", labels=["EXC","_","_","_","INH"],frameon=False,loc="upper left")
plt.show()

In [None]:
g = sns.catplot(data=sel_df,x="aav",y="template_size",hue="responder",saturation=0.5,split=True,density_norm="count",inner=None, kind="violin")
sns.swarmplot(data=sel_df, x="aav", y="template_size", hue="responder", size=4, edgecolor='k',linewidth=0.5,ax=g.ax)

In [None]:
qc_skeleton.vertices[qc_skeleton.terminals(),:2]

In [None]:
qc_skeleton, scaled_qc_list = skel.full_skeletonization(root_path, stream_id, template_id, params, skel_params, qc_params)

In [None]:
path_list, r2s, full_vels, lengths = skel.perform_path_qc(scaled_qc_list, params,**qc_params)

In [None]:
len(scaled_qc_list)

In [None]:
good_full_vels = [full_vels[x] for x in range(len(full_vels)) if r2s[x] > 0.9 and lengths[x]>5]
np.mean(np.unique(good_full_vels))

In [None]:
mean_vel, std_vel = quant.get_sliding_window_velocity(scaled_qc_list,params,window_size=6,min_r2=0.9)
std_vel

In [None]:
temp_size = quant.get_simple_template_size(scaled_qc_list)
template_size

In [None]:
branch_points = quant.get_branch_point_count(qc_skeleton)
branch_points

In [None]:
branch_dists = quant.get_branch_point_dists(qc_skeleton)
np.mean(branch_dists)

In [None]:
branch_lengths = quant.get_branch_lengths(scaled_qc_list)
np.mean(branch_lengths)

In [None]:
longest_path = quant.get_longest_path(qc_skeleton)
longest_path

In [None]:
terminal_count = quant.get_terminal_count(qc_skeleton)

In [None]:
dists = quant.get_projection_dists(skeleton)
dists

In [None]:
vis.plot_delay_contour(capped_template,qc_skeleton,params,skel_params,radius=5,save_path=[])

In [None]:
vis.plot_filled_contour(capped_template,qc_skeleton,params,radius=5,save_path=[])