In [3]:
# General
import sys
import os.path as op
from time import time
from collections import OrderedDict as od
from importlib import reload
import warnings
from glob import glob
import itertools
import h5py

# Scientific
import numpy as np
import pandas as pd
pd.options.display.max_rows = 200
pd.options.display.max_columns = 999
import scipy.io as sio

# Stats
import scipy.stats as stats
import statsmodels.api as sm
from statsmodels.formula.api import ols
import random
from sklearn.preprocessing import minmax_scale
from sklearn.decomposition import PCA
from scipy.ndimage.filters import gaussian_filter1d

# Plots
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import matplotlib as mpl
from matplotlib.lines import Line2D
import matplotlib.patches as patches
mpl.rcParams['grid.linewidth'] = 0.1
mpl.rcParams['grid.alpha'] = 0.75
mpl.rcParams['lines.linewidth'] = 1
mpl.rcParams['lines.markersize'] = 3
mpl.rcParams['xtick.labelsize'] = 12
mpl.rcParams['ytick.labelsize'] = 12
mpl.rcParams['xtick.major.width'] = 0.8
mpl.rcParams['ytick.major.width'] = 0.8
colors = ['1f77b4', 'd62728', '2ca02c', 'ff7f0e', '9467bd', 
          '8c564b', 'e377c2', '7f7f7f', 'bcbd22', '17becf']
mpl.rcParams['axes.prop_cycle'] = mpl.cycler('color', colors)
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.formatter.offset_threshold'] = 2
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['axes.labelpad'] = 8
mpl.rcParams['axes.titlesize'] = 16
mpl.rcParams['axes.grid'] = False
mpl.rcParams['axes.axisbelow'] = True
mpl.rcParams['legend.loc'] = 'upper right'
mpl.rcParams['legend.fontsize'] = 14
mpl.rcParams['legend.frameon'] = False
mpl.rcParams['figure.dpi'] = 300
mpl.rcParams['figure.titlesize'] = 16
mpl.rcParams['figure.figsize'] = (10, 4) 
mpl.rcParams['figure.subplot.wspace'] = 0.25 
mpl.rcParams['figure.subplot.hspace'] = 0.25 
mpl.rcParams['font.sans-serif'] = ['Helvetica']
mpl.rcParams['savefig.format'] = 'pdf'
mpl.rcParams['pdf.fonttype'] = 42

# Personal
sys.path.append('/home1/dscho/code/general')
sys.path.append('/home1/dscho/code/projects/manning_replication')
sys.path.append('/home1/dscho/code/projects')
import data_io as dio
import array_operations as aop
from eeg_plotting import plot_trace, plot_trace2
from time_cells import spike_sorting, spike_preproc, events_preproc, events_proc, time_bin_analysis, time_cell_plots

font = {'tick': 12,
        'label': 14,
        'annot': 12,
        'fig': 16}

# Colors
n = 4
c = 2
colors = [sns.color_palette('Blues', n)[c], 
          sns.color_palette('Reds', n)[c], 
          sns.color_palette('Greens', n)[c],
          sns.color_palette('Purples', n)[c],
          sns.color_palette('Oranges', n)[c],
          sns.color_palette('Greys', n)[c],
          sns.color_palette('YlOrBr', n+3)[c],
          'k']
cmap = sns.palettes.blend_palette((colors[0], 
                                   'w',
                                   colors[1]), 501)

colws = od([('1', 6.55),
            ('2-1/2', 3.15),
            ('2-1/3', 2.1),
            ('2-2/3', 4.2),
            ('3', 2.083),
            ('4', 1.525),
            ('5', 1.19),
            ('6', 0.967),
            (1, 2.05),
            (2, 3.125),
            (3, 6.45),
            ('nat1w', 3.50394),
            ('nat2w', 7.20472),
            ('natl', 9.72441)])

proj_dir = '/home1/dscho/projects/time_cells'

In [4]:
# Get sessions.
sessions = np.unique([op.basename(f).split('-')[0] 
                      for f in glob(op.join(proj_dir, 'analysis', 'events', '*.pkl'))])
print('{} subjects, {} sessions'.format(len(np.unique([x.split('_')[0] for x in sessions])), len(sessions)))

10 subjects, 12 sessions


In [6]:
load_single_file = False

start_time = time()

# Load all time OLS result files.
filename = op.join(proj_dir, 'analysis', 'unit_to_behav', 'tmp',
                   'ols-time_bin-model_pairs-433units.pkl')
if load_single_file and op.exists(filename):
    ols_pairs = dio.open_pickle(filename)
else:
    ols_pairs_files = glob(op.join(proj_dir, 'analysis', 'unit_to_behav', '*-time_bin-model_pairs.pkl'))
    print('Found OLS outputs for {} neurons'.format(len(ols_pairs_files)))

    warnings.filterwarnings('ignore')
    bad_files = []
    ols_pairs = pd.DataFrame([])
    for filename in ols_pairs_files:
        try:
            ols_pairs = pd.concat((ols_pairs, dio.open_pickle(filename)))
        except:
            bad_files.append(filename)
    ols_pairs = ols_pairs.sort_values(['subj_sess', 'neuron']).reset_index(drop=True)
    warnings.resetwarnings()

    # Restrict dataframe columns to those that we want to keep.
    keep_cols = ['subj_sess', 'neuron', 'gameState', 'testvar', 'full', 'llf_full', 'lr', 'z_lr', 'emp_pval']
    ols_pairs = ols_pairs[keep_cols].reset_index(drop=True)
    
    # Remove variables that we don't care about.
    exclude_vars = ['is_moving', 'dig_performed']
    ols_pairs = ols_pairs.query("(testvar!={})".format(exclude_vars)).reset_index(drop=True)
    
    # Organize categorical columns.
    test_vars = ['time', 'place', 'head_direc', 'base_in_view', 'gold_in_view']
    test_var_cat = pd.CategoricalDtype(test_vars, ordered=True)
    ols_pairs['testvar'] = ols_pairs['testvar'].astype(test_var_cat)

    # Add new columns.
    roi_map2 = spike_preproc.roi_mapping(n=2)
    roi_map3 = spike_preproc.roi_mapping(n=3)
    roi_map4 = spike_preproc.roi_mapping(n=4)
    roi_map5 = spike_preproc.roi_mapping(n=5)
    val_map = od([('hem', []),
                  ('roi', []),
                  ('roi_gen2', []),
                  ('roi_gen3', []),
                  ('roi_gen4', []),
                  ('roi_gen5', []),
                  ('spike_mat'        , []),
                  ('mean_frs'         , []),
                  ('sem_frs'          , []),
                  ('fr_mean'          , []),
                  ('fr_max'           , []),
                  ('fr_max_ind'       , []),
                  ('sparsity'         , [])])
    for idx, row in ols_pairs.iterrows():
        if 'event_spikes' not in dir():
            event_spikes = time_bin_analysis.load_event_spikes(row['subj_sess'], verbose=False)
        elif event_spikes.subj_sess != row['subj_sess']:
            event_spikes = time_bin_analysis.load_event_spikes(row['subj_sess'], verbose=False)
        hemroi = spike_preproc.roi_lookup(row['subj_sess'], row['neuron'].split('-')[0])
        hem = hemroi[0]
        roi = hemroi[1:]
        roi_gen2 = roi_map2.get(roi, np.nan)
        roi_gen3 = roi_map3.get(roi, np.nan)
        roi_gen4 = roi_map4.get(roi, np.nan)
        roi_gen5 = roi_map5.get(roi, np.nan)
        spike_mat = event_spikes.get_spike_mat(row['neuron'], row['gameState'])
        mean_frs = time_bin_analysis.get_mean_frs(spike_mat) * 2
        sem_frs = time_bin_analysis.get_sem_frs(spike_mat) * 2
        fr_mean = np.mean(mean_frs.values)
        fr_max = np.max(mean_frs.values)
        fr_max_ind = np.argmax(mean_frs.values)
        sparsity = time_bin_analysis.get_sparsity(spike_mat)

        val_map['hem'].append(hem)
        val_map['roi'].append(roi)
        val_map['roi_gen2'].append(roi_gen2)
        val_map['roi_gen3'].append(roi_gen3)
        val_map['roi_gen4'].append(roi_gen4)
        val_map['roi_gen5'].append(roi_gen5)
        val_map['spike_mat'].append(spike_mat.values.tolist())
        val_map['mean_frs'].append(mean_frs.tolist())
        val_map['sem_frs'].append(sem_frs.tolist())
        val_map['fr_mean'].append(fr_mean)
        val_map['fr_max'].append(fr_max)
        val_map['fr_max_ind'].append(fr_max_ind)
        val_map['sparsity'].append(sparsity)
    
    ols_pairs.insert(0, 'subj', ols_pairs['subj_sess'].apply(lambda x: x.split('_')[0]))
    ols_pairs.insert(2, 'subj_sess_unit', ols_pairs.apply(lambda x: '{}-{}'.format(x['subj_sess'], x['neuron']), axis=1))
    
    # Insert hemisphere and region info.
    ols_pairs.insert(4, 'hem', val_map['hem'])
    ols_pairs.insert(5, 'roi', val_map['roi'])
    ols_pairs.insert(6, 'roi_gen2', val_map['roi_gen2'])
    ols_pairs.insert(7, 'roi_gen3', val_map['roi_gen3'])
    ols_pairs.insert(8, 'roi_gen4', val_map['roi_gen4'])
    ols_pairs.insert(9, 'roi_gen5', val_map['roi_gen5'])
    for col_name in val_map:
        if col_name not in ols_pairs:
            ols_pairs[col_name] = val_map[col_name]
    ols_pairs['roi_gen2'] = ols_pairs['roi_gen2'].astype(pd.CategoricalDtype(['MTL', 'Cortex'], ordered=True))
    ols_pairs['roi_gen3'] = ols_pairs['roi_gen3'].astype(pd.CategoricalDtype(['Hippocampus', 'MTL', 'Cortex'], ordered=True))
    ols_pairs['roi_gen4'] = ols_pairs['roi_gen4'].astype(pd.CategoricalDtype(['Hippocampus', 'MTL', 'Frontal', 'Cortex'], ordered=True))
    ols_pairs['roi_gen5'] = ols_pairs['roi_gen5'].astype(pd.CategoricalDtype(['Hippocampus', 'MTL', 'Frontal', 'Temporal', 'Cortex'], ordered=True))
    
    # Test significance.
    alpha = 0.05
    ols_pairs['sig'] = False
    ols_pairs.loc[((ols_pairs['emp_pval']<alpha)), 'sig'] = True
    ols_pairs['sig_holm'] = ols_pairs.groupby(['subj_sess_unit', 'gameState', 'full'])['emp_pval'].transform(lambda x: sm.stats.multipletests(x, alpha, method='holm')[0])

    print('{} bad files'.format(len(bad_files)))

# Get rid of the random cortex neurons (insula, occipital) in 1/10 subjects.
ols_pairs = ols_pairs.query("(roi_gen5!='Cortex')").reset_index(drop=True)
ols_pairs['roi_gen5'] = ols_pairs['roi_gen5'].astype(pd.CategoricalDtype(['Hippocampus', 'MTL', 'Frontal', 'Temporal'], ordered=True))

# Find the highest firing rate time bin.
def _fr_max_ind(spike_mat, bins=10):
    return np.argmax([v.sum() for v in np.split(np.sum(spike_mat, axis=0), bins)])

icol = ols_pairs.columns.tolist().index('fr_max_ind') + 1
ols_pairs.insert(icol, 'fr_max_ind4', ols_pairs['spike_mat'].apply(lambda x: _fr_max_ind(x, 4)))
ols_pairs.insert(icol+1, 'fr_max_ind5', ols_pairs['spike_mat'].apply(lambda x: _fr_max_ind(x, 5)))
ols_pairs.insert(icol+2, 'fr_max_ind10', ols_pairs['spike_mat'].apply(lambda x: _fr_max_ind(x, 10)))

print('ols_pairs:', ols_pairs.shape)

print('Done in {:.1f}s'.format(time() - start_time))

Found OLS outputs for 476 neurons
0 bad files
ols_pairs: (5236, 29)
Done in 1527.1s


In [21]:
# Save OLS pairs for all sessions as a single dataframe.
save_output = 1
overwrite = 1

if save_output:
    filename = op.join(proj_dir, 'analysis', 'unit_to_behav', 'all_units_time',
                       'ols-time_bin-model_pairs-{}units.pkl'
                       .format(ols_pairs['subj_sess_unit'].unique().size))
    if overwrite or not op.exists(filename):
        dio.save_pickle(ols_pairs, filename)

Saved /home1/dscho/projects/time_cells/analysis/unit_to_behav/tmp/ols-time_bin-model_pairs-457units.pkl


In [9]:
spike_preproc.roi_mapping(n=5)

{'A': 'MTL',
 'AC': 'Frontal',
 'AH': 'Hippocampus',
 'AI': 'Cortex',
 'EC': 'MTL',
 'FOp': 'Frontal',
 'FOP': 'Frontal',
 'FSG': 'Temporal',
 'HGa': 'Temporal',
 'MFG': 'Frontal',
 'MH': 'Hippocampus',
 'O': 'Cortex',
 'OF': 'Frontal',
 'PHG': 'MTL',
 'PI-SMG': 'Frontal',
 'pSMA': 'Frontal',
 'TO': 'Temporal',
 'TP': 'Temporal',
 'TPO': 'Temporal'}