Differentiation in the visual behavior dataset

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

from os import path
from glob import glob
import pickle
import itertools

import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import cm
from IPython.display import display
from IPython.utils.capture import capture_output
from tqdm.auto import tqdm
with capture_output():
    tqdm.pandas()
import h5py

from differentiation import spectral_differentiation as specD

from ipympl.backend_nbagg import Canvas
Canvas.header_visible.default_value = False

In [2]:
data_directory = '/allen/programs/braintv/workgroups/tiny-blue-dot/differentiation/refactor/behavior/'

In [3]:
region_sets = {
    'VisCtx' : ['VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'],
    'HVAs' : ['VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'],
    'THx_VISp' : ['LGd', 'LP', 'TH', 'VISp'],
    'AllVis' : ['LGd', 'LP', 'TH', 'VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'],
    'THx' : ['LGd', 'LP', 'TH'],
    'hipp' : ['CA', 'CA1', 'CA2', 'CA3', 'DG', 'DG-mo', 'DG-po', 'DG-sg'],
}

relevant_regions = [
    'LGd', 'LP', 'TH', 'VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'
]

layer_order = ['L2/3', 'L4', 'L5']

hierarchy = {
    'Input' : -100,
    'stimulus' : -100,
    'Stim' : -100,
    'TH' : -10,
    'LG' : -9,
    'LGv' : -8,
    'LGd' : -7,
    'LP' : -6,
    'THx' : -5,
    'THx_VISp' : -4,
    'VISp' : 0,
    'VISpl' : 2,
    'VISl' : 4,
    'VISli' : 6,
    'VISrl' : 8,
    'VISal' : 10,
    'VISpm' : 12,
    'VISam' : 14,
    'VISpor' : 16,
    'VISa' : 18,
    'SC' : 24,
    'VISmma' : 20,
    'VISmmp' : 20,
    'VIS' : 22,
    'HVAs' : 21,
    'VisCtx' : 21.5,
    'AllVis' : 22,
    'PF' : 25,
    'MB' : 30,
    'hipp' : 38,
    'CAx' : 39,
    'CA' : 40,
    'CA1' : 41,
    'CA2' : 42,
    'CA3' : 43,
    'DG' : 50,
}

# groups of units to apply differentiation to
unit_set = []
unit_set  = ['region == "%s"'%reg for reg in [
    'VISp', 'VISl', 'VISal', 'VISam', 'VISpm', 'VISrl', 'LGd', 'LP'
]]
unit_set += ['region in @region_sets.get("%s")'%s for s in region_sets.keys()]

unit_set += ['layer == "%s" & region == "%s"'%(l, r) for l, r in itertools.product(
    ['L2/3', 'L4', 'L5'], ['VISp', 'VISl', 'VISal', 'VISpm', 'VISam', 'VISrl']
)]
unit_set += ['layer == "%s" & region in @region_sets.get("%s")'%(l, s) for l, s in itertools.product(
    ['L2/3', 'L4', 'L5'], ['HVAs', 'VisCtx']
)]
print(f'Applying to {len(unit_set)} groups.')

Applying to 38 groups.


# Download and organize data

In [4]:
raw_data_path = path.join(data_directory, 'vis_behavior_npx.hdf5')
with h5py.File(raw_data_path, 'r') as f:
    mouse = '09132019_461027'
    print(f.keys())
    print(f[mouse].keys())
    lick_times = f[mouse]['lickTimes'][:]
    flash_times = f[mouse]['behaviorFlashTimes'][:]

<KeysViewHDF5 ['03212019_409096', '03262019_417882', '03272019_417882', '04042019_408528', '04052019_408528', '04102019_408527', '04112019_408527', '04252019_421323', '04262019_421323', '04302019_422856', '05162019_423749', '05172019_423749', '06072019_427937', '06122019_423745', '07112019_429084', '07122019_429084', '08082019_423744', '08092019_423744', '08132019_423750', '08142019_423750', '09052019_459521', '09062019_459521', '09122019_461027', '09132019_461027']>
<KeysViewHDF5 ['behaviorChangeTimes', 'behaviorFlashTimes', 'behaviorOmitFlashTimes', 'behaviorRunDx', 'behaviorRunSpeed', 'behaviorRunTime', 'ccfRegion', 'changeImage', 'flashImage', 'inCortex', 'initialImage', 'isiRegion', 'lickTimes', 'omitFlashImage', 'passiveChangeTimes', 'passiveFlashTimes', 'passiveOmitFlashTimes', 'passiveRunDx', 'passiveRunSpeed', 'passiveRunTime', 'response', 'rewardTimes', 'sdfs', 'spikeTimes', 'units']>


In [5]:
sampling_rate = 200
win = np.exp(-(np.arange(11)-5)**2/4)

def get_mouse_ids():
    with h5py.File(raw_data_path, 'r') as f:
        mouse_ids = list(f.keys())
    return mouse_ids

mouse_ids = get_mouse_ids()

def get_mouse_data(mouse_id):
    with h5py.File(raw_data_path, 'r') as f:
        data = {}
        
        omissions = pd.DataFrame(np.array([
            f[mouse_id]['behaviorOmitFlashTimes'][:],
            [x.decode() for x in f[mouse_id]['omitFlashImage'][:]]
        ]).T, columns=['time', 'image']).set_index('time')
        omissions['type'] = 'omission'
        omissions['session_type'] = 'active'
        try:
            pomissions = pd.DataFrame(np.array([
                f[mouse_id]['passiveOmitFlashTimes'][:],
                [x.decode() for x in f[mouse_id]['omitFlashImage'][:]]
            ]).T, columns=['time', 'image']).set_index('time')
            pomissions['type'] = 'omission'
            pomissions['session_type'] = 'passive'
            omissions = omissions.append(pomissions)
        except:
            pass
        
        changes = pd.DataFrame(np.array([
            f[mouse_id]['behaviorChangeTimes'][:],
            f[mouse_id]['rewardTimes'][:],
            [x.decode() for x in f[mouse_id]['changeImage'][:]],
            [x.decode() for x in f[mouse_id]['initialImage'][:]],
            [x.decode() for x in f[mouse_id]['response'][:]]
        ]).T, columns=[
            'time', 'reward_time', 'image', 'old_image', 'response'
        ]).set_index('time')
        changes['type'] = 'change'
        changes['session_type'] = 'active'
        try:
            pchanges = pd.DataFrame(np.array([
                f[mouse_id]['passiveChangeTimes'][:],
                [x.decode() for x in f[mouse_id]['changeImage'][:]],
                [x.decode() for x in f[mouse_id]['initialImage'][:]],
                [x.decode() for x in f[mouse_id]['response'][:]]
            ]).T, columns=[
                'time', 'image', 'old_image', 'response'
            ]).set_index('time')
            pchanges['type'] = 'change'
            pchanges['session_type'] = 'passive'
            changes = changes.append(pchanges)
        except:
            pass
        
        flashes = pd.DataFrame(np.array([
            f[mouse_id]['behaviorFlashTimes'][:],
            [x.decode() for x in f[mouse_id]['flashImage'][:]]
        ]).T, columns=['time', 'image']).set_index('time')
        flashes['type'] = 'flash'
        flashes['session_type'] = 'active'
        try:
            pflashes = pd.DataFrame(np.array([
                f[mouse_id]['passiveFlashTimes'][:],
                [x.decode() for x in f[mouse_id]['flashImage'][:]]
            ]).T, columns=['time', 'image']).set_index('time')
            pflashes['type'] = 'flash'
            pflashes['session_type'] = 'passive'
            flashes = flashes.append(pflashes)
        except:
            pass
        
        behavior = pd.concat([omissions, changes, flashes])
        behavior.index = np.round(behavior.index.astype('float32'), 3)
        behavior = behavior.sort_index()
        data['behavior'] = behavior
        
        data['running'] = pd.DataFrame(np.array([
            f[mouse_id]['behaviorRunTime'][:],
            f[mouse_id]['behaviorRunSpeed'][:],
            f[mouse_id]['behaviorRunDx'][:],
        ]).T, columns=['time', 'speed', 'dx']).set_index('time')
        try:
            prun = pd.DataFrame(np.array([
                f[mouse_id]['passiveRunTime'][:],
                f[mouse_id]['passiveRunSpeed'][:],
                f[mouse_id]['passiveRunDx'][:],
            ]).T, columns=['time', 'speed', 'dx']).set_index('time')
            data['running'] = data['running'].append(prun)
        except:
            pass
        
        data['lick_times'] = f[mouse_id]['lickTimes'][:]
        
        ccfRegions = {}
        inCortex = {}
        spikeTimes = {}
        for probe in f[mouse_id]['ccfRegion'].keys():
            ccfRegions[probe] = np.array([
                x.decode() for x in f[mouse_id]['ccfRegion'][probe][:]
            ])
            inCortex[probe] = f[mouse_id]['inCortex'][probe][:]
            spikeTimes[probe] = {}
            for unit in f[mouse_id]['units'][probe][:]:
                spikeTimes[probe][unit.decode()] = f[mouse_id]['spikeTimes'][probe][unit][:]
        
        # binarize spiking data
        spikes = {}
        for probe in spikeTimes.keys():
            for i, unit in enumerate(spikeTimes[probe].keys()):
                spikes[(probe, unit, i)] = spikeTimes[probe][unit][:, 0]
        spikes = pd.Series(spikes).rename_axis(['probe', 'unit', 'unit_idx'])
        print(mouse_id, ': total number of units = ', len(spikes))

        # extract units metadata
        units = spikes.index.to_frame(index=False)
        areas = units.apply(lambda r: ccfRegions[r[0]][r[2]], axis=1)
        inCtx = units.apply(lambda r: inCortex[r[0]][r[2]], axis=1)
        layers = areas.str.extract('(\d.*)')[0].fillna('').rename('layer')
        areas = areas.str.rstrip('12/3456ab').fillna('')
        units['area'] = areas
        units['layer'] = layers
        units['inCtx'] = inCtx
        units = units.drop('unit_idx', axis=1)
        data['units'] = units
        spikes.index = pd.MultiIndex.from_frame(units)
        data['spikes'] = spikes
        
        # compute firing rates via convolution with a gaussian (see top of cell for window function)
        n_units = len(spikes)
        maxtime = spikes.apply(lambda x: max(x)).max().round(4)
        maxtimems = np.rint(maxtime*1000).astype(int)#+1
        if not path.exists(f'data/sessions_uint8fr/{mouse_id}.npy'):
            frdata = np.zeros((
                n_units,
                np.rint(maxtime*sampling_rate).astype(int)
            ), dtype='uint8')
            for unit in range(n_units):
                st_int = np.array(spikes.iloc[unit]*1000, dtype=int)
                fr = np.zeros(maxtimems, dtype='uint8')
                fr[st_int[st_int<maxtimems]] = 1
                sample_rate = int(1000/sampling_rate)
                frdata[unit] = (250*np.convolve(
                    fr, win, mode='same'
                )).astype('uint8')[int(sample_rate/2)::sample_rate][:frdata.shape[1]]
            np.save(f'data/sessions_uint8fr/{mouse_id}.npy', frdata)
        frdata = np.load(f'data/sessions_uint8fr/{mouse_id}.npy', mmap_mode='r')
        data['fr'] = pd.DataFrame(
            frdata, index=pd.MultiIndex.from_frame(units),
            columns=np.linspace(0, maxtimems, frdata.shape[1], endpoint=False)/1000
        )
    return data

In [6]:
# compute firing rates for all sessions in the dataset
for mouse_id in tqdm(mouse_ids):
    if path.exists(path.join(data_directory, f'spikes_{mouse_id}.pkl')):
        continue
    behavior, running, lick_times, units, spikes, fr = get_mouse_data(mouse_id).values()
    spikes.to_pickle(path.join(data_directory, f'spikes_{mouse_id}.pkl'))

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))




In [7]:
def compute_differentiation(
    mouse_id, unit_set, norm, function, sampling_rate=200,
    window_length=0.6, state_length=0.1
):
    units_bar = tqdm(range(len(unit_set)), desc=mouse_id)
    
    differentiation = []
    
    # load all data
    behavior, running, lick_times, up, spikes, fr = get_mouse_data(mouse_id).values()
    
    # add licking and behavior data to firing rate frame
    behavior = behavior.groupby('time', group_keys=False).apply(
        lambda df: df[df.type=='change'] if len(df)>1 else df
    )
    behavior['stim_id'] = range(len(behavior))

    licks = pd.Series(False, index=fr.columns, name='lick').rename_axis('time')
    idx = pd.Index(lick_times).reindex(licks.index, method='ffill', limit=1)
    licks.loc[idx[0][idx[1]>-1]] = True
    fr.columns = pd.MultiIndex.from_frame(
        licks.reset_index().set_index('time', drop=False).join(
            running.reindex(
                licks.index, method='nearest'
            ).speed.rolling(1000, center=True).mean().bfill().ffill()
        ).join(
            behavior.reindex(
                licks.index, method='ffill',
                limit=int(window_length*sampling_rate)
            )
        )
    )

    t = fr.columns.to_frame(index=False)
    fr = fr[t[~t.type.isna()].time]
    
    idx = fr.columns.to_frame(index=False)
    idx = idx.groupby(
        'stim_id', group_keys=False
    ).apply(lambda df: df.iloc[:int(window_length*sampling_rate)])
    _fr = fr[pd.MultiIndex.from_frame(idx)]
    
    up = up.rename({'area':'region'}, axis=1)
    up['layer'] = up.layer.apply(lambda x: f'L{x}' if len(x)>0 else '')
    
    # compute differentiation
    for i, units_name in enumerate(unit_set):
        units = up[up.eval(unit_set[i])]
        if len(units) == 0:
            units_bar.update()
            continue

        # extract firing rate for selected units
        unit_fr = _fr.iloc[units.index].T

        # normalize firing rates depending on 'norm'
        if norm=='cohort_full_ts':
            unit_fr = unit_fr / unit_fr.values.mean()
        else:
            raise ValueError(f'normalization {norm} not implemented.')

        # reshape for single shot spectral differentiation calculation
        unit_fr_local = unit_fr.iloc[
            :round((
                unit_fr.shape[0]//(window_length*sampling_rate)
            )*window_length*sampling_rate)
        ].copy()
        unit_fr_local = np.reshape(
            unit_fr_local.T.values,
            (unit_fr_local.shape[1], -1, int(window_length*sampling_rate))
        ).transpose(1, 0, 2)

        # compute spectral differentiation
        df = function(
            unit_fr_local, sample_rate=sampling_rate, window_length=state_length
        )

        # get median differentiation
        df = np.median(df, axis=1)

        # put it into a nice series indexed by time
#         times = unit_fr.index.get_level_values('time')
#         times = np.linspace(times[0], times[-1], df.shape[0], False)
#         times = times + np.diff(times).mean()/2
        times = unit_fr.groupby('stim_id').apply(
            lambda df: df.index.get_level_values('time')[0]
        )
        df = pd.Series(
            df, index=times,
            name=(window_length, state_length,
                  f'{units_name} & n_units = {len(units)}')
        )
        idx = fr.columns.to_frame(index=False).set_index('time')
        idx['lick'] = idx.lick.rolling(20).sum()
        idx = idx.reindex(df.index, method='nearest')
        df.index = pd.MultiIndex.from_frame(
            idx.rename_axis('time').reset_index()
        )
        differentiation.append(df)
        units_bar.update()
    if len(differentiation)==0:
        return pd.DataFrame()
    differentiation = pd.concat(differentiation, axis=1)
    differentiation = differentiation.sort_index(axis=1)
    differentiation = differentiation.rename_axis(
        columns=['window_length', 'state_length', 'region']
    ).droplevel([0, 1], axis=1)
    return differentiation

# utility functions for properly renaming the columns of the differentiation dataframe
def get_unit_filters(units):
    filter_strings = units.split(' & ')
    filters = {}
    for filt in filter_strings:
        if ' = ' in filt:
            filters[filt.split(' = ')[0]] = filt.split(' = ')[1].strip('""')
        if '==' in filt:
            filters[filt.split(' == ')[0]] = filt.split(' == ')[1].strip('""')
        if '>' in filt:
            filters[filt.split(' > ')[0]] = float(filt.split(' > ')[1])
        if '@' in filt:
            key = filt.split(' ')[0]
            value = filt.split('get')[1].split('"')[1]
            filters[key] = value
    return filters

def rename_columns(c):
    props = get_unit_filters(c.name)
    return dict(
        layer=props.get('layer', '-'),
        area=props.get('region', '-'),
        n_units=int(props['n_units'])
    )

In [8]:
# compute differentiation for all mice
win_ms = 300 # 150/300/600
sta_ms = 60 # 25/60/100
differentiation, n_units = {}, {}
for mouse_id in tqdm(mouse_ids[::-1]):
    fname = path.join(
        data_directory,
        f'spectral_differentiation_{mouse_id}_{win_ms}_{sta_ms}.pkl'
    )
    if not path.exists(fname):
        diffn = compute_differentiation(
            mouse_id, unit_set, 'cohort_full_ts', specD,
            window_length=win_ms/1000, state_length=sta_ms/1000
        )
        try:
            diffn.columns = pd.MultiIndex.from_frame(
                pd.DataFrame(
                    list(
                        diffn.columns.to_frame()
                        .apply(rename_columns, axis=1)
                    )
                )
            )
        except:
            diffn.columns = pd.MultiIndex.from_arrays(
                [['']*len(diffn), ['']*len(diffn), ['']*len(diffn)],
                names=['area', 'layer', 'n_units']
            )
        diffn.to_pickle(fname)
    else:
        diffn = pd.read_pickle(fname)
    
    n_units[mouse_id] = diffn.columns.to_frame(index=False)
    if diffn is not None:
        differentiation[mouse_id] = diffn.droplevel(2, axis=1)
differentiation = pd.concat(
    differentiation,
    names=[
        'expt', 'time', 'lick', 'running_speed', 'image', 'type',
        'session_type', 'reward_time', 'old_image', 'response', 'stim_id'
    ]
)

idx = differentiation.index.to_frame(index=False)
idx['mouse'] = idx.expt.apply(lambda x: x.split('_')[1])
idx['date'] = idx.expt.apply(lambda x: x.split('_')[0])
_idx = idx[['mouse', 'date']].drop_duplicates()
_idx = _idx.join(_idx.groupby('mouse', group_keys=False).apply(
    lambda df: df.date.apply(lambda c: list(df.date.unique()).index(c))
).rename('sid')).set_index(['mouse', 'date'])
idx['sid'] = _idx.loc[pd.MultiIndex.from_frame(idx[['mouse', 'date']])].values
idx = idx[[
    'mouse', 'date', 'sid', 'time', 'lick', 'running_speed', 'image',
    'type', 'session_type', 'reward_time', 'old_image', 'response'
]]
differentiation.index = pd.MultiIndex.from_frame(idx)
differentiation = differentiation.sort_index()
idx = differentiation.index.to_frame(index=False)
differentiation

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))




Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,layer,-,-,-,-,-,-,-,-,-,-,...,L4,L4,L5,L5,L5,L5,L5,L5,L5,L5
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,area,AllVis,HVAs,LGd,LP,THx,THx_VISp,VISal,VISam,VISl,VISp,...,VISrl,VisCtx,HVAs,VISal,VISam,VISl,VISp,VISpm,VISrl,VisCtx
mouse,date,sid,time,lick,running_speed,image,type,session_type,reward_time,old_image,response,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2,Unnamed: 31_level_2,Unnamed: 32_level_2
408527,04102019,1,77.110017,,16.503895,im065,flash,active,,,,143995.076363,122873.798786,,,,75924.167922,51711.321109,,9675.406830,75924.167922,...,13971.644066,78666.808810,58427.373628,5641.570307,,12752.492008,1805.963111,104411.192417,5877.420420,55968.075714
408527,04102019,1,77.860017,0.0,17.680735,im065,flash,active,,,,92315.190916,88502.583197,,,,17593.311095,32877.217003,,8996.704857,17593.311095,...,8832.495968,41321.987748,51581.848154,2877.759060,,12129.144896,575.882417,79314.095545,8881.156282,47327.400742
408527,04102019,1,78.610017,0.0,17.784651,im065,flash,active,,,,84975.558637,80377.704806,,,,21375.788752,28175.018012,,5813.560203,21375.788752,...,11426.549122,40873.080762,42176.546978,5979.700650,,7365.526827,490.279280,87824.627332,6090.059133,38710.889985
408527,04102019,1,79.360017,0.0,17.952055,im065,flash,active,,,,70065.770951,67533.321403,,,,15296.092372,28539.871445,,5916.936556,15296.092372,...,10528.206917,38030.738508,39062.618305,6642.951130,,6987.155798,405.905286,65955.123348,7005.739471,35998.270357
408527,04102019,1,80.115017,0.0,15.611155,im065,flash,active,,,,89664.913584,87078.847925,,,,11759.005833,35142.255112,,5539.734848,11759.005833,...,13024.662352,48269.677513,43559.390275,6473.813186,,5332.223415,405.905286,98989.005651,5272.978995,39846.977040
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
461027,09132019,0,9158.000000,0.0,10.850420,im085,flash,passive,,,,24338.693145,20860.015367,4375.303335,,11679.557220,22021.984392,20860.015367,,,24070.374673,...,,25189.088607,11001.789099,11001.789099,,,8223.397782,,,10068.690641
461027,09132019,0,9158.750000,0.0,14.521681,im085,flash,passive,,,,31587.553360,42373.462089,5237.135715,,11441.162817,27254.892898,42373.462089,,,23906.255375,...,,22710.235075,11001.789099,11001.789099,,,6820.561711,,,8787.059787
461027,09132019,0,9159.500000,0.0,20.543412,im085,flash,passive,,,,28254.883681,28258.969800,4569.361060,,11441.976023,24707.955992,28258.969800,,,31625.546199,...,,20535.448975,194.295502,194.295502,,,10097.156491,,,12311.408053
461027,09132019,0,9160.250000,0.0,29.017218,im085,flash,passive,,,,27067.249947,33038.435128,4313.449490,,11306.983736,24121.356126,33038.435128,,,29199.741967,...,,38097.575580,35446.758454,35446.758454,,,7104.464658,,,9784.609921


In [9]:
n_units = pd.concat(
    n_units, names=['mouse', 'idx']
).set_index(
    ['layer', 'area'], append=True
).droplevel('idx')
n_units

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,n_units
mouse,layer,area,Unnamed: 3_level_1
09132019_461027,L2/3,VISal,3
09132019_461027,L2/3,VISp,6
09132019_461027,L2/3,HVAs,3
09132019_461027,L2/3,VisCtx,9
09132019_461027,L4,VISal,2
...,...,...,...
03212019_409096,-,AllVis,277
03212019_409096,-,HVAs,156
03212019_409096,-,THx_VISp,121
03212019_409096,-,VisCtx,277


In [10]:
len(n_units.index.levels[0])

24

In [11]:
print(
    f"{n_units.xs('-', level='layer').xs('AllVis', level='area').mean().values[0]}"
    f" +/- {n_units.xs('-', level='layer').xs('AllVis', level='area').std().values[0]}"
    " neurons per mouse."
)
n_units.xs('-', level='layer').xs('AllVis', level='area').sum()

318.25 +/- 146.97300846547063 neurons per mouse.


n_units    7638
dtype: int64

In [12]:
def compute_mfr(mouse_id, unit_set, kind='mean_wrt_flash', win_size_ms=300):
    units_bar = tqdm(range(len(unit_set)), desc=mouse_id)
    mfr = []
    
    behavior, running, lick_times, up, spikes, fr = get_mouse_data(mouse_id).values()
    # remove flash / change duplicates
    behavior = behavior.groupby('time', group_keys=False).apply(
        lambda df: df[df.type=='change'] if len(df)>1 else df
    )
    behavior['stim_id'] = range(len(behavior))

    # add licking and behavior data to firing rate frame
    licks = pd.Series(False, index=fr.columns, name='lick').rename_axis('time')
    idx = pd.Index(lick_times).reindex(licks.index, method='ffill', limit=1)
    licks.loc[idx[0][idx[1]>-1]] = True
    fr.columns = pd.MultiIndex.from_frame(
        licks.reset_index().set_index('time', drop=False).join(
            running.reindex(
                licks.index, method='nearest'
            ).speed.rolling(1000, center=True).mean().bfill().ffill()
        ).join(
            behavior.reindex(
                licks.index, method='ffill',
                limit=int(0.6*sampling_rate)
            )
        )
    )

    t = fr.columns.to_frame(index=False)
    fr = fr[t[~t.type.isna()].time]

    idx = fr.columns.to_frame(index=False)
    idx = idx.groupby('stim_id', group_keys=False).apply(
        lambda df: df.iloc[:int(0.6*sampling_rate)]
    )
    _fr = fr[pd.MultiIndex.from_frame(idx)]
    if kind=='mean_wrt_flash':
        _mffr = _fr.T[_fr.T.index.get_level_values('type')=='flash']
        _fr = ((_fr.T - _mffr.mean())/_mffr.std()).T
    
    up = up.rename({'area':'region'}, axis=1)
    up['layer'] = up.layer.apply(lambda x: f'L{x}' if len(x)>0 else '')
    
    for i, units_name in enumerate(unit_set):
        units = up[up.eval(unit_set[i])]
        if len(units) == 0:
            units_bar.update()
            continue

        # extract firing rate for selected units
        unit_fr = _fr.iloc[units.index].T
        df = unit_fr.mean(1)
        
        _mfr = {}
        df.groupby('stim_id').apply(lambda g: _mfr.update({
            tuple(g.index.to_frame().iloc[0])[:-1]:g[:win_size_ms//5].mean()
        }))
        _mfr = pd.Series(_mfr, name=f'{units_name} & n_units = {len(units)}')
    
        mfr.append(_mfr)
        units_bar.update()
    if len(mfr)==0:
        return pd.DataFrame()
    mfr = pd.concat(mfr, axis=1)
    mfr = mfr.sort_index(axis=1)
    return mfr
    mfr = mfr.rename_axis(
        columns=['region']
    ).droplevel([0, 1], axis=1)
    return mfr

In [13]:
# compute mfr for all mice for same window parameters as differentiation
# win_ms = 300 # 150/300/600
# sta_ms = 60 # 25/60/100
kind = 'mean_wrt_flash'
mfr = {}
for mouse_id in tqdm(mouse_ids[::-1]):
    fname = path.join(
        data_directory,
        f'behavior_mfr_{mouse_id}_{win_ms}_{sta_ms}.pkl'
    )
    if not path.exists(fname):
        diffn = compute_mfr(
            mouse_id, unit_set, win_size_ms=win_ms, kind='mean_wrt_flash'
        )
        try:
            diffn.columns = pd.MultiIndex.from_frame(
                pd.DataFrame(
                    list(
                        diffn.columns.to_frame()
                        .apply(rename_columns, axis=1)
                    )
                )
            )
        except:
            diffn.columns = pd.MultiIndex.from_arrays(
                [['']*len(diffn), ['']*len(diffn), ['']*len(diffn)],
                names=['area', 'layer', 'n_units']
            )
        diffn.to_pickle(fname)
    else:
        diffn = pd.read_pickle(fname)
    if diffn is not None:
        mfr[mouse_id] = diffn.droplevel(2, axis=1)
mfr = pd.concat(mfr, names=[
    'expt', 'time', 'lick', 'running_speed', 'image', 'type',
    'session_type', 'reward_time', 'old_image', 'response'
])

idx = mfr.index.to_frame(index=False)
idx['mouse'] = idx.expt.apply(lambda x: x.split('_')[1])
idx['date'] = idx.expt.apply(lambda x: x.split('_')[0])
_idx = idx[['mouse', 'date']].drop_duplicates()
_idx = _idx.join(_idx.groupby('mouse', group_keys=False).apply(
    lambda df: df.date.apply(
        lambda c: list(df.date.unique()).index(c)
    )
).rename('sid')).set_index(['mouse', 'date'])
idx['sid'] = _idx.loc[
    pd.MultiIndex.from_frame(idx[['mouse', 'date']])
].values
idx = idx[[
    'mouse', 'date', 'sid', 'time', 'lick', 'running_speed', 'image',
    'type', 'session_type', 'reward_time', 'old_image', 'response'
]]
mfr.index = pd.MultiIndex.from_frame(idx)
mfr = mfr.sort_index()
idx = differentiation.index.to_frame(index=False)
mfr

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))




Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,layer,-,-,-,-,-,-,-,-,-,-,...,L4,L4,L5,L5,L5,L5,L5,L5,L5,L5
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,area,AllVis,HVAs,LGd,LP,THx,THx_VISp,VISal,VISam,VISl,VISp,...,VISrl,VisCtx,HVAs,VISal,VISam,VISl,VISp,VISpm,VISrl,VisCtx
mouse,date,sid,time,lick,running_speed,image,type,session_type,reward_time,old_image,response,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2,Unnamed: 31_level_2,Unnamed: 32_level_2
408527,04102019,1,77.110017,False,16.503895,im065,flash,active,,,,0.192562,0.181356,,,,0.252669,0.343690,,0.579777,0.252669,...,0.228116,0.204801,0.185314,0.160744,,0.660715,0.970135,0.198823,-0.046285,0.203566
408527,04102019,1,77.860017,False,17.680735,im065,flash,active,,,,0.055447,0.060794,,,,0.026768,0.180335,,0.164436,0.026768,...,0.029052,0.057216,0.072113,0.102845,,0.252809,0.373203,0.097658,-0.058077,0.079115
408527,04102019,1,78.610017,False,17.784651,im065,flash,active,,,,0.069885,0.069686,,,,0.070949,0.111654,,0.343702,0.070949,...,0.075617,0.055960,0.122702,0.135395,,0.434950,0.114436,0.119834,-0.021127,0.122510
408527,04102019,1,79.360017,False,17.952055,im065,flash,active,,,,0.020651,0.024535,,,,-0.000178,0.087078,,0.106189,-0.000178,...,0.004004,0.011397,0.042805,0.281510,,0.061480,0.072886,0.055147,-0.057300,0.043504
408527,04102019,1,80.115017,False,15.611155,im065,flash,active,,,,0.050543,0.057281,,,,0.014407,0.144031,,0.079999,0.014407,...,0.129377,0.059712,0.043734,0.263404,,0.044645,0.072474,0.077505,-0.073634,0.044403
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
461027,09132019,0,9158.000000,False,10.850420,im085,flash,passive,,,,0.045082,0.132772,0.108851,,0.042819,0.038526,0.132772,,,0.032611,...,,0.018326,0.102153,0.102153,,,0.037928,,,0.046305
461027,09132019,0,9158.750000,False,14.521681,im085,flash,passive,,,,0.067094,0.156164,0.142916,,0.035864,0.060435,0.156164,,,0.094288,...,,0.134854,0.049745,0.049745,,,0.014555,,,0.019145
461027,09132019,0,9159.500000,False,20.543412,im085,flash,passive,,,,0.028772,0.087531,0.125406,,0.004329,0.024379,0.087531,,,0.052003,...,,0.056017,-0.057436,-0.057436,,,0.028231,,,0.017057
461027,09132019,0,9160.250000,False,29.017218,im085,flash,passive,,,,0.049667,0.183887,0.025980,,0.012673,0.039631,0.183887,,,0.076774,...,,0.151484,0.214721,0.214721,,,0.013366,,,0.039630


In [14]:
def get_hitrate(diffn, ax=None, hrw=20):
    _idx = diffn.index.to_frame(index=False).loc[
        diffn.index.to_frame(index=False).response.dropna().index
    ]
    hitrate = _idx.response.apply(
        lambda x: 1 if x in ['hit'] else 0
    ).rolling(hrw, center=True).mean()
    hitrate.index = pd.MultiIndex.from_frame(_idx)
    if ax is not None:
        hitrate.swaplevel(0, 'time').droplevel(
            list(range(1, hitrate.index.nlevels))
        ).plot(ax=ax, marker='.')
    return hitrate

f, ax = plt.subplots(figsize=(6, 2.6), tight_layout=True)
hitrate = differentiation.groupby(
    ['mouse', 'date'], group_keys=False
).apply(get_hitrate, ax=ax, hrw=50)
hitrate = hitrate.reindex(differentiation.index).ffill().bfill()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [14]:
def bhc(stats, alpha=0.01):
    '''
    stats must have a pval column
    '''
    stats = stats.sort_values('pval')
    stats['rank'] = range(len(stats))
    stats['thresh'] = (stats['rank']+1)/len(stats)*alpha
    idx = stats.index[stats.pval<stats.thresh]
    if len(idx)==0:
        print('No significant values.')
        return None
    else:
        _s = stats[stats.pval<stats.thresh].copy()
        _s['rank2'] = range(len(_s))
        return _s[_s['rank']<=_s['rank2']].drop('rank2', axis=1)

In [23]:
# each point is a stimulation instance
idx = differentiation.index.to_frame(index=False)
idx['response'] = idx.response.replace(['correctReject', 'noChange', 'falseAlarm'], None)#.replace(['correctReject'], None)
idx['is_running'] = True
idx.loc[idx.index[idx.running_speed>0.15], 'is_running'] = False
_differentiation = pd.DataFrame(differentiation, index=pd.MultiIndex.from_frame(idx), copy=True)
is_running = -1

if is_running>-1:
    diffn = _differentiation.groupby(level=['session_type', 'is_running', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
else:
    diffn = _differentiation.groupby(level=['session_type', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
display(diffn['-'].head())

f, _axes = plt.subplots(4, 2, figsize=(12, 7), constrained_layout=True, sharex=True, sharey=True)

for axes, session_type in zip(_axes.T, ['active', 'shuffle']):# active, passive and shuffle options
    for ax, layer in zip(axes, diffn.columns.levels[0]):
        areas = [a for a in hierarchy.keys() if a in diffn[layer].columns]
        st = session_type
        if st=='shuffle':
            st = 'active'
        if is_running > -1:
            _dfn = diffn.loc[(st, is_running), layer]/diffn.loc[st, layer].mean()
        else:
            _dfn = diffn.loc[(st), layer]/diffn.loc[st, layer].mean()
        if session_type=='shuffle':
            _dfn.index = pd.MultiIndex.from_frame(_dfn.index.to_frame().sample(frac=1))
        (
            _dfn.stack()
            .droplevel(1).to_frame().reset_index()
            .rename(columns={'level_1': 'idx', 0: 'differentiation'})
            .pipe(
                (sns.boxplot, 'data'),
                x='area', y='differentiation', hue='response', ax=ax,
                fliersize=2, showfliers=False, order=areas
            )
        )
        ax.legend(fontsize=7)
        ax.set_ylabel(f'{layer}\ndifferentiation')
    axes[0].set_title(f'{session_type} ({"running" if is_running==1 else "resting" if is_running==0 else ""})' , fontsize=9);

Unnamed: 0_level_0,Unnamed: 1_level_0,area,AllVis,HVAs,LGd,LP,THx,THx_VISp,VISal,VISam,VISl,VISp,VISpm,VISrl,VisCtx,hipp
session_type,response,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
active,hit,0,113976.531541,97989.876251,,,,31087.275164,17419.942913,,9315.345402,31087.275164,101756.774015,29976.088098,113976.531541,53568.921104
active,hit,1,104540.860029,93945.063851,,,,39009.630353,40576.390387,,5024.417016,39009.630353,77070.224562,22624.49064,104540.860029,49356.015577
active,hit,2,92509.908538,82681.496595,,,,27345.956785,31919.158883,,9410.272002,27345.956785,98379.55743,33100.226672,92509.908538,53468.856582
active,hit,3,115923.094488,103448.330671,,,,44514.312991,42103.491132,,7857.719251,44514.312991,183297.764375,20721.730481,115923.094488,56222.961091
active,hit,4,104084.852429,101258.307363,,,,10673.35296,30890.172786,,6484.687541,10673.35296,58718.128404,25695.220338,104084.852429,50683.321273


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [16]:
# for each col, get a pval for the difference between hits and misses
# then pass the pval dataframe to bhc with alpha to get corrected pvals

In [24]:
# each point is a stimulation instance
areas = ['VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam', 'HVAs', 'VisCtx']
idx = differentiation.index.to_frame(index=False)
idx['response'] = idx.response.replace(['correctReject', 'noChange', 'falseAlarm'], None)#.replace(['correctReject'], None)
idx['is_running'] = True
idx.loc[idx.index[idx.running_speed>0.15], 'is_running'] = False
_differentiation = pd.DataFrame(differentiation, index=pd.MultiIndex.from_frame(idx), copy=True)
is_running = -1

if is_running>-1:
    diffn = _differentiation.groupby(level=['session_type', 'is_running', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
else:
    diffn = _differentiation.groupby(level=['session_type', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))

with sns.axes_style('white'):
    f, axes = plt.subplots(4, 1, figsize=(4.5, 4.5), constrained_layout=True, sharex=True)

session_type = 'active'
for ax, layer in zip(axes, diffn.columns.levels[0]):
#     areas = [a for a in hierarchy.keys() if a in diffn[layer].columns]
    st = session_type
    if st=='shuffle':
        st = 'active'
    if is_running > -1:
        _dfn = diffn.loc[(st, is_running), layer]#/diffn.loc[st, layer].mean()
    else:
        _dfn = diffn.loc[(st), layer]#/diffn.loc[st, layer].mean()
    stats = _dfn.apply(lambda c: sp.stats.ttest_ind(c['hit'].dropna(), c['miss'].dropna(), equal_var=False)).T
    stats.columns = ['stastic', 'pval']
    stats = bhc(stats, 1e-4)
    if session_type=='shuffle':
        _dfn.index = pd.MultiIndex.from_frame(_dfn.index.to_frame().sample(frac=1))
    (
        _dfn.stack()
        .droplevel(1).to_frame().reset_index()
        .rename(columns={'level_1': 'idx', 0: 'differentiation'})
        .pipe(
            (sns.boxplot, 'data'),
            x='area', y='differentiation', hue='response', ax=ax,
            fliersize=2, showfliers=False, order=areas, width=0.4,
            linewidth=1, palette={'hit':cm.Reds(0.6, 0.6), 'miss':cm.Greys(0.6, 0.6)}
        )
    )
    yt = ax.get_yticks()
    ax.set_yticks([yt[1], yt[-2]])
    yl = ax.get_ylim()[1]
    for i, a in enumerate(areas):
        if a in stats.index:
#             print(i)
            ax.annotate('*', (i, 0.9), xycoords=('data', 'axes fraction'), ha='center')
            ax.plot([i-0.12, i+0.12], [yl*0.9, yl*0.9], c='r')
#     ax.yaxis.set_major_locator(plt.MaxNLocator(2))
#     ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, y: f'{x/1e5}'))
    ax.legend(fontsize=7).set_visible(False)
    ax.set_ylabel(f'{layer if layer!="-" else "all layers"}\ndifferentiation', fontsize=8, labelpad=0)
    ax.set_xlabel('')
    ax.tick_params(axis='both', which='major', labelsize=8, pad=-1)
    ax.yaxis.get_offset_text().set_size(8)
    bottom = True
    if list(axes).index(ax)==len(axes)-1:
        bottom = False
    sns.despine(ax=ax, left=False, bottom=bottom)
# axes[0].set_title(f'{session_type} ({"running" if is_running==1 else "resting" if is_running==0 else ""})' , fontsize=9)
axes[0].set_title('spectral differentiation')
axes[0].legend(loc=(0.06, 0.8), fontsize=8, frameon=False, ncol=1)
f.align_ylabels(axes);
f.savefig('fig_behavior_boxplots.pdf')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
# each point is a stimulation instance
areas = ['VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam', 'HVAs', 'VisCtx']
idx = mfr.index.to_frame(index=False)
idx['response'] = idx.response.replace(['correctReject', 'noChange', 'falseAlarm'], None)#.replace(['correctReject'], None)
idx['is_running'] = True
idx.loc[idx.index[idx.running_speed>0.15], 'is_running'] = False
_mfr = pd.DataFrame(mfr.replace([-np.inf, np.inf], np.nan), index=pd.MultiIndex.from_frame(idx), copy=True)
is_running = -1

if is_running>-1:
    diffn = _mfr.groupby(level=['session_type', 'is_running', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
else:
    diffn = _mfr.groupby(level=['session_type', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))

with sns.axes_style('white'):
    f, axes = plt.subplots(4, 1, figsize=(4.5, 4.5), constrained_layout=True, sharex=True)

session_type = 'active'
for ax, layer in zip(axes, diffn.columns.remove_unused_levels().levels[0]):
#     areas = [a for a in hierarchy.keys() if a in diffn[layer].columns]
    st = session_type
    if st=='shuffle':
        st = 'active'
    if is_running > -1:
        _dfn = diffn.loc[(st, is_running), layer]#/diffn.loc[st, layer].mean()
    else:
        _dfn = diffn.loc[(st), layer]#/diffn.loc[st, layer].mean()
    stats = _dfn.apply(lambda c: sp.stats.ttest_ind(c['hit'].dropna(), c['miss'].dropna(), equal_var=False)).T
    stats.columns = ['stastic', 'pval']
    stats = bhc(stats, 1e-4)
    if session_type=='shuffle':
        _dfn.index = pd.MultiIndex.from_frame(_dfn.index.to_frame().sample(frac=1))
    (
        _dfn.stack()
        .droplevel(1).to_frame().reset_index()
        .rename(columns={'level_1': 'idx', 0: 'differentiation'})
        .pipe(
            (sns.boxplot, 'data'),
            x='area', y='differentiation', hue='response', ax=ax,
            fliersize=2, showfliers=False, order=areas, width=0.4,
            linewidth=1, palette={'hit':cm.Reds(0.6, 0.6), 'miss':cm.Greys(0.6, 0.6)}
        )
    )
    yt = ax.get_yticks()
    ax.set_yticks([yt[1], yt[-2]])
    yl = ax.get_ylim()[1]
    if stats is not None:
        for i, a in enumerate(areas):
            if a in stats.index:
    #             print(i)
                ax.annotate('*', (i, 0.9), xycoords=('data', 'axes fraction'), ha='center')
                ax.plot([i-0.12, i+0.12], [yl*0.9, yl*0.9], c='r')
#     ax.yaxis.set_major_locator(plt.MaxNLocator(2))
#     ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, y: f'{x/1e5}'))
    ax.legend(fontsize=7).set_visible(False)
    ax.set_ylabel(f'{layer}\ndifferentiation', fontsize=8, labelpad=0)
    ax.set_xlabel('')
    ax.tick_params(axis='both', which='major', labelsize=8, pad=-1)
    ax.yaxis.get_offset_text().set_size(8)
    bottom = True
    if list(axes).index(ax)==len(axes)-1:
        bottom = False
    sns.despine(ax=ax, left=False, bottom=bottom)
# axes[0].set_title(f'{session_type} ({"running" if is_running==1 else "resting" if is_running==0 else ""})' , fontsize=9)
axes[0].set_title('firing rate')
axes[0].legend(loc=2, fontsize=9, frameon=False, ncol=2)
f.align_ylabels(axes);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [25]:
areas = ['LP', 'THx', 'VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam', 'HVAs', 'VisCtx']#, 'AllVis', 'hipp']
idx = differentiation.index.to_frame(index=False)
idx['response'] = idx.response.replace(['correctReject', 'noChange', 'falseAlarm'], None)
idx['is_running'] = True
idx.loc[idx.index[idx.running_speed>0.15], 'is_running'] = False
_differentiation = pd.DataFrame(differentiation, index=pd.MultiIndex.from_frame(idx), copy=True)

def get_dfn_stats(_differentiation, is_running=-1, pval=1e-4):
    if is_running>-1:
        diffn = _differentiation.groupby(level=['session_type', 'is_running', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
    else:
        diffn = _differentiation.groupby(level=['session_type', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
    areas = [a for a in hierarchy.keys() if a in diffn[layer].columns]
    if is_running > -1:
        _dfn = diffn.loc[(st, is_running)]#/diffn.loc[st, layer].mean()
    else:
        _dfn = diffn.loc[(st)]#/diffn.loc[st, layer].mean()
    stats = _dfn.apply(lambda c: sp.stats.ttest_ind(c['hit'].dropna(), c['miss'].dropna(), equal_var=False)).T
    stats.columns = ['stastic', 'pval']
    stats = bhc(stats, pval)
    return _dfn, stats

f, axes = plt.subplots(2, 1, figsize=(3.5, 2.4), sharex=True, constrained_layout=True)
cbar_ax = f.add_axes([0.88, .2, .03, .75])
f.tight_layout(rect=[0.04, 0.04, 0.9, 1])

for ax, running in zip(axes, [0, 1]):
    _dfn, stats = get_dfn_stats(_differentiation, is_running=running)
    _dfn = _dfn.groupby('response').mean().diff(-1).stack().loc['hit'].T
    areas = [a for a in hierarchy.keys() if a in areas]
    _stats = pd.DataFrame(index=_dfn.index, columns=_dfn.columns)
    _stats.loc[stats.pval.unstack().index, stats.pval.unstack().columns] = stats.pval.unstack()
    _stats[~_stats.isna()] = '*'
#     display(_dfn[areas])
#     display(_stats[areas])
    sns.heatmap(
        _dfn.T.reindex(areas).T, cmap='RdBu_r', robust=True, center=0, ax=ax,
        vmin=-10000, vmax=40000, annot=_stats.T.reindex(areas).T.fillna(''), fmt='s',
        cbar_kws={'ticks':[-10000, 40000], 'format':'%.0e'},
        cbar_ax=cbar_ax, annot_kws={"size":7}
    )
    cbar_ax.yaxis.label.set_size(7)
    ax.set_ylabel(f'{"running" if running else "resting"}', fontsize=7, labelpad=0.5)
    ax.set_xlabel('')
    ax.set_yticks([0.5, 1.5, 2.5, 3.5])
    ax.set_yticklabels(['all', 'L2/3', 'L4', 'L5'], rotation=0, fontsize=7)
    ax.tick_params(left=False, bottom=False, axis='both', which='major', labelsize=7, pad=-1)

cbar_ax.tick_params(labelsize=6, pad=-2)
cbar_ax.set_ylabel(
    '$\Delta$ differentiation (hit - miss)', labelpad=-10
)
ax.set_xticks(range(len(areas)))
ax.set_xticklabels(areas, rotation=45, fontsize=7)#, ha='right', rotation_mode='anchor')
f.suptitle(f'differentiation ({win_ms}ms post stimulation)', fontsize=8);
f.savefig('fig_behavior_delta_df.pdf')
# f.subplots_adjust(0.1, 0.1, 0.88, 0.9);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



In [20]:
idx = mfr.index.to_frame(index=False)
idx['response'] = idx.response.replace(['correctReject', 'noChange', 'falseAlarm'], None)
idx['is_running'] = True
idx.loc[idx.index[idx.running_speed>0.15], 'is_running'] = False
_mfr = pd.DataFrame(mfr.replace([np.inf, -np.inf], np.nan), index=pd.MultiIndex.from_frame(idx), copy=True)

def get_dfn_stats(_mfr, is_running=-1, pval=1e-4):
    if is_running>-1:
        diffn = _mfr.groupby(level=['session_type', 'is_running', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
    else:
        diffn = _mfr.groupby(level=['session_type', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
    areas = [a for a in hierarchy.keys() if a in diffn[layer].columns]
    if is_running > -1:
        _dfn = diffn.loc[(st, is_running)]#/diffn.loc[st, layer].mean()
    else:
        _dfn = diffn.loc[(st)]#/diffn.loc[st, layer].mean()
    stats = _dfn.apply(lambda c: sp.stats.ttest_ind(c['hit'].dropna(), c['miss'].dropna(), equal_var=False)).T
    stats.columns = ['stastic', 'pval']
    stats = bhc(stats, pval)
    return _dfn, stats

f, axes = plt.subplots(2, 1, figsize=(3.5, 2.4), sharex=True, constrained_layout=True)
cbar_ax = f.add_axes([0.88, .2, .03, .75])
f.tight_layout(rect=[0.04, 0.04, 0.9, 1])

for ax, running in zip(axes, [0, 1]):
    _dfn, stats = get_dfn_stats(_mfr, is_running=running)
    _dfn = _dfn.groupby('response').mean().diff(-1).stack().loc['hit'].T
    areas = [a for a in hierarchy.keys() if a in ['LP', 'THx', 'VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam', 'HVAs', 'VisCtx']]#, 'AllVis', 'hipp']]
    _stats = pd.DataFrame(index=_dfn.index, columns=_dfn.columns)
    _stats.loc[stats.pval.unstack().index, stats.pval.unstack().columns] = stats.pval.unstack()
    _stats[~_stats.isna()] = '*'
#     display(_dfn[areas])
#     display(_stats[areas])
    sns.heatmap(
        _dfn.T.reindex(areas).T, cmap='RdBu_r', robust=True, center=0, ax=ax,
        vmin=-0.05, vmax=0.05, annot=_stats.T.reindex(areas).T.fillna(''), fmt='s',
        cbar_kws={'ticks':[-0.05, 0.05]},
        cbar_ax=cbar_ax, annot_kws={"size":7}
    )
    cbar_ax.yaxis.label.set_size(7)
    ax.set_ylabel(f'{"running" if running else "resting"}', fontsize=7, labelpad=0.5)
    ax.set_xlabel('')
    ax.set_yticks([0.5, 1.5, 2.5, 3.5])
    ax.set_yticklabels(['all', 'L2/3', 'L4', 'L5'], rotation=0, fontsize=7)
    ax.tick_params(left=False, bottom=False, axis='both', which='major', labelsize=7, pad=-1)

cbar_ax.tick_params(labelsize=6, pad=-2)
cbar_ax.set_ylabel(
    '$\Delta$ z-scored firing rate (hit - miss)', labelpad=-10
)
ax.set_xticks(range(len(areas)))
ax.set_xticklabels(areas, rotation=45, fontsize=7)#, ha='right', rotation_mode='anchor')
f.suptitle(f'firing rates ({win_ms}ms post stimulation)', fontsize=8);
f.savefig('fig_behavior_delta_fr.pdf')
# f.subplots_adjust(0.1, 0.1, 0.88, 0.9);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



---