In [3]:
from collections.abc import Iterable
import os

from adjustText import adjust_text
import colorsys
from datetime import datetime, timedelta
from dateutil import tz
from hdmf.backends.hdf5.h5_utils import H5DataIO
from hdmf.container import Container
from hdmf.data_utils import DataChunkIterator
import latex
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms
import numpy as np
import pandas as pd
import pickle
from pynwb import load_namespaces, get_class, register_class, NWBFile, TimeSeries, NWBHDF5IO
from pynwb.file import MultiContainerInterface, NWBContainer, Device, Subject
from pynwb.ophys import ImageSeries, OnePhotonSeries, OpticalChannel, ImageSegmentation, PlaneSegmentation, Fluorescence, DfOverF, CorrectedImageStack, MotionCorrection, RoiResponseSeries, ImagingPlane
from pynwb.core import NWBDataInterface
from pynwb.epoch import TimeIntervals
from pynwb.behavior import SpatialSeries, Position
from pynwb.image import ImageSeries
import pywt
import scipy.io as sio
import scipy
from scipy.stats import multivariate_normal, spearmanr
from scipy.optimize import linear_sum_assignment
import seaborn as sns
import skimage.io as skio
from sklearn.neighbors import NearestNeighbors
from sklearn.decomposition import PCA
from tifffile import TiffFile
import tifffile

from networkx import kamada_kawai_layout

import sys
sys.path.append('NWBelegans/Analysis')

from atlas import loadmat, NPAtlas, NWBAtlas
from process_file import get_nwb_neurons, get_dataset_neurons, get_dataset_online, combine_datasets, get_pairings, get_color_discrim, get_neur_nums
from stats import get_summary_stats, analyze_pairs, get_accuracy
from visualization import plot_num_heatmap, plot_std_heatmap, plot_summary_stats, plot_color_discrim, plot_accuracies, plot_visualizations_atlas, plot_visualizations_data, plot_atlas2d_super
from utils import covar_to_coord, convert_coordinates, maha_dist, run_linear_assignment

# ndx_mulitchannel_volume is the novel NWB extension for multichannel optophysiology in C. elegans
#from ndx_multichannel_volume import CElegansSubject, OpticalChannelReferences, OpticalChannelPlus, ImagingVolume, VolumeSegmentation, MultiChannelVolume, MultiChannelVolumeSeries

In [4]:
#Optional import if you want to open figures in a separate window, add %matplotlib qt to top of any code box if you want figures to open in a separate window 
import PyQt6.QtCore
os.environ["QT_API"] = "pyqt6"

In [6]:
atlas = NWBAtlas(atlas_file = 'Data/atlases/2024_03_11_match_full_nosplit.pkl', ganglia='Data/neuron_ganglia.csv') # Load atlas
atlas_df = atlas.get_df()
atlas_neurons = np.asarray(atlas_df['ID'])
atlas.df = atlas.df.drop(atlas.df[atlas.df['ID']=='IL1V'].index)

In [7]:
def maha_dist(data, mu, sigma):

    data_mu = data-mu
    inv_sigma = np.linalg.inv(sigma)
    left_data = np.dot(data_mu, inv_sigma)
    mahal = np.dot(left_data, data_mu.T)

    return np.sqrt(mahal)

def run_linear_assignment(df_data, atlas):
    #df_data should have columns 'X', 'Y', 'Z', 'R', 'G', 'B', 'ID' 

    df_assigns = df_data.copy()

    mu = atlas.mu
    sigma = atlas.sigma
    neurons = np.asarray(atlas.neurons)

    xyzrgb = np.asarray(df_data[['X','Y','Z','R','G','B']])
    gt_labels = np.asarray(df_data['ID'])

    assigns = np.empty((xyzrgb.shape[0],5),np.dtype('U100'))
    assign_cost = np.zeros((xyzrgb.shape[0], 3, 5)) #total, position, color in second dimension, top 5 ranks in third 

    cost_mat = np.zeros((xyzrgb.shape[0], mu.shape[0]))
    cost_pos = np.zeros((xyzrgb.shape[0], mu.shape[0]))
    cost_col = np.zeros((xyzrgb.shape[0], mu.shape[0]))

    for i in range(xyzrgb.shape[0]):
        for j in range(mu.shape[0]):
            cost = maha_dist(xyzrgb[i,:], mu[j,:], sigma[:,:,j])
            cost_pos[i,j] = maha_dist(xyzrgb[i,:3], mu[j,:3], sigma[:3,:3,j])
            cost_col[i,j] = maha_dist(xyzrgb[i,3:], mu[j,3:], sigma[3:,3:,j])

            cost_mat[i,j] = cost

    for k in range(5):

        row_inds, col_inds = linear_sum_assignment(cost_mat)

        assigns[row_inds,k] = np.asarray(neurons[col_inds])

        assign_cost[row_inds, 0, k] = cost_mat[row_inds, col_inds]
        assign_cost[row_inds, 1, k] = cost_pos[row_inds, col_inds]
        assign_cost[row_inds, 2, k] = cost_col[row_inds, col_inds]

        cost_mat[row_inds, col_inds] = np.inf

    df_assigns['assign_1'] = assigns[:,0]
    df_assigns['assign_2'] = assigns[:,1]
    df_assigns['assign_3'] = assigns[:,2]
    df_assigns['assign_4'] = assigns[:,3]
    df_assigns['assign_5'] = assigns[:,4]

    return df_assigns, assign_cost

In [8]:
def get_accuracies(folder, atlas):
    acc_df = pd.DataFrame(columns=['Total_neurons','Percent_IDd', 'Percent_top1', 'Percent_top2', 'Percent_top3', 'Percent_top4', 'Percent_top5', 'Filename'])
    for file in os.listdir(folder):
        if not file[-4:] == '.csv':
            continue

        df_data = pd.read_csv(folder + '/'+file)
        df_data = df_data.rename(columns={"aligned_x":"X","aligned_y":"Y","aligned_z":"Z", "aligned_R":"R", "aligned_G":"G", "aligned_B":"B"})

        df, costs = run_linear_assignment(df_data, atlas)

        IDd = df[~df['ID'].isnull()]

        per_ID = len(IDd.index)/len(df.index)

        total_neurons = len(df.index)

        corr1 = df.loc[df['ID']==df['assign_1']]
        corr2 = df.loc[df['ID']==df['assign_2']]
        corr3 = df.loc[df['ID']==df['assign_3']]
        corr4 = df.loc[df['ID']==df['assign_4']]
        corr5 = df.loc[df['ID']==df['assign_5']]
            
        corr_cum_2 = pd.concat([corr1,corr2]).drop_duplicates().reset_index(drop=True)
        corr_cum_3 = pd.concat([corr_cum_2,corr3]).drop_duplicates().reset_index(drop=True)
        corr_cum_4 = pd.concat([corr_cum_3,corr4]).drop_duplicates().reset_index(drop=True)
        corr_cum_5 = pd.concat([corr_cum_4, corr5]).drop_duplicates().reset_index(drop=True)

        per_corr_1 = len(corr1.index)/len(IDd.index)
        per_corr_2 = len(corr_cum_2.index)/len(IDd.index)
        per_corr_3 = len(corr_cum_3.index)/len(IDd.index)
        per_corr_4 = len(corr_cum_4.index)/len(IDd.index)
        per_corr_5 = len(corr_cum_5.index)/len(IDd.index)

        acc_df.loc[len(acc_df.index)] = [total_neurons,per_ID, per_corr_1, per_corr_2, per_corr_3, per_corr_4, per_corr_5, file[:-4]]

    return acc_df

In [10]:
'''
Get accuracy values for each dataset using the trained atlas and the roughly aligned point clouds. If you would like
to test on datasets that have not been pre-aligned, please use the neuroPAL_ID software which has the atlas and alignment
code pre-compiled
'''

NP_atlas_match = NWBAtlas(atlas_file = 'Data/atlases/2024_03_11_NPonly.pkl', ganglia='Data/neuron_ganglia.csv') #Atlas trained on just original 10 NeuroPAL datasets
NP_atlas_unmatch = NWBAtlas(atlas_file = 'Data/atlases/2024_03_11_NPunmatch.pkl', ganglia='Data/neuron_ganglia.csv')

accs_NP = get_accuracies('Data/aligned_heads/aligned_NP', NP_atlas_match)
accs_NP_unmatch = get_accuracies('Data/aligned_heads/aligned_NP_nomatch', NP_atlas_unmatch)

In [13]:
for i in range(5):
    full_atlas_match = NWBAtlas(atlas_file = 'Data/atlases/2024_03_11_split/exgroup'+str(i)+'.pkl', ganglia='Data/neuron_ganglia.csv') 
    full_atlas_unmatch = NWBAtlas(atlas_file = 'Data/atlases/2024_03_11_split_unmatch/exgroup'+str(i)+'.pkl', ganglia='Data/neuron_ganglia.csv')

    accs_match = get_accuracies('Data/aligned_heads/aligned_split/group'+str(i+1), full_atlas_match)
    accs_unmatch = get_accuracies('Data/aligned_heads/aligned_split_nomatch/group'+str(i+1), full_atlas_unmatch)

    if i==0:
        accs_full = accs_match
        accs_full_unmatch = accs_unmatch
    else:
        accs_full = pd.concat([accs_full, accs_match])
        accs_full_unmatch = pd.concat([accs_full_unmatch, accs_unmatch])

In [14]:
accs_NP_unmatch_CRF = pd.read_csv('Data/Acc_CRFID/original.csv')
accs_full_CRF = pd.read_csv('Data/Acc_CRFID/multi_colorcorr.csv')

In [15]:
accs_full_CPD = pd.read_csv('Data/Acc_CPD/match_all_temp.csv')

In [16]:
accs_pretrain_fDNC = pd.read_csv('Data/Acc_fDNC/fDNC_pretrain__acc_top1_and_top5_combined.csv')

In [17]:
# Skipping files that have obvious artifacts or known alignment issues
skipfiles = ['20231013-9-30-0', '20230412-20-15-17', '2023-01-23-01', '20239828-11-14-0', '2023-01-05-01', '2023-01-10-14', '2022-06-28-07', '2022-07-26-01', '2023-01-19-15', '2022-07-15-06', '2022-08-02-01', '2023-01-09-08', '2023-01-09-28', '2023-01-10-14', '2023-01-17-14', '2023-01-19-22', '2023-01-23-01', '7_YAaLR', '11_YAaLR', '20_YAaLR', '38_YAaDV', '55_YAaDV', '56_YAaDV', '62_YAaLR', '64_YAaDV', '70_YAaLR', '76_YAaDV']

In [18]:
group_assigns = pd.read_csv('Data/group_assigns.csv')
dataset_dict = {row['Filename']:row['Dataset'] for i, row in group_assigns.iterrows()}

In [19]:
%matplotlib qt

def gen_plots_benchmark(accs_full_CPD, accs_NP_unmatch_stat_atlas, accs_full_stat_atlas, accs_NP_unmatch_CRF_ID, accs_full_CRF_ID, accs_pretrain_fDNC, dataset_dict, skipfiles):

    filenames = accs_NP_unmatch_stat_atlas['Filename']

    df_dataset = pd.DataFrame(columns = ['Filename', 'Dataset', 'Label','Atlas', 'Model', 'Accuracy', 'Rank', 'Num_labeled'])

    df_quartiles = accs_NP_unmatch[~accs_NP_unmatch['Filename'].isin(skipfiles)]
    
    neur_quartiles = np.quantile(df_quartiles['Total_neurons']*df_quartiles['Percent_IDd'], [0,0.25,0.5,0.75,1])

    print(neur_quartiles)

    thresh = neur_quartiles[2]
    thresh = 100

    for file in filenames:
        if file in skipfiles:
            continue
        acc_NP_stat = accs_NP_unmatch_stat_atlas.loc[accs_NP_unmatch_stat_atlas['Filename']==file]
        acc_full_stat = accs_full_stat_atlas.loc[accs_full_stat_atlas['Filename']==file]
        acc_NP_CRF = accs_NP_unmatch_CRF_ID.loc[accs_NP_unmatch_CRF_ID['Filename']==file]
        acc_full_CRF = accs_full_CRF_ID.loc[accs_full_CRF_ID['Filename']==file]
        acc_full_CPD = accs_full_CPD.loc[accs_full_CPD['Filename']==file]
        acc_pretrain_fDNC = accs_pretrain_fDNC.loc[accs_pretrain_fDNC['Filename']==file]

        dataset = dataset_dict[file]

        per_IDd = acc_NP_stat.iloc[0]['Percent_IDd']
        total_neurons = acc_NP_stat.iloc[0]['Total_neurons']

        num_label = per_IDd*total_neurons

        for i in range(1,6):
            df_dataset.loc[len(df_dataset.index)] = [file, dataset,'CPD (best template)', 'Full','CPD', acc_full_CPD.iloc[0]['top'+str(i)], i, num_label]
            df_dataset.loc[len(df_dataset.index)] = [file, dataset,'StatAtlas (base)' ,'Base', 'StatAtlas',  acc_NP_stat.iloc[0]['Percent_top'+str(i)], i, num_label]
            df_dataset.loc[len(df_dataset.index)] = [file, dataset,'StatAtlas (retrain)','Full', 'StatAtlas',  acc_full_stat.iloc[0]['Percent_top'+str(i)], i, num_label]
            df_dataset.loc[len(df_dataset.index)] = [file, dataset,'CRF (base)','Base', 'CRFID',  acc_NP_CRF.iloc[0]['top'+str(i)], i, num_label]
            df_dataset.loc[len(df_dataset.index)] = [file, dataset,'CRF (retrain)','Full', 'CRFID',  acc_full_CRF.iloc[0]['top'+str(i)], i, num_label]
            df_dataset.loc[len(df_dataset.index)] = [file, dataset,'fDNC (base)','Base', 'fDNC',  acc_pretrain_fDNC.iloc[0]['top'+str(i)], i, num_label]
            
    
    for label in ['CPD (best template)', 'StatAtlas (base)', 'StatAtlas (retrain)', 'CRF (base)', 'CRF (retrain)', 'fDNC (base)']:

        print(label + ': low labels')
        print('Rank 1 mean and std')
        print(np.mean(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==1)&(df_dataset['Num_labeled']<thresh)]['Accuracy'])))
        print(np.std(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==1)&(df_dataset['Num_labeled']<thresh)]['Accuracy'])))
        print('Rank 5 mean and std')
        print(np.mean(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==5)&(df_dataset['Num_labeled']<thresh)]['Accuracy'])))
        print(np.std(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==5)&(df_dataset['Num_labeled']<thresh)]['Accuracy'])))
        print(label + ': high labels')
        print('Rank 1 mean and std')
        print(np.mean(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==1)&(df_dataset['Num_labeled']>=thresh)]['Accuracy'])))
        print(np.std(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==1)&(df_dataset['Num_labeled']>=thresh)]['Accuracy'])))
        print('Rank 5 mean and std')
        print(np.mean(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==5)&(df_dataset['Num_labeled']>=thresh)]['Accuracy'])))
        print(np.std(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==5)&(df_dataset['Num_labeled']>=thresh)]['Accuracy'])))
        print(label + ': laverage')
        print('Rank 1 mean and std')
        print(np.mean(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==1)]['Accuracy'])))
        print(np.std(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==1)]['Accuracy'])))
        print('Rank 5 mean and std')
        print(np.mean(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==5)]['Accuracy'])))
        print(np.std(np.asarray(df_dataset[(df_dataset['Label']==label)&(df_dataset['Rank']==5)]['Accuracy'])))
    
    fig, axs = plt.subplots(2,1, sharex=True)

    palette = sns.color_palette('colorblind')
    color1 = palette[3]
    color2 = palette[2]
    color3 = palette[0]
    color4 = palette[8]
    color5 = palette[4]
    color6 = palette[6]

    sns.violinplot(ax = axs[0], data=df_dataset[df_dataset['Rank']==1], x='Label', y='Accuracy', hue='Model', cut=0, inner='point', density_norm='width', inner_kws = {})
    sns.violinplot(ax = axs[1], data=df_dataset[df_dataset['Rank']==5], x='Label', y='Accuracy', hue='Model', cut=0, inner='point', density_norm='width', inner_kws = {})

    df_rank = df_dataset[df_dataset['Rank']==1]
    df_rank5 = df_dataset[df_dataset['Rank']==5]

    stats = df_rank.groupby('Label')['Accuracy'].agg(['mean']).reset_index()
    stats5 = df_rank5.groupby('Label')['Accuracy'].agg(['mean']).reset_index()

    for i, row in stats.iterrows():
        cat_index = np.where(df_rank['Label'].unique() == row['Label'])[0][0]
        axs[0].axhline(y=row['mean'], color='red', linestyle='-', linewidth=3, xmin=cat_index/len(df_rank['Label'].unique()), xmax=(cat_index+1)/len(df_rank['Label'].unique()))

    for i, row in stats5.iterrows():
        cat_index = np.where(df_rank5['Label'].unique() == row['Label'])[0][0]
        axs[1].axhline(y=row['mean'], color='red', linestyle='-', linewidth=3, xmin=cat_index/len(df_rank5['Label'].unique()), xmax=(cat_index+1)/len(df_rank5['Label'].unique()))

    axs[0].set_ylim((0,1))
    axs[0].set(xlabel=None)
    axs[0].set_title('Model performance: top ranked assignment')
    axs[0].set_yticks([0,0.25,0.5,0.75,1.0])

    axs[1].set_ylim((0,1))
    axs[1].set(xlabel=None)
    axs[1].set_title('Model performance: top 5 assignments')
    axs[1].set_yticks([0,0.25,0.5,0.75,1.0])
    axs[1].set_xticklabels(['', 'Base', 'Retrained', 'Base', 'Retrained'])

    axs[0].spines[['right', 'top']].set_visible (False)
    axs[0].axhline(1.0, ls='--', c='grey')
    axs[0].axhline(0.75, ls='--', c='grey')
    axs[0].axhline(0.5, ls='--', c='grey')
    axs[0].axhline(0.25, ls='--', c='grey')

    axs[1].spines[['right', 'top']].set_visible (False)
    axs[1].axhline(1.0, ls='--', c='grey')
    axs[1].axhline(0.75, ls='--', c='grey')
    axs[1].axhline(0.5, ls='--', c='grey')
    axs[1].axhline(0.25, ls='--', c='grey')

    plt.show()

    return df_dataset

df_dataset = gen_plots_benchmark(accs_full_CPD, accs_NP_unmatch, accs_full, accs_NP_unmatch_CRF, accs_full_CRF, accs_pretrain_fDNC, dataset_dict, skipfiles)



[ 33.    53.25  69.   155.5  184.  ]
CPD (best template): low labels
Rank 1 mean and std
0.3850112050363181
0.1535105216395126
Rank 5 mean and std
0.6464095162526102
0.14108426302043475
CPD (best template): high labels
Rank 1 mean and std
0.3883926904121189
0.08531298227392174
Rank 5 mean and std
0.6681316874276109
0.09713432191677596
CPD (best template): laverage
Rank 1 mean and std
0.3860544292480013
0.1361732452938966
Rank 5 mean and std
0.6531110371470252
0.12951995549746173
StatAtlas (base): low labels
Rank 1 mean and std
0.40485059579775506
0.1173072176469839
Rank 5 mean and std
0.6353057706543227
0.11803104883514566
StatAtlas (base): high labels
Rank 1 mean and std
0.4100999598332594
0.07654098783333474
Rank 5 mean and std
0.7173536710308535
0.05494133851038358
StatAtlas (base): laverage
Rank 1 mean and std
0.4064700804470064
0.10643715852570117
Rank 5 mean and std
0.6606184207704863
0.10954788897962951
StatAtlas (retrain): low labels
Rank 1 mean and std
0.6602282675580602
0.140

  axs[1].set_xticklabels(['', 'Base', 'Retrained', 'Base', 'Retrained'])


In [20]:
def get_acc_table(df_dataset):
    acc_table = pd.DataFrame(columns=['Worm', 'Dataset', 'CPD', 'fDNC', 'StatAtlas (base)', 'StatAtlas (retrain)', 'CRF (base)', 'CRF (retrain)'])
    for file in df_dataset['Filename'].unique():
        df_file = df_dataset[(df_dataset['Filename']==file)&(df_dataset['Rank']==1)]
        dataset= df_file.iloc[0]['Dataset']
        CPD = df_file[df_file['Label']=='CPD (best template)'].iloc[0]['Accuracy']
        Stat_base = df_file[df_file['Label']=='StatAtlas (base)'].iloc[0]['Accuracy']
        Stat_full = df_file[df_file['Label']=='StatAtlas (retrain)'].iloc[0]['Accuracy']
        CRF_base = df_file[df_file['Label']=='CRF (base)'].iloc[0]['Accuracy']
        CRF_full = df_file[df_file['Label']=='CRF (retrain)'].iloc[0]['Accuracy']
        fDNC_base = df_file[df_file['Label']=='fDNC (base)'].iloc[0]['Accuracy']

        acc_table.loc[len(acc_table.index)] = [file, dataset, CPD, fDNC_base, Stat_base, Stat_full, CRF_base, CRF_full]

    return acc_table

acc_table = get_acc_table(df_dataset)


In [21]:
acc_table = acc_table.sort_values(['Dataset','Worm'])
acc_table.to_csv('Data/summary_acc_ID.csv')

## Number of labels 

In [22]:
fig, axs = plt.subplots()

sns.histplot(data= df_dataset[(df_dataset['Rank']==1) & (df_dataset['Label']=='CPD (best template)')], x='Num_labeled', bins=16)
axs.set_xlabel('Number of ground truth labels')
axs.spines[['right', 'top']].set_visible (False)

plt.show()


In [23]:
num_labels = np.asarray(df_dataset[(df_dataset['Rank']==1) & (df_dataset['Label']=='CPD (best template)')]['Num_labeled'])

print(sum(num_labels>=100))
print(sum(num_labels<100))

29
65


## Confusion matrix of neural identities

In [None]:
'''
Get accuracy values for each dataset using the trained atlas and the roughly aligned point clouds. If you would like
to test on datasets that have not been pre-aligned, please use the neuroPAL_ID software which has the atlas and alignment
code pre-compiled
'''

NP_atlas_match = NWBAtlas(atlas_file = 'Data/atlases/2024_03_11_NPonly.pkl', ganglia='Data/neuron_ganglia.csv') #Atlas trained on just original 10 NeuroPAL datasets
NP_atlas_unmatch = NWBAtlas(atlas_file = 'Data/atlases/2024_03_11_NPunmatch.pkl', ganglia='Data/neuron_ganglia.csv')

accs_NP = get_accuracies('Data/aligned_heads/aligned_NP', NP_atlas_match)
accs_NP_unmatch = get_accuracies('Data/aligned_heads/aligned_NP_nomatch', NP_atlas_unmatch)

In [32]:
neur_df = NP_atlas_match.df
neurons = np.asarray(neur_df['ID'])

In [35]:
print(neurons)
print(np.argwhere(neurons=='I1L')[0][0])

['M3R' 'M3L' 'M4' 'NSMR' 'NSML' 'I2R' 'MCL' 'I3' 'MCR' 'I2L' 'MI' 'I1L'
 'I1R' 'RMEV' 'CEPVR' 'CEPVL' 'BAGR' 'BAGL' 'OLQVR' 'OLQVL' 'RMER' 'URAVR'
 'RMEL' 'URAVL' 'RIPR' 'RMED' 'RIPL' 'OLQDR' 'URYVR' 'URBR' 'URBL' 'URYVL'
 'OLQDL' 'IL1VR' 'URYDR' 'IL1VL' 'IL2VR' 'URYDL' 'OLLR' 'IL1R' 'IL1DR'
 'IL1L' 'IL2R' 'OLLL' 'IL2L' 'IL1DL' 'URADR' 'IL2VL' 'URADL' 'IL2DR'
 'IL2DL' 'URXR' 'URXL' 'CEPDR' 'ALA' 'CEPDL' 'RID' 'RMGL' 'RMGR' 'ADAL'
 'ADAR' 'AQR' 'ADEL' 'ADER' 'FLPL' 'FLPR' 'RICL' 'RICR' 'AIZL' 'AIZR'
 'AINL' 'ASJL' 'ASJR' 'RIMR' 'RIML' 'AINR' 'AVDL' 'AVDR' 'AVJL' 'AVBL'
 'AVJR' 'AUAL' 'RIBL' 'AUAR' 'RIBR' 'AVBR' 'ASEL' 'ASIL' 'RIVR' 'ASIR'
 'RIVL' 'ASER' 'AVHR' 'AVHL' 'ASHL' 'ASGR' 'ASGL' 'ASHR' 'AWCR' 'AWCL'
 'AIBR' 'AWAL' 'AIBL' 'SIBDR' 'AWAR' 'AWBL' 'ADLL' 'AWBR' 'SIBDL' 'ADLR'
 'RMDR' 'ADFR' 'AFDL' 'ADFL' 'AVER' 'AFDR' 'RMDL' 'AVEL' 'ASKL' 'ASKR'
 'AVAR' 'AVAL' 'RIAR' 'RIAL' 'SMDVR' 'SMDVL' 'RMDVR' 'SAAVR' 'SAAVL'
 'RMDVL' 'RIS' 'AVKL' 'AIML' 'AVKR' 'AIYL' 'AVL' 'AIMR' 'SIAVR' 'SMBVR

In [25]:
NP_atlas_match.neurons

['URADR',
 'OLQDL',
 'M3R',
 'URYDR',
 'AFDR',
 'I1R',
 'SAADL',
 'AIBL',
 'ASGR',
 'SABVL',
 'RMEV',
 'RIVR',
 'ASER',
 'I5',
 'SMBVR',
 'IL2DL',
 'RMDDR',
 'SMDDR',
 'RIBL',
 'RMGR',
 'ASIL',
 'ADAR',
 'M2L',
 'M2R',
 'URYVL',
 'M1',
 'SAAVR',
 'RMDVL',
 'RMDDL',
 'IL2L',
 'SMDVR',
 'SIBDL',
 'RIFR',
 'SAAVL',
 'AVBR',
 'AVKR',
 'NSMR',
 'RIAL',
 'ADLL',
 'OLQDR',
 'MCL',
 'AVDL',
 'RMED',
 'ASHL',
 'RIAR',
 'AVFR',
 'AVG',
 'AVDR',
 'ALA',
 'AS1',
 'AIYL',
 'BAGR',
 'ASHR',
 'SIAVR',
 'BAGL',
 'M4',
 'VB1',
 'RICR',
 'RID',
 'IL1DR',
 'RMDL',
 'IL2R',
 'RMHR',
 'AWCR',
 'AVER',
 'AWAR',
 'ASKL',
 'AIZL',
 'ASJL',
 'DB1',
 'RMHL',
 'URADL',
 'SIBVL',
 'RMER',
 'RIPL',
 'SMBDL',
 'SIAVL',
 'ASGL',
 'AIZR',
 'I2L',
 'AVAL',
 'AVEL',
 'AUAR',
 'RIGR',
 'IL1VL',
 'AINR',
 'M5',
 'IL2VR',
 'ASJR',
 'SAADR',
 'AIMR',
 'ADLR',
 'I6',
 'AVL',
 'RICL',
 'RMDVR',
 'SMDVL',
 'VA1',
 'CEPDL',
 'SMBDR',
 'AVJR',
 'SABD',
 'OLLR',
 'I2R',
 'ASKR',
 'URYVR',
 'AIAL',
 'CEPVR',
 'OLQVR',
 'HMC',
 'R

In [50]:
x = np.ones(100)
y = np.arange(100)

print(np.expand_dims(y,axis=1)@np.expand_dims(x,axis=0))

[[ 0.  0.  0. ...  0.  0.  0.]
 [ 1.  1.  1. ...  1.  1.  1.]
 [ 2.  2.  2. ...  2.  2.  2.]
 ...
 [97. 97. 97. ... 97. 97. 97.]
 [98. 98. 98. ... 98. 98. 98.]
 [99. 99. 99. ... 99. 99. 99.]]


In [53]:
def get_confusion_matrix(folder, atlas):
    acc_df = pd.DataFrame(columns=['Total_neurons','Percent_IDd', 'Percent_top1', 'Percent_top2', 'Percent_top3', 'Percent_top4', 'Percent_top5', 'Filename'])
    neurons = np.asarray(atlas.df['ID'])

    confusion_matrix = np.zeros((len(neurons), len(neurons)))
    tot_gt = np.zeros(len(neurons))

    for file in os.listdir(folder):
        if not file[-4:] == '.csv':
            continue

        df_data = pd.read_csv(folder + '/'+file)
        df_data = df_data.rename(columns={"aligned_x":"X","aligned_y":"Y","aligned_z":"Z", "aligned_R":"R", "aligned_G":"G", "aligned_B":"B"})

        df, costs = run_linear_assignment(df_data, atlas)

        IDd = df[~df['ID'].isnull()]

        per_ID = len(IDd.index)/len(df.index)

        total_neurons = len(df.index)

        for k, row in IDd.iterrows():
            gt = row['ID']
            assign = row['assign_1']

            if not gt in neurons:
                continue
            
            i = np.argwhere(neurons==gt)[0][0]
            j = np.argwhere(neurons==assign)[0][0]
            
            confusion_matrix[i,j] +=1
            tot_gt[i] +=1

    tot_gt_matrix =  np.expand_dims(tot_gt, axis=1) @ np.expand_dims(np.ones(len(neurons)), axis=0)
    print(tot_gt_matrix.shape)

    confusion_matrix_scaled = np.divide(confusion_matrix, tot_gt_matrix)

    return confusion_matrix_scaled, neurons



(194, 194)


In [98]:
def get_confusion_split(folder, atlas_folder, neurons):
    acc_df = pd.DataFrame(columns=['Total_neurons','Percent_IDd', 'Percent_top1', 'Percent_top2', 'Percent_top3', 'Percent_top4', 'Percent_top5', 'Filename'])
    confusion_matrix = np.zeros((len(neurons), len(neurons)))
    tot_gt = np.zeros(len(neurons))

    for group in range(5):
        atlas = NWBAtlas(f'{atlas_folder}/exgroup{group}.pkl', ganglia='Data/neuron_ganglia.csv')

        for file in os.listdir(f'{folder}/group{group+1}'):
            if not file[-4:] == '.csv':
                continue
            
            df_data = pd.read_csv(f'{folder}/group{group+1}/{file}')
            df_data = df_data.rename(columns={"aligned_x":"X","aligned_y":"Y","aligned_z":"Z", "aligned_R":"R", "aligned_G":"G", "aligned_B":"B"})

            df, costs = run_linear_assignment(df_data, atlas)

            IDd = df[~df['ID'].isnull()]

            per_ID = len(IDd.index)/len(df.index)

            total_neurons = len(df.index)

            for k, row in IDd.iterrows():
                gt = row['ID']
                assign = row['assign_1']

                if not gt in neurons:
                    continue
                
                i = np.argwhere(neurons==gt)[0][0]
                j = np.argwhere(neurons==assign)[0][0]
                
                confusion_matrix[i,j] +=1
                tot_gt[i] +=1

    tot_gt_matrix =  np.expand_dims(tot_gt, axis=1) @ np.expand_dims(np.ones(len(neurons)), axis=0)
    print(tot_gt_matrix.shape)

    confusion_matrix_scaled = np.divide(confusion_matrix, tot_gt_matrix)

    return confusion_matrix_scaled, neurons



In [128]:
atlas = NWBAtlas(atlas_file = 'Data/atlases/2024_03_11_NPonly.pkl', ganglia='Data/neuron_ganglia.csv')
confusion, neurons = get_confusion_matrix('Data/aligned_heads/aligned_NP', atlas)

correct_stat_base = np.nan_to_num(np.diagonal(confusion))

correct_matrix = np.multiply(confusion, np.eye(confusion.shape[0]))
confusion_Stat_base = confusion - correct_matrix

confusion, neurons = get_confusion_split('Data/aligned_heads/aligned_split', 'Data/atlases/2024_03_11_split', neurons)

correct_matrix = np.multiply(confusion, np.eye(confusion.shape[0]))
confusion_Stat_retrain = confusion - correct_matrix

correct_stat_retrain = np.nan_to_num(np.diagonal(confusion))

confusion, neurons = get_confusion_from_predict('Data/predict_CRF_original', neurons)

correct_matrix = np.multiply(confusion, np.eye(confusion.shape[0]))
confusion_CRF_base = confusion - correct_matrix

correct_CRF_base = np.nan_to_num(np.diagonal(confusion))

confusion, neurons = get_confusion_from_predict('Data/predict_CRF_full', neurons)

correct_CRF_retrain = np.nan_to_num(np.diagonal(confusion))

correct_matrix = np.multiply(confusion, np.eye(confusion.shape[0]))
confusion_CRF_retrain = confusion - correct_matrix

(194, 194)
(194, 194)


  confusion_matrix_scaled = np.divide(confusion_matrix, tot_gt_matrix)


(194, 194)


  confusion_matrix_scaled = np.divide(confusion_matrix, tot_gt_matrix)


(194, 194)


  confusion_matrix_scaled = np.divide(confusion_matrix, tot_gt_matrix)


In [142]:
atlas = NWBAtlas(atlas_file = 'Data/atlases/2024_03_11_NPonly.pkl', ganglia='Data/neuron_ganglia.csv')
confusion_stat_base, neurons = get_confusion_matrix('Data/aligned_heads/aligned_NP', atlas)

confusion_stat_base = np.nan_to_num(confusion_stat_base)

confusion_stat_retrain, neurons = get_confusion_split('Data/aligned_heads/aligned_split', 'Data/atlases/2024_03_11_split', neurons)

confusion_stat_retrain = np.nan_to_num(confusion_stat_retrain)

confusion_CRF_base, neurons = get_confusion_from_predict('Data/predict_CRF_original', neurons)

confusion_CRF_base = np.nan_to_num(confusion_CRF_base)

confusion_CRF_retrain, neurons = get_confusion_from_predict('Data/predict_CRF_full', neurons)

confusion_CRF_retrain = np.nan_to_num(confusion_CRF_retrain)


(194, 194)
(194, 194)


  confusion_matrix_scaled = np.divide(confusion_matrix, tot_gt_matrix)


(194, 194)


  confusion_matrix_scaled = np.divide(confusion_matrix, tot_gt_matrix)


(194, 194)


  confusion_matrix_scaled = np.divide(confusion_matrix, tot_gt_matrix)


In [178]:
plt.rcParams.update({'font.size':1})
fig, axs = plt.subplots()
sns.heatmap(data=confusion_stat_base, cmap = 'Reds', vmin=0, vmax=1, norm=LogNorm())
axs.set_xticks(np.arange(len(neurons)), neurons, fontsize=5)
axs.set_yticks(np.arange(len(neurons)), neurons, fontsize=5)

for tick in axs.xaxis.get_major_ticks()[1::2]:
    tick.set_pad(20)

for tick in axs.yaxis.get_major_ticks()[1::2]:
    tick.set_pad(20)

plt.show()

2024-06-12 21:34:16.059 python[81280:7884534] +[CATransaction synchronize] called within transaction
  el.exec() if hasattr(el, 'exec') else el.exec_()
2024-06-12 21:34:22.286 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:34:27.041 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:34:28.374 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:34:28.636 python[81280:7884534] +[CATransaction synchronize] called within transaction


In [179]:
plt.rcParams.update({'font.size':5})
fig, axs = plt.subplots()
sns.heatmap(data=confusion_stat_retrain, cmap = 'Greens', vmin=0, vmax=1, norm=LogNorm())
axs.set_xticks(np.arange(len(neurons)), neurons, fontsize=5)
axs.set_yticks(np.arange(len(neurons)), neurons, fontsize=5)

for tick in axs.xaxis.get_major_ticks()[1::2]:
    tick.set_pad(20)

for tick in axs.yaxis.get_major_ticks()[1::2]:
    tick.set_pad(20)


plt.show()

2024-06-12 21:35:14.930 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:35:21.040 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:35:22.330 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:35:22.593 python[81280:7884534] +[CATransaction synchronize] called within transaction


In [180]:
plt.rcParams.update({'font.size':5})
fig, axs = plt.subplots()
sns.heatmap(data=confusion_CRF_base, cmap = 'Blues', vmin=0, vmax=1, norm=LogNorm())
axs.set_xticks(np.arange(len(neurons)), neurons, fontsize=5)
axs.set_yticks(np.arange(len(neurons)), neurons, fontsize=5)

for tick in axs.xaxis.get_major_ticks()[1::2]:
    tick.set_pad(20)

for tick in axs.yaxis.get_major_ticks()[1::2]:
    tick.set_pad(20)


plt.show()

2024-06-12 21:35:49.808 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:36:03.788 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:36:06.055 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:36:06.318 python[81280:7884534] +[CATransaction synchronize] called within transaction


In [181]:
plt.rcParams.update({'font.size':5})
fig, axs = plt.subplots()
sns.heatmap(data=confusion_CRF_retrain, cmap = 'Oranges', vmin=0, vmax=1, norm=LogNorm())
axs.set_xticks(np.arange(len(neurons)), neurons, fontsize=5)
axs.set_yticks(np.arange(len(neurons)), neurons, fontsize=5)

for tick in axs.xaxis.get_major_ticks()[1::2]:
    tick.set_pad(20)

for tick in axs.yaxis.get_major_ticks()[1::2]:
    tick.set_pad(20)

plt.show()

2024-06-12 21:36:37.451 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:36:43.072 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:36:44.586 python[81280:7884534] +[CATransaction synchronize] called within transaction
2024-06-12 21:36:44.848 python[81280:7884534] +[CATransaction synchronize] called within transaction


In [108]:
def get_confusion_from_predict(folder, neurons):
    acc_df = pd.DataFrame(columns=['Total_neurons','Percent_IDd', 'Percent_top1', 'Percent_top2', 'Percent_top3', 'Percent_top4', 'Percent_top5', 'Filename'])

    confusion_matrix = np.zeros((len(neurons), len(neurons)))
    tot_gt = np.zeros(len(neurons))

    for file in os.listdir(folder):
        if not file[-4:] == '.csv':
            continue

        df = pd.read_csv(folder + '/'+file)

        IDd = df[~df['GT'].isnull()]

        per_ID = len(IDd.index)/len(df.index)

        total_neurons = len(df.index)

        for k, row in IDd.iterrows():
            gt = row['GT']
            assign = row['Top1']

            if not gt in neurons:
                continue
            
            i = np.argwhere(neurons==gt)[0][0]
            j = np.argwhere(neurons==assign)[0][0]
            
            confusion_matrix[i,j] +=1
            tot_gt[i] +=1

    tot_gt_matrix =  np.expand_dims(tot_gt, axis=1) @ np.expand_dims(np.ones(len(neurons)), axis=0)
    print(tot_gt_matrix.shape)

    confusion_matrix_scaled = np.divide(confusion_matrix, tot_gt_matrix)

    return confusion_matrix_scaled, neurons



## Visualize incorrect predictions

In [148]:
correct_stat_base = np.diagonal(confusion_stat_base)
correct_stat_retrain = np.diagonal(confusion_stat_retrain)
correct_CRF_base = np.diagonal(confusion_CRF_base)
correct_CRF_retrain = np.diagonal(confusion_CRF_retrain)

In [173]:
print(confusion_stat_retrain[15,15])

0.6527777777777778


In [174]:
print(np.argwhere(confusion_stat_retrain[15,:]!=0))

[[ 14]
 [ 15]
 [ 35]
 [ 55]
 [116]
 [148]
 [151]
 [153]
 [155]
 [157]
 [159]
 [161]]


In [175]:
print(neurons[np.argwhere(confusion_stat_retrain[15,:]!=0)])

[['CEPVR']
 ['CEPVL']
 ['IL1VL']
 ['CEPDL']
 ['RMDL']
 ['SIBVL']
 ['SIADL']
 ['RMDDL']
 ['RIR']
 ['SIBVR']
 ['SMDDL']
 ['RIH']]


In [171]:
print(np.argwhere(neurons=='CEPVL'))

[[15]]


In [172]:
print(correct_stat_retrain[15])

0.6527777777777778
