### This notebook contains basic scripts for organizing Astellas data into hierarchical dicts
***

In [None]:
import os
import time
import sys
import copy
import bisect
import pickle

import cv2

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from skimage import transform

from scipy import stats
from scipy.spatial import distance
import scipy.signal as signal
import scipy.interpolate as interpolate
from scipy import ndimage as nd

from sklearn import metrics
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import TheilSenRegressor
from sklearn.neighbors import DistanceMetric

from skimage import transform
from tqdm.notebook import tqdm

import isx

from astellaslib import *

***
### Populate dictionary with synced neural and COM data. Each experiment's stat_dict is primary data object.
#### These should be saved as .pkl files and loaded/extended by subsequent analyses rather than created from scratch

In [None]:
experiments = ['linear_social_saline', 'linear_social_1.5mgkg_rBaclofen']
genotypes = ['FMR1CTRL', 'FMR1KO']
dates = ['20200401', '20200402', '20200408', '20200415']

conditions = genotypes

In [None]:
# initialize stat dict:
stats_to_do = ['trace_mtx', 'raster_mtx', 'COM']
stat_dict = dict()
for c in conditions:
    stat_dict[c] = {}
    for e in experiments:
        stat_dict[c][e] = {}
        subjs = df_sub.subject_ID.loc[(df_sub.condition == e) & (df_sub.genotype == c)].values
        for s in subjs:
            stat_dict[c][e][s] = {}
            for st in stats_to_do:
                stat_dict[c][e][s][st] = []
#stat_dict = dict(keys=conditions)
print(stat_dict)

for the_cond in conditions:
    print('\n\n',the_cond)
    for the_exp in experiments:
        print('\t',the_exp)
        df_cond = df_sub.loc[(df_sub.condition == the_exp) & (df_sub.genotype == the_cond)] 
        cellsets_cond = get_cellset_paths(df_cond) # collect paths to cellsets of interest
        subj_ids = df_cond.subject_ID.unique()
        for subj in subj_ids: # each subject:
            print('\n',subj)
            # Load cellset and gpio csv file, use them to create frame-lookup vector:
            subj_ca_path = fix_data_path(df_cond.data_dir_ca.loc[df_cond.subject_ID == subj].values[0] )
            subj_behav_path = fix_data_path(df_cond.data_dir_behavior.loc[df_cond.subject_ID == subj].values[0])
            gpio_fn = [subj_ca_path + i for i in os.listdir(subj_ca_path) if 'gpio.csv' in i][0]
            cellset_fn = [i for i in cellsets_cond if subj in i][0]
            frame_lookup = make_frame_lookup(gpio_fn, cellset_fn)
            cs_time = (get_cellset_timescale(cellset_fn))
            cs_fs = round(1/(cs_time[1]) ,8)
            print('Ca imaging sampling rate: {}'.format(cs_fs))

            # Load COM csv file:
            com_fn = behavior_dir + df_cond.behav_data_basename.loc[df_cond.subject_ID == subj].values[0] + '_COM.csv'
            the_com = np.loadtxt(com_fn, delimiter=',')

            # flip Center of mass x-values if test side is on the right. This aligns all data so that test side = left
            the_test_side = test_side[the_exp][subj]
            print('test side: {}'.format(the_test_side))
            if the_test_side == 'r': 
                the_com[:,1] = -(the_com[:,1] - frame_width )
            stat_dict[the_cond][the_exp][subj]['COM'] = the_com

            # identify frames for which COM exceeded left or right threshold:
            #left_frames = np.argwhere(the_com[:,1] < x_threshs[0]).flatten()
            #right_frames = np.argwhere(the_com[:,1] > x_threshs[1]).flatten()
            #middle_frames = np.asarray([i for i in range(len(the_com)) if i not in np.union1d(left_frames, right_frames) ])

            # Build trace_matrix / z-matrix, event-matrix, etc:
            [the_trace_mtx, _,the_isx_raster_mtx,] = cellset_data_volumes(cellset_fn, isx_ed_threshold=3,autosort_snr=4)
            z_mtx = stats.zscore(the_trace_mtx, axis = 1, nan_policy = 'omit')

            # Make time-base for frame-averaged activity. Each time point references the start of a behavior video frame:
            tvect = []
            for i in np.arange(1,max(frame_lookup)):
                x_start = cs_time[np.argwhere(frame_lookup==i).flatten()[0]]
                tvect.append(x_start)
            tvect = np.asarray(tvect)
            behav_fs = 1/(max(tvect) / len(frame_lookup))
            print('behavior video fps: {}'.format(behav_fs))

            # extract activity for each video frame and assemble into matrix:
            frame_win = int(cs_fs * .05)
            sub_z_mtx = np.zeros((len(tvect), z_mtx.shape[1]), dtype = np.float32)
            sub_r_mtx = np.zeros((len(tvect), the_isx_raster_mtx.shape[1]), dtype = np.int16)
            print('Building frame-aligned data matrix...')
            for cellnum in range(z_mtx.shape[1]):
                svect = []
                rvect = []
                for i in np.arange(1,max(frame_lookup)):
                    x_start = int(cs_fs * cs_time[np.argwhere(frame_lookup==i).flatten()[0]])
                    dat_vect = z_mtx[:, cellnum]
                    svect.append(np.nanmean(dat_vect[x_start:(x_start + frame_win)]))
                    dat_vect = the_isx_raster_mtx[:, cellnum]
                    rvect.append(np.nanmax(dat_vect[x_start:(x_start + frame_win)]))
                svect = np.asarray(svect)
                rvect = np.asarray(rvect)
                sub_z_mtx[:,cellnum] = svect.astype(np.float32)
                sub_r_mtx[:,cellnum] = rvect.astype(np.int16)
            #print(sub_z_mtx.shape)
            stat_dict[the_cond][the_exp][subj]['trace_mtx'] = sub_z_mtx
            stat_dict[the_cond][the_exp][subj]['raster_mtx'] = sub_r_mtx

        
   

In [None]:
# save the dictionary:

fn = 'frame_aligned_data_baclofen.pkl'
pickle.dump(stat_dict, open(fn,'wb'))

***
### Compute auROC values (event-tuning) for events/annotations of interest from stat_dict, and add fields to stat_dict:

In [None]:
sides = (95, 230)
event_types = ['nose-to-right', ('nose-to-ag', 'nose-to-nose') ] # these should be changed according to analysis of interest
align_pad = 200
phase_dict = {'preference': (6500, 6500+11500)}
n_iter = 500


for the_cond in stat_dict.keys(): # each genotype
    print('\n',the_cond)
    if 'linear' in sorted(stat_dict[the_cond].keys())[0]:
        print('** Repeated measures detected **')
        for the_exp in stat_dict[the_cond].keys(): # each experiment condition
            print(the_exp)
            
            # This needs to be populated with per-experiment version of the following:
            
                        
    else: # same as above loop but without repeated measures for each subject:
        print('** No repeated measures detected **')
        for the_subj in stat_dict[the_cond]: # each subject
                print('\t',the_subj)
                for the_phase in sorted(phase_dict.keys()):
                    frames = (phase_dict[the_phase][0], phase_dict[the_phase][1])
                    for event_type in event_types:
                        if isinstance(event_type, str):
                            event_type = [event_type]
                        print('\t', event_type)
                        # load traces and labels, extract samples in the phase of interest
                        trace_mtx = stat_dict[the_cond][the_subj]['trace_mtx'][frames[0]:frames[1], :]
                        df_labels = stat_dict[the_cond][the_subj]['labels']
                        align_pnts = df_labels.frame.loc[df_labels.label.isin(event_type)].values - frames[0]
                        #print('\talign points: {}'.format(align_pnts[the_side][event_type]))

                        # convert alignment points to raster:
                        align_raster = np.zeros((trace_mtx.shape[0],))
                        padded_align_pnts = align_pnts
                        #padded_align_pnts = np.asarray(
                        #    [np.arange(i, i+align_pad) for i in align_pnts[the_side][event_type] if (i+align_pad) < trace_mtx.shape[0]]
                        #).flatten()

                        if len(align_pnts):
                            align_raster[padded_align_pnts] = 1

                            print('\t', trace_mtx.shape, align_raster.shape, min(np.argwhere(align_raster==1)), max(np.argwhere(align_raster==1)))

                            # find cells with significant auROC based on time-shuffled distribution:
                            auroc = []
                            auroc_p = []

                            for cellnum,the_trace in enumerate(trace_mtx.transpose()):
                                the_auroc = metrics.roc_auc_score(align_raster, the_trace)
                                #the_auroc = metrics.average_precision_score(align_raster[frames[0]:frames[1]], the_trace)
                                auroc.append(the_auroc)
                                shift_auroc = []
                                for the_iter in range(n_iter): # time-shuffle traces for null distribution:
                                    shift_trace = np.roll(the_trace, np.random.choice(np.arange(len(the_trace))))
                                    shift_auroc.append(metrics.roc_auc_score(align_raster, shift_trace) )

                                    # precision-recall curve instead of ROC:
                                    #shift_auroc.append(metrics.average_precision_score(align_raster[frames[0]:frames[1]], shift_trace) )

                                auroc_p.append(prob_from_dist(the_auroc, shift_auroc, hist_range = (0,1), tails = 2))
                        else:
                            print('\t** no events found')
                            auroc = []
                            auroc_p = []
                        statstr = 'auroc_' + ''.join([i + ', ' for i in event_type])
                        stat_dict[the_cond][the_subj][statstr] = auroc
                        statstr = 'auroc_p_' + ''.join([i + ', ' for i in event_type])
                        stat_dict[the_cond][the_subj][statstr] = auroc_p
                        print('\n')
                        


In [None]:
# save the dictionary:

fn = 'frame_aligned_data_baclofen.pkl'
pickle.dump(stat_dict, open(fn,'wb'))

***
### This script analyzes and pools auROC values across events, seperates +/- modulation. Generates frac_dict, which is input data for plotting functions

In [None]:
event_types = ['rearing left', 'rearing right', 'grooming']
print(event_types), 
print([''.join([j for j in i]) + '_pos_mod' for i in event_types] + [''.join([j for j in i]) + '_neg_mod' for i in event_types])

In [None]:
event_types

### Fraction of modulated neurons:

In [None]:
event_types = event_types
p_thresh = 0.01
roc_thresh = .6
align_pad = 200
frames = (6200,6200+11500)
#frames = (10,5500)

frac_dict = dict.fromkeys(stat_dict)

for the_cond in stat_dict.keys(): # each genotype
    print('\n',the_cond)

    frac_dict[the_cond] = dict.fromkeys([''.join([j for j in i]) + '_pos_mod' for i in event_types] + 
                                        [''.join([j for j in i]) + '_neg_mod' for i in event_types])
    
    for i in frac_dict[the_cond]:
        frac_dict[the_cond][i] = []
    print(frac_dict[the_cond].keys())

    for the_subj in sorted(stat_dict[the_cond]): # each subject
        print('\t',the_subj)

        trace_mtx = stat_dict[the_cond][the_subj]['trace_mtx']
        df_labels = stat_dict[the_cond][the_subj]['labels']
        
        for event_type in event_types:

            # create raster from event points:
            align_raster = np.zeros((trace_mtx.shape[0],))
            if isinstance(event_type, str):
                align_pnts = df_labels.frame.loc[df_labels.label.isin([event_type])].values
            else:
                align_pnts = df_labels.frame.loc[df_labels.label.isin(event_type)].values
            padded_align_pnts = align_pnts
            align_raster[padded_align_pnts] = 1
            
            if isinstance(event_type, tuple):
                event_type = ''.join([i + ', ' for i in event_type])[:-2]
            print('\t',event_type)
            #print('\talign points: {}'.format(align_pnts[the_side][event_type]))

            # load auROC and auROC_p vectors:
            stat_str = 'auroc_p_' + event_type
            the_auroc_p = stat_dict[the_cond][the_subj][stat_str]
            
            stat_str = 'auroc_' + event_type
            the_auroc = stat_dict[the_cond][the_subj][stat_str]
            
            pos_mod = np.intersect1d(np.argwhere(np.asarray(the_auroc_p) <= p_thresh).flatten(), np.argwhere(np.asarray(the_auroc) >= roc_thresh).flatten())
            neg_mod = np.intersect1d(np.argwhere(np.asarray(the_auroc_p) <= p_thresh).flatten(), np.argwhere(np.asarray(the_auroc) <= (1-roc_thresh)).flatten())

            # positively modulated cells:
            stat_str = ''.join([j for j in event_type]) + '_pos_mod'
            stat_str = stat_str.replace(', ','')
            if len(pos_mod) and sum(align_raster[frames[0]:frames[1]]): # check for modulated cells and zone entries
                avg_trace = np.mean(trace_mtx[:,pos_mod], axis=1)
                auc = metrics.roc_auc_score(align_raster[frames[0]:frames[1]], avg_trace[frames[0]:frames[1]])
                #frac_dict[the_cond][the_exp]['positive_mod_left'].append(auc)
                frac_dict[the_cond][stat_str].append(len(pos_mod)/len(the_auroc))
                if len(pos_mod==1):
                    pos_mod= pos_mod[0]
                #frac_dict[the_cond][the_exp]['positive_mod_left'].append(np.mean(the_auroc[pos_mod_l]))
            else:
                #frac_dict[the_cond][the_exp]['positive_mod_left'].append(np.nan)
                frac_dict[the_cond][stat_str].append(0)
                
            # negatively modulated cells:
            stat_str = ''.join([j for j in event_type]) + '_neg_mod'
            stat_str = stat_str.replace(', ','')
            if len(neg_mod) and sum(align_raster[frames[0]:frames[1]]): # check for modulated cells and zone entries
                avg_trace = np.mean(trace_mtx[:, neg_mod], axis=1)
                auc = metrics.roc_auc_score(align_raster[frames[0]:frames[1]], avg_trace[frames[0]:frames[1]])
                #frac_dict[the_cond][the_exp]['positive_mod_left'].append(auc)
                frac_dict[the_cond][stat_str].append(len(neg_mod)/len(the_auroc))
                if len(neg_mod==1):
                    neg_mod = neg_mod[0]
                #frac_dict[the_cond][the_exp]['positive_mod_left'].append(np.mean(the_auroc[pos_mod_l]))
            else:
                #frac_dict[the_cond][the_exp]['positive_mod_left'].append(np.nan)
                frac_dict[the_cond][stat_str].append(0)
        
            

### auROC of tuned ensembles:

In [None]:
event_types = ['nose-to-nose', 'nose-to-ag', 'boxing', 'nose-to-right', 'nose-to-left', 
               'approach left', 'approach right', 'boxing', ('nose-to-left', 'nose-to-nose', 'nose-to-ag')]
print(event_types)
print([''.join([j for j in i]) + '_pos_mod' for i in event_types] + [''.join([j for j in i]) + '_neg_mod' for i in event_types])

In [None]:
event_types = event_types
p_thresh = 0.01
roc_thresh = .6
align_pad = 200
frames = (6200,6200+11500)
#frames = (10,5500)

frac_dict = dict.fromkeys(stat_dict)

for the_cond in stat_dict.keys(): # each genotype
    print('\n',the_cond)

    frac_dict[the_cond] = dict.fromkeys([''.join([j for j in i]) + '_pos_mod' for i in event_types] + 
                                        [''.join([j for j in i]) + '_neg_mod' for i in event_types])
    
    for i in frac_dict[the_cond]:
        frac_dict[the_cond][i] = []
    print(frac_dict[the_cond].keys())

    for the_subj in sorted(stat_dict[the_cond]): # each subject
        print('\t',the_subj)

        trace_mtx = stat_dict[the_cond][the_subj]['trace_mtx']
        df_labels = stat_dict[the_cond][the_subj]['labels']
        
        for event_type in event_types:

            # create raster from event points:
            align_raster = np.zeros((trace_mtx.shape[0],))
            if isinstance(event_type, str):
                align_pnts = df_labels.frame.loc[df_labels.label.isin([event_type])].values
            else:
                align_pnts = df_labels.frame.loc[df_labels.label.isin(event_type)].values
            padded_align_pnts = align_pnts
            align_raster[padded_align_pnts] = 1
            
            if isinstance(event_type, tuple):
                event_type = ''.join([i + ', ' for i in event_type])[:-2]
            print('\t',event_type)
            #print('\talign points: {}'.format(align_pnts[the_side][event_type]))

            # load auROC and auROC_p vectors:
            stat_str = 'auroc_p_' + event_type
            the_auroc_p = stat_dict[the_cond][the_subj][stat_str]
            
            stat_str = 'auroc_' + event_type
            the_auroc = stat_dict[the_cond][the_subj][stat_str]
            
            pos_mod = np.intersect1d(np.argwhere(np.asarray(the_auroc_p) <= p_thresh).flatten(), np.argwhere(np.asarray(the_auroc) >= roc_thresh).flatten())
            neg_mod = np.intersect1d(np.argwhere(np.asarray(the_auroc_p) <= p_thresh).flatten(), np.argwhere(np.asarray(the_auroc) <= (1-roc_thresh)).flatten())

            # positively modulated cells:
            stat_str = ''.join([j for j in event_type]) + '_pos_mod'
            stat_str = stat_str.replace(', ','')
            if len(pos_mod) and sum(align_raster[frames[0]:frames[1]]): # check for modulated cells and zone entries
                avg_trace = np.mean(trace_mtx[:,pos_mod], axis=1)
                auc = metrics.roc_auc_score(align_raster[frames[0]:frames[1]], avg_trace[frames[0]:frames[1]])
                frac_dict[the_cond][stat_str].append(auc)
                #frac_dict[the_cond][stat_str].append(len(pos_mod)/len(the_auroc))
                #frac_dict[the_cond][stat_str].append(np.mean([the_auroc[i] for i in pos_mod]))
                if len(pos_mod==1):
                    pos_mod= pos_mod[0]
                #frac_dict[the_cond][the_exp]['positive_mod_left'].append(np.mean(the_auroc[pos_mod_l]))
            else:
                frac_dict[the_cond][stat_str].append(np.nan)
                #frac_dict[the_cond][stat_str].append(0)
                
            # negatively modulated cells:
            stat_str = ''.join([j for j in event_type]) + '_neg_mod'
            stat_str = stat_str.replace(', ','')
            if len(neg_mod) and sum(align_raster[frames[0]:frames[1]]): # check for modulated cells and zone entries
                avg_trace = np.mean(trace_mtx[:, neg_mod], axis=1)
                auc = metrics.roc_auc_score(align_raster[frames[0]:frames[1]], avg_trace[frames[0]:frames[1]])
                frac_dict[the_cond][stat_str].append(auc)
                #frac_dict[the_cond][stat_str].append(np.mean([the_auroc[i] for i in neg_mod]))
                if len(neg_mod==1):
                    neg_mod = neg_mod[0]
                #frac_dict[the_cond][the_exp]['positive_mod_left'].append(np.mean(the_auroc[pos_mod_l]))
            else:
                frac_dict[the_cond][stat_str].append(np.nan)
                #frac_dict[the_cond][stat_str].append(0)
        
            