In [None]:
import scipy as sp
import scipy.io
import os
import numpy as np
import pandas as pd
import glob
import csv
import random as rand
from tqdm import tnrange, tqdm_notebook
from collections import Iterable
import matplotlib.pylab as plt
import random as rand
from ipywidgets import *
from scipy import stats
import importlib
import sys
sys.path.append(os.getcwd()+'/../')
from utils import utils
from utils import utils, zscores
from utils import plotting_utils as pu
from tqdm import tqdm
import warnings

import cmocean
cmap = cmocean.cm.thermal
colors = cmap

# from utils import auc_methods as ama
# import matplotlib.patches as patches
# from matplotlib import gridspec
# from sklearn.metrics import roc_curve, auc
# from sklearn.model_selection import train_test_split
# from sklearn.preprocessing import label_binarize
# from matplotlib_venn import venn2

%load_ext autoreload
%autoreload 2
import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
plt.close('all')

font = {'family' : 'Arial',
        'weight' : 'normal',
        'size'   : 6}

mpl.rc('font', **font)
mpl.rc('xtick', labelsize=6) 
mpl.rc('ytick', labelsize=6)
mpl.rc('axes', labelsize=6)

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
size_mult = 1

In [None]:
plt.close('all')

font = {'family' : 'Arial',
        'weight' : 'normal',
        'size'   : 16}

mpl.rc('font', **font)
mpl.rc('xtick', labelsize=16) 
mpl.rc('ytick', labelsize=16)
mpl.rc('axes', labelsize=16)

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

In [None]:
data_directory = r'C:\Users\Eric\Documents\09-12-2021\DATA\Crossmodal_only'
log_df = pd.read_hdf(f'{data_directory}/log_df_processed_02-28-2019.h5', 'fixed')
unit_key_df = pd.read_hdf(f'{data_directory}/unit_key_df_processed_02-28-2019.h5', 'fixed')

In [None]:
unit_key_df.shape

In [None]:
target_trials = log_df['trial_label'].isin(['Touch Stim Hit', 'Touch Stim Miss'])
target_unit = log_df['uni_id'] == '07903-22-164t2'

example = log_df[target_unit & target_trials]

pos_trial = 'Touch Stim Hit'
labels, spikes = (example['trial_label'] == pos_trial).values, example['spike_counts(stim_aligned)'].values
spikes = np.vstack(spikes)[:, 39:59]/40 # first 500ms after stim onset; divide by 40 to go from per sec FR to raw spike count


In [None]:
def calc_euc_dist(labels, spikes, shuff = False):
    
    if shuff:
        labels = np.random.permutation(labels)
    
    pos_mean = np.mean(spikes[labels, :], axis = 0)
    neg_mean = np.mean(spikes[~labels, :], axis = 0)
    
    return np.linalg.norm(pos_mean-neg_mean)

def cacl_mean_diff(labels, spikes, shuff = True):
    if shuff:
        labels = np.random.permutation(labels)
    
    pos_mean = np.mean(spikes[labels, :])
    neg_mean = np.mean(spikes[~labels, :])
    
    return pos_mean-neg_mean
                    

def permutation_test(labels, spikes, num_iter = 1000):
    pos_mean = np.mean(spikes[labels, :], axis = 0)
    neg_mean = np.mean(spikes[~labels, :], axis = 0)

    real_value = calc_euc_dist(labels, spikes, shuff = False)
    shuff_values = [calc_euc_dist(labels, spikes, shuff = True) for i in range(num_iter)]
    p_val = np.mean(shuff_values >= real_value)
    return p_val

def mean_permutation_test(labels, spikes, num_iter = 1000):
    pos_mean = np.mean(spikes[labels, :])
    neg_mean = np.mean(spikes[~labels, :])

    real_value = cacl_mean_diff(labels, spikes, shuff = False)
    shuff_values = [cacl_mean_diff(labels, spikes, shuff = True) for i in range(num_iter)]
    p_val = np.mean(shuff_values >= real_value)
    return p_val

In [None]:
unit_key_df.columns

In [None]:
def compare_trials_unit_generator(log_df, unit_list, pos_trial, neg_trial, time_window, stim_length = 'long'):
    """ 
    generator that returns the unit id, vector of spike counts
    after the stim onset for every relevant trial, along with labels for those
    trials so that they can be run through permutation test
    """
    
    if stim_length == 'long':
        stims = ['Stim_Som_NoCue', 'Stim_Vis_NoCue']
    else:
        stims = ['1CycStim_Vis_NoCue','1CycStim_Som_NoCue']
    
    time_window_bins = (np.array(time_window)/0.025 + (1/0.025 - 1)).astype(int)
    
    subset_stim_length = log_df['trial_type'].isin(stims)
    subset_trialtype = log_df['trial_label'].isin([pos_trial, neg_trial])
    target_trials = log_df[subset_stim_length & subset_trialtype]
        
    for unit in unit_list:
        unit_df = target_trials[target_trials['uni_id'] == unit]
        labels, spikes = (unit_df['trial_label'] == pos_trial).values, unit_df['spike_counts(stim_aligned)'].values
        
        # first x ms after stim onset; divide by 40 to go from per sec FR to raw spike count
        spikes = np.vstack(spikes)[:, time_window_bins[0]:time_window_bins[1]]/40 
        
        yield unit, labels, spikes

def compare_to_baseline_unit_generator(log_df, unit_list, pos_trial,time_window, stim_length = 'long'):
    """ 
    generator that returns the unit id, vector of spike counts before and 
    after the stim onset for every relevant trial, along with labels for those
    trials so that they can be run through permutation test
    """
    
    if stim_length == 'long':
        stims = ['Stim_Som_NoCue', 'Stim_Vis_NoCue']
    else:
        stims = ['1CycStim_Vis_NoCue','1CycStim_Som_NoCue']
    
    #compare same length vectors before and after stim onset
    time_window_bins = (np.array(time_window)/0.025 + (1/0.025 - 1)).astype(int)
    baseline_window_bins = [39 - (time_window_bins[1] -time_window_bins[0]), 39] 
    
    subset_stim_length = log_df['trial_type'].isin(stims)
    subset_trialtype = log_df['trial_label']==pos_trial
    target_trials = log_df[subset_stim_length & subset_trialtype]
        
    for unit in unit_list:
        unit_df = target_trials[target_trials['uni_id'] == unit]
        if unit_df.shape[0] ==  0:
            labels = None
            spikes = None
        
        else:
            
            post_stim_spikes = np.vstack(unit_df['spike_counts(stim_aligned)'].values)[:, time_window_bins[0]:time_window_bins[1]]/40
            pre_stim_spikes = np.vstack(unit_df['spike_counts(stim_aligned)'].values)[:, baseline_window_bins[0]:baseline_window_bins[1]]/40

            spikes = np.vstack([post_stim_spikes,pre_stim_spikes])

            labels = np.concatenate([[True]*post_stim_spikes.shape[0], [False]*pre_stim_spikes.shape[0]])
#             import pdb; pdb.set_trace()
        yield unit, labels, spikes
    

### Statistics for elevated activity over baseline for diff trial types

In [None]:
# print("hits vs baseline")
# unit_gen = compare_to_baseline_unit_generator(log_df,unit_key_df['uni_id'], 'Touch Stim Hit', [0,0.5], stim_length = 'long')
# all_pvals_hit_baseline = {unit:mean_permutation_test(labels, spikes) for unit, labels, spikes in unit_gen}

print("miss vs baseline")
unit_gen = compare_to_baseline_unit_generator(log_df,unit_key_df['uni_id'], 'Touch Stim Miss', [0,0.5], stim_length = 'long')
all_pvals_miss_baseline = {unit:mean_permutation_test(labels, spikes) for unit, labels, spikes in unit_gen}

print("hits vs baseline: 0-150ms")
unit_gen = compare_to_baseline_unit_generator(log_df,unit_key_df['uni_id'], 'Touch Stim Hit', [0,0.15], stim_length = 'long')
all_pvals_hit_baseline_stim_period = {unit:mean_permutation_test(labels, spikes) for unit, labels, spikes in unit_gen}

print("miss vs baseline: 0-150ms")
unit_gen = compare_to_baseline_unit_generator(log_df,unit_key_df['uni_id'], 'Touch Stim Miss', [0,0.15], stim_length = 'long')
all_pvals_miss_baseline_stim_period = {unit:mean_permutation_test(labels, spikes) for unit, labels, spikes in unit_gen}
      
print("hits vs baseline: 150-500ms")
unit_gen = compare_to_baseline_unit_generator(log_df,unit_key_df['uni_id'], 'Touch Stim Hit', [0.15,0.5], stim_length = 'long')
all_pvals_hit_baseline_post_stim_period = {unit:mean_permutation_test(labels, spikes) for unit, labels, spikes in unit_gen}

print("miss vs baseline 150-500ms")
unit_gen = compare_to_baseline_unit_generator(log_df,unit_key_df['uni_id'], 'Touch Stim Miss', [0.15,0.5], stim_length = 'long')
all_pvals_miss_baseline_post_stim_period = {unit:mean_permutation_test(labels, spikes) for unit, labels, spikes in unit_gen}

print("Touch block FA vs baseline")
touch_blocks = log_df[log_df['block_type'] == 'Whisker']
unit_gen = compare_to_baseline_unit_generator(touch_blocks, unit_key_df['uni_id'], 'Visual Stim FA', [0,0.5], stim_length = 'long')
all_pvals_FA_baseline = {unit:mean_permutation_test(labels, spikes) for unit, labels, spikes in unit_gen if labels is not None}

print("correct rejection vs baseline")
unit_gen = compare_to_baseline_unit_generator(log_df,unit_key_df['uni_id'], 'Visual Stim CR', [0,0.5], stim_length = 'long')
all_pvals_CR_baseline = {unit:mean_permutation_test(labels, spikes) for unit, labels, spikes in unit_gen}

In [None]:
'07903-22-164t2'

In [None]:
print("Touch block FA vs baseline")

touch_blocks = log_df[log_df['block_type'] == 'Whisker']
unit_gen = compare_to_baseline_unit_generator(log_df, ['07903-22-164t2'], 'Touch Stim Miss', [0,0.5], stim_length = 'long')
all_pvals_FA_baseline = {unit:permutation_test(labels, spikes) for unit, labels, spikes in unit_gen if labels is not None}


### Statistics for comparing activity between trial types

In [None]:
unit_gen = compare_trials_unit_generator(log_df,unit_key_df['uni_id'], 'Touch Stim Hit', 'Touch Stim Miss',[0, 0.5], stim_length = 'long')
all_pvals_hit_miss_post_stim = {unit:mean_permutation_test(labels, spikes) for unit, labels, spikes in unit_gen}

In [None]:
def cum_dist(arr, bins):
    hist = np.histogram(arr, bins = bins)
    cumsum = np.cumsum(hist[0])
    return [cumsum, hist[1]]

In [None]:
len(all_pvals_hit_miss_post_stim)

In [None]:
len(all_pvals_hit_baseline_stim_period)

In [None]:
fig, axes = plt.subplots(4,1, figsize = (7,18))
# plt.tight_layout()


compare_to_baseline_vals = [
    all_pvals_hit_baseline,
    all_pvals_miss_baseline,
    all_pvals_FA_baseline,
    all_pvals_CR_baseline
]

compare_to_baseline_vals_stim = [
    all_pvals_hit_baseline_stim_period,
    all_pvals_miss_baseline_stim_period
]

compare_to_baseline_vals_post_stim = [
    all_pvals_hit_baseline_post_stim_period,
    all_pvals_miss_baseline_post_stim_period
]

for group, c, label, y_pos in zip(compare_to_baseline_vals, ['C0', 'k', 'C2', 'C3'], 
                           ['Hit', 'Miss', 'FA', 'CR'], [0.4, 0.3,0.2,0.1]):
    
    pvals = np.array(list(group.values()))
    frac_below_a = sum(pvals < 0.05)/pvals.shape[0]
    cumsum = cum_dist(pvals,np.arange(0,1.001, 0.001))
    axes[0].plot(cumsum[1][:-1], cumsum[0]/len(group), color = c)
    axes[0].text(.65, y_pos, label + f' ({frac_below_a:.2f})', color = c, transform = axes[0].transAxes)
axes[0].set_title('Elevated above baseline (0-500ms)', pad = 20)


cumsum = cum_dist(list(all_pvals_hit_miss_post_stim.values()),np.arange(0,1.001, 0.001))
axes[1].plot(cumsum[1][:-1], cumsum[0]/len(all_pvals_hit_miss_post_stim), '-', color = 'blue')
axes[1].set_title('Activity in "Hit" > activity in "Miss" (0-500ms)', pad = 20)

for group, c, label, y_pos in zip(compare_to_baseline_vals_stim, ['C0', 'k'], ['Hit', 'Miss'], [0.4, 0.3]):
    pvals = np.array(list(group.values()))
    frac_below_a = sum(pvals < 0.05)/pvals.shape[0]
    cumsum = cum_dist(list(group.values()),np.arange(0,1.001, 0.001))
    axes[2].plot(cumsum[1][:-1], cumsum[0]/len(group), color = c)
    axes[2].text(.65, y_pos, label + f' ({frac_below_a:.2f})', color = c, transform = axes[2].transAxes)
axes[2].set_title('Elevated above baseline (0-150ms)', pad = 20)


for group, c, label, y_pos in zip(compare_to_baseline_vals_post_stim, ['C0', 'k'], ['Hit', 'Miss'], [0.4, 0.3]):
    pvals = np.array(list(group.values()))
    frac_below_a = sum(pvals < 0.05)/pvals.shape[0]
    cumsum = cum_dist(list(group.values()),np.arange(0,1.001, 0.001))
    axes[3].plot(cumsum[1][:-1], cumsum[0]/len(group), color = c)
    axes[3].text(.65, y_pos, label + f' ({frac_below_a:.2f})', color = c, transform = axes[3].transAxes)
axes[3].set_title('Elevated above baseline (150-500ms)', pad = 20)
axes[3].set_xlabel('p-value')

for ax in axes:
    ax.set_ylabel('Cummulative fraction\nof neurons')
    ax.set_ylim(0,1)
    ax.set_xlim(0,1)
    ax.axvline(0.05, linestyle = '--', color = 'k')
    ax.set_xticks([0.05,0.5,1])

fig.subplots_adjust(hspace = .35, left = 0.3)
# fig.savefig('permutation_test_fig2.png')

In [None]:
pos_mean = np.mean(spikes[labels, :], axis = 0)
neg_mean = np.mean(spikes[~labels, :], axis = 0)

real_value = calc_euc_dist(labels, spikes, shuff = False)
shuff_values = [calc_euc_dist(labels, spikes, shuff = True) for i in range(10000)]
np.mean(shuff_values >= real_value)

In [None]:

real_value

In [None]:
np.max(shuff_values)

In [None]:
example.columns

In [None]:
labels.isin(['Touch Stim Hit', 'Touch Stim Miss'])

In [None]:
np.vstack(spikes)