In [None]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

import isx

In [None]:
def twoaxis(axname,lwidth=1.5):
    #
    # formats axes by setting axis thickness & ticks to lwidth and clearing top/right axes
    #
    #
    axname.spines['bottom'].set_linewidth(lwidth)
    axname.tick_params(width=lwidth)
    axname.spines['left'].set_linewidth(lwidth)
    axname.spines['top'].set_linewidth(0)
    axname.spines['right'].set_linewidth(0)

### Specify flags and auto-curation filters:

In [None]:
auto_curate = True # set to False if you want to use existing cellset curation
detect_events = True # set to False if you want to use existing event isxd file
export_metrics = True # set to False if you do not want to create a .csv for each cellset's metrics

max_comps = 1. # maximum spatial components
min_snr = 5. # minimum Signal to noise ratio
min_event_rate = 0. # minimum event rate
event_thresh = 3.5 # SNR threshold above local MAD for event detection

### Metrics to save:

In [None]:
stats_to_do = ['snr', 'eventRate', 'decayMedian', 'overallAreaInPixels', 'cellName']


### Load a cellset or all cellsets in a directory:

In [None]:
cellset_list = [''] # leave this empty to analyze all cellsets in directory = data_dir, or specify list of cellsets to analyze here
data_dir = '/ariel/science/mmiller/data/longo_lab/'

if len(cellset_list[0]) == 0:
    print('getting cellsets from directory:\n')
    cellset_list = [data_dir + i for i in os.listdir(data_dir) if 'ICA.isxd' in i]
    
print(cellset_list)

In [None]:
for i in range(len(cellset_list)): # each cellset:
    the_cellset = cellset_list[i]
    print(the_cellset)
    data_dir = os.path.dirname(the_cellset)
    
    # Event detection:
    events_fn = isx.make_output_file_path(the_cellset, data_dir, 'ED')
    if (os.path.isfile(events_fn) is False) or (detect_events is True):
        print('Detecting events with {} SNR threshold...'.format(event_thresh))
        if os.path.isfile(events_fn) is True:
            os.remove(events_fn)
        isx.event_detection(the_cellset, events_fn, threshold = event_thresh)
    if auto_curate is True:
        print('Auto curating cellset...')
        isx.auto_accept_reject(the_cellset, events_fn, filters = [('# Comps', '=', max_comps), ('SNR', '>', min_snr), ('Event Rate', '>', min_event_rate)])

    # Compile Metrics:
    mets_fn = isx.make_output_file_path(the_cellset, data_dir, 'metrics')
    if os.path.isfile(mets_fn):
        os.remove(mets_fn)
    print('Compiling cellset metrics...')
    isx.cell_metrics(the_cellset, events_fn, mets_fn)

    # find accepted cells:
    accepted_cells = []
    cellset = isx.CellSet.read(the_cellset)

    for i in range(cellset.num_cells):
        if cellset.get_cell_status(i) == 'accepted':
            accepted_cells.append(cellset.get_cell_name(i))
            
    # create dataframe, filter it, and export it:        
    df_mets = pd.read_csv(mets_fn)
    df_out = df_mets[stats_to_do].loc[df_mets.cellName.isin(accepted_cells)]
    df_out.reset_index(drop=True, inplace=True)
    if export_metrics is True:
        mets_fn = mets_fn = isx.make_output_file_path(the_cellset, data_dir, 'curated_metrics', ext = 'csv')
        print('Saving metrics .csv file...')
        df_out.to_csv(mets_fn)
    print('\t{} accepted cells\n\tMedian SNR: {}\n\tMedian event rate: {}\n\tMedian Decay (s): {}'
          .format(len(accepted_cells), round(np.median(df_out.snr),2), round(np.median(df_out.eventRate),5), round(np.median(df_out.decayMedian),3)))
    print('\n')
    

***
### Plot histograms for one cellset's metrics:

In [None]:
stats_to_do = ['snr', 'eventRate', 'decayMedian', 'overallAreaInPixels']
bins = 25
plotsize = 4


# specify plotting ranges for metrics:
stat_range = {'snr':(2,20), 
              'eventRate':(0,0.25), 
              'decayMedian':(0,10), 
              'overallAreaInPixels':(25,300), 
              'numCells':(0,500)}
# plot labels:
stat_labels = {'snr':'SNR', 
               'eventRate':'Ca events per second', 
               'decayMedian':'event decay (s)', 
               'overallAreaInPixels':'cell area (pixels)', 
               'numCells':'number of cells'}

In [None]:
cellset_list = ['/ariel/science/mmiller/data/longo_lab/2020-04-03-16-51-53_video-PP-BP-MC-DFF-PCA-ICA.isxd']
metrics_list = [isx.make_output_file_path(i, os.path.dirname(i), 'curated_metrics', ext = 'csv') for i in cellset_list]
print(metrics_list)
df_mets = pd.read_csv(metrics_list[0], index_col=0)

In [None]:
cmap = plt.get_cmap('tab10')

f,ax = plt.subplots(1,len(stats_to_do), figsize=((plotsize*1.5)*len(stats_to_do), plotsize))

for a, stat in zip(ax, stats_to_do): # each stat
    
    h,hb = np.histogram(df_mets[stat], bins=bins)
    a.step(hb, np.insert(h, 0,0), color=cmap(0))
    a.vlines(np.median(df_mets[stat]), ymin=0,ymax=max(h), color='crimson', alpha=.5)
    a.grid('on', alpha=.5)
    a.set_xlabel(stat_labels[stat], fontsize=14)
    a.set_ylabel('count', fontsize=14)
    a.set_title('median = {}'.format(round(np.median(df_mets[stat]),4)), loc='center')
    a.set_xlim(0,max(hb))
    twoaxis(a)

plt.show()

***
### Plot violin plots for all metrics in a directory (plotted together):

In [None]:
stats_to_do = ['snr', 'eventRate', 'decayMedian', 'overallAreaInPixels', 'numCells']

# specify plotting ranges for metrics:
stat_range = {'snr':(2,20), 
              'eventRate':(0,0.25), 
              'decayMedian':(0,10), 
              'overallAreaInPixels':(25,300), 
              'numCells':(0,500)}
# plot labels:
stat_labels = {'snr':'SNR', 
               'eventRate':'Ca events per second', 
               'decayMedian':'event decay (s)', 
               'overallAreaInPixels':'cell area (pixels)', 
               'numCells':'number of cells'}

In [None]:
cellset_list = [''] # leave this empty to analyze all cellsets in directory = data_dir, or specify list of cellsets to analyze here
data_dir = '/ariel/science/mmiller/data/longo_lab/'

if len(cellset_list[0]) == 0:
    print('getting cellsets from directory:\n')
    cellset_list = [data_dir + i for i in os.listdir(data_dir) if 'ICA.isxd' in i]

metrics_list = [isx.make_output_file_path(i, os.path.dirname(i), 'metrics') for i in cellset_list]
datename_list = [os.path.basename(i)[:10] for i in metrics_list]
print(metrics_list,'\n')
print(datename_list)

In [None]:
xjog = 1
cmap = plt.get_cmap('tab10')

stat_dict = dict.fromkeys(stats_to_do)

for stat in stats_to_do:
    f,ax = plt.subplots(1,1,figsize=(10,4))
    #ax.set_title(stat, fontsize=14)
    x_val = 0
    for datename, met_fn, cellset_fn in zip(sorted(datename_list), sorted(metrics_list), sorted(cellset_list)):
        stat_dict[datename] = {}
        # load metrics:
        df = pd.read_csv(met_fn)
        
        # find accepted cells:
        accepted_cells = []
        cellset = isx.CellSet.read(cellset_fn)
        for i in range(cellset.num_cells):
            if cellset.get_cell_status(i) == 'accepted':
                accepted_cells.append(cellset.get_cell_name(i))
        
        # populate stats dictionary:
        if stat is not 'numCells':
            stat_vals = df[stat].loc[df.cellName.isin(accepted_cells)].values
            stat_dict[datename][stat] = stat_vals

            # violin plot:
            parts = ax.violinplot(stat_vals, positions = [x_val], showextrema=False)
            for pc in parts['bodies']:
                pc.set_facecolor(cmap(0))
                pc.set_edgecolor(cmap(0))
                pc.set_alpha(.5)
            ax.scatter(x_val, np.nanmedian(stat_vals), s=2000, marker='_', linewidth=4, color=cmap(0), alpha=1)            
            
        elif stat is 'numCells':
            stat_vals = len(accepted_cells)
            stat_dict[datename][stat] = len(accepted_cells)
            ax.scatter(x_val, len(accepted_cells), s=2000, marker='_', linewidth=4, color=cmap(0), alpha=1 )
            
        x_val+=1        

    ax.set_ylim(stat_range[stat][0], stat_range[stat][1])
    ax.set_ylabel(stat_labels[stat], fontsize=14)
    ax.set_xticks(np.arange(5))
    ax.set_xticklabels([i for i in sorted(datename_list)], fontsize=12)
    ax.set_xlabel('experiment', fontsize=14)
    ax.grid('on', alpha=.5)
    twoaxis(ax)
    plt.show()
    