
Differentiation in the visual behavior dataset

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

from os import path
from glob import glob
import pickle
import itertools

import numpy as np
import pandas as pd
import scipy as sp
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import cm
from IPython.display import display
from IPython.utils.capture import capture_output
from tqdm.auto import tqdm
with capture_output():
    tqdm.pandas()
import h5py

from differentiation import spectral_differentiation as specD

from ipympl.backend_nbagg import Canvas
Canvas.header_visible.default_value = False

In [2]:
data_directory = '/allen/programs/braintv/workgroups/tiny-blue-dot/differentiation/refactor/behavior/'

In [3]:
region_sets = {
    'VisCtx' : ['VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'],
    'HVAs' : ['VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'],
    'THx_VISp' : ['LGd', 'LP', 'TH', 'VISp'],
    'AllVis' : ['LGd', 'LP', 'TH', 'VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'],
    'THx' : ['LGd', 'LP', 'TH'],
    'hipp' : ['CA', 'CA1', 'CA2', 'CA3', 'DG', 'DG-mo', 'DG-po', 'DG-sg'],
}

relevant_regions = [
    'LGd', 'LP', 'TH', 'VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'
]

layer_order = ['L2/3', 'L4', 'L5']

hierarchy = {
    'Input' : -100,
    'stimulus' : -100,
    'Stim' : -100,
    'TH' : -10,
    'LG' : -9,
    'LGv' : -8,
    'LGd' : -7,
    'LP' : -6,
    'THx' : -5,
    'THx_VISp' : -4,
    'VISp' : 0,
    'VISpl' : 2,
    'VISl' : 4,
    'VISli' : 6,
    'VISrl' : 8,
    'VISal' : 10,
    'VISpm' : 12,
    'VISam' : 14,
    'VISpor' : 16,
    'VISa' : 18,
    'SC' : 24,
    'VISmma' : 20,
    'VISmmp' : 20,
    'VIS' : 22,
    'HVAs' : 21,
    'HVAs_ss' : 21.25,
    'VisCtx' : 21.5,
    'VisCtx_ss' : 21.75,
    'AllVis' : 22,
    'PF' : 25,
    'MB' : 30,
    'hipp' : 38,
    'CAx' : 39,
    'CA' : 40,
    'CA1' : 41,
    'CA2' : 42,
    'CA3' : 43,
    'DG' : 50,
}

# groups of units to apply differentiation to
unit_set = []
unit_set  = ['region == "%s"'%reg for reg in [
    'VISp', 'VISl', 'VISal', 'VISam', 'VISpm', 'VISrl', 'LGd', 'LP'
]]
unit_set += ['region in @region_sets.get("%s")'%s for s in region_sets.keys()]

unit_set += ['layer == "%s" & region == "%s"'%(l, r) for l, r in itertools.product(
    ['L2/3', 'L4', 'L5'], ['VISp', 'VISl', 'VISal', 'VISpm', 'VISam', 'VISrl']
)]
unit_set += ['layer == "%s" & region in @region_sets.get("%s")'%(l, s) for l, s in itertools.product(
    ['L2/3', 'L4', 'L5'], ['HVAs', 'VisCtx']
)]
print(f'Applying to {len(unit_set)} groups.')

min_units = 20

Applying to 38 groups.


# Download and organize data

In [4]:
raw_data_path = path.join(data_directory, 'vis_behavior_npx.hdf5')
with h5py.File(raw_data_path, 'r') as f:
    mouse = '09132019_461027'
    print(f.keys())
    print(f[mouse].keys())
    lick_times = f[mouse]['lickTimes'][:]
    flash_times = f[mouse]['behaviorFlashTimes'][:]

<KeysViewHDF5 ['03212019_409096', '03262019_417882', '03272019_417882', '04042019_408528', '04052019_408528', '04102019_408527', '04112019_408527', '04252019_421323', '04262019_421323', '04302019_422856', '05162019_423749', '05172019_423749', '06072019_427937', '06122019_423745', '07112019_429084', '07122019_429084', '08082019_423744', '08092019_423744', '08132019_423750', '08142019_423750', '09052019_459521', '09062019_459521', '09122019_461027', '09132019_461027']>
<KeysViewHDF5 ['behaviorChangeTimes', 'behaviorFlashTimes', 'behaviorOmitFlashTimes', 'behaviorRunDx', 'behaviorRunSpeed', 'behaviorRunTime', 'ccfRegion', 'changeImage', 'flashImage', 'inCortex', 'initialImage', 'isiRegion', 'lickTimes', 'omitFlashImage', 'passiveChangeTimes', 'passiveFlashTimes', 'passiveOmitFlashTimes', 'passiveRunDx', 'passiveRunSpeed', 'passiveRunTime', 'response', 'rewardTimes', 'sdfs', 'spikeTimes', 'units']>


In [5]:
sampling_rate = 200
win = np.exp(-(np.arange(11)-5)**2/4)

def get_mouse_ids():
    with h5py.File(raw_data_path, 'r') as f:
        mouse_ids = list(f.keys())
    return mouse_ids

mouse_ids = get_mouse_ids()

def get_mouse_data(mouse_id):
    with h5py.File(raw_data_path, 'r') as f:
        data = {}
        
        omissions = pd.DataFrame(np.array([
            f[mouse_id]['behaviorOmitFlashTimes'][:],
            [x.decode() for x in f[mouse_id]['omitFlashImage'][:]]
        ]).T, columns=['time', 'image']).set_index('time')
        omissions['type'] = 'omission'
        omissions['session_type'] = 'active'
        try:
            pomissions = pd.DataFrame(np.array([
                f[mouse_id]['passiveOmitFlashTimes'][:],
                [x.decode() for x in f[mouse_id]['omitFlashImage'][:]]
            ]).T, columns=['time', 'image']).set_index('time')
            pomissions['type'] = 'omission'
            pomissions['session_type'] = 'passive'
            omissions = omissions.append(pomissions)
        except:
            pass
        
        changes = pd.DataFrame(np.array([
            f[mouse_id]['behaviorChangeTimes'][:],
            f[mouse_id]['rewardTimes'][:],
            [x.decode() for x in f[mouse_id]['changeImage'][:]],
            [x.decode() for x in f[mouse_id]['initialImage'][:]],
            [x.decode() for x in f[mouse_id]['response'][:]]
        ]).T, columns=[
            'time', 'reward_time', 'image', 'old_image', 'response'
        ]).set_index('time')
        changes['type'] = 'change'
        changes['session_type'] = 'active'
        try:
            pchanges = pd.DataFrame(np.array([
                f[mouse_id]['passiveChangeTimes'][:],
                [x.decode() for x in f[mouse_id]['changeImage'][:]],
                [x.decode() for x in f[mouse_id]['initialImage'][:]],
                [x.decode() for x in f[mouse_id]['response'][:]]
            ]).T, columns=[
                'time', 'image', 'old_image', 'response'
            ]).set_index('time')
            pchanges['type'] = 'change'
            pchanges['session_type'] = 'passive'
            changes = changes.append(pchanges)
        except:
            pass
        
        flashes = pd.DataFrame(np.array([
            f[mouse_id]['behaviorFlashTimes'][:],
            [x.decode() for x in f[mouse_id]['flashImage'][:]]
        ]).T, columns=['time', 'image']).set_index('time')
        flashes['type'] = 'flash'
        flashes['session_type'] = 'active'
        try:
            pflashes = pd.DataFrame(np.array([
                f[mouse_id]['passiveFlashTimes'][:],
                [x.decode() for x in f[mouse_id]['flashImage'][:]]
            ]).T, columns=['time', 'image']).set_index('time')
            pflashes['type'] = 'flash'
            pflashes['session_type'] = 'passive'
            flashes = flashes.append(pflashes)
        except:
            pass
        
        behavior = pd.concat([omissions, changes, flashes])
        behavior.index = np.round(behavior.index.astype('float32'), 3)
        behavior = behavior.sort_index()
        data['behavior'] = behavior
        
        data['running'] = pd.DataFrame(np.array([
            f[mouse_id]['behaviorRunTime'][:],
            f[mouse_id]['behaviorRunSpeed'][:],
            f[mouse_id]['behaviorRunDx'][:],
        ]).T, columns=['time', 'speed', 'dx']).set_index('time')
        try:
            prun = pd.DataFrame(np.array([
                f[mouse_id]['passiveRunTime'][:],
                f[mouse_id]['passiveRunSpeed'][:],
                f[mouse_id]['passiveRunDx'][:],
            ]).T, columns=['time', 'speed', 'dx']).set_index('time')
            data['running'] = data['running'].append(prun)
        except:
            pass
        
        data['lick_times'] = f[mouse_id]['lickTimes'][:]
        
        ccfRegions = {}
        inCortex = {}
        spikeTimes = {}
        for probe in f[mouse_id]['ccfRegion'].keys():
            ccfRegions[probe] = np.array([
                x.decode() for x in f[mouse_id]['ccfRegion'][probe][:]
            ])
            inCortex[probe] = f[mouse_id]['inCortex'][probe][:]
            spikeTimes[probe] = {}
            for unit in f[mouse_id]['units'][probe][:]:
                spikeTimes[probe][unit.decode()] = f[mouse_id]['spikeTimes'][probe][unit][:]
        
        # binarize spiking data
        spikes = {}
        for probe in spikeTimes.keys():
            for i, unit in enumerate(spikeTimes[probe].keys()):
                spikes[(probe, unit, i)] = spikeTimes[probe][unit][:, 0]
        spikes = pd.Series(spikes).rename_axis(['probe', 'unit', 'unit_idx'])
        print(mouse_id, ': total number of units = ', len(spikes))

        # extract units metadata
        units = spikes.index.to_frame(index=False)
        areas = units.apply(lambda r: ccfRegions[r[0]][r[2]], axis=1)
        inCtx = units.apply(lambda r: inCortex[r[0]][r[2]], axis=1)
        layers = areas.str.extract('(\d.*)')[0].fillna('').rename('layer')
        areas = areas.str.rstrip('12/3456ab').fillna('')
        units['area'] = areas
        units['layer'] = layers
        units['inCtx'] = inCtx
        units = units.drop('unit_idx', axis=1)
        data['units'] = units
        spikes.index = pd.MultiIndex.from_frame(units)
        data['spikes'] = spikes
        
        # compute firing rates via convolution with a gaussian (see top of cell for window function)
        n_units = len(spikes)
        maxtime = spikes.apply(lambda x: max(x)).max().round(4)
        maxtimems = np.rint(maxtime*1000).astype(int)#+1
        if not path.exists(f'data/sessions_uint8fr/{mouse_id}.npy'):
            frdata = np.zeros((
                n_units,
                np.rint(maxtime*sampling_rate).astype(int)
            ), dtype='uint8')
            for unit in range(n_units):
                st_int = np.array(spikes.iloc[unit]*1000, dtype=int)
                fr = np.zeros(maxtimems, dtype='uint8')
                fr[st_int[st_int<maxtimems]] = 1
                sample_rate = int(1000/sampling_rate)
                frdata[unit] = (250*np.convolve(
                    fr, win, mode='same'
                )).astype('uint8')[int(sample_rate/2)::sample_rate][:frdata.shape[1]]
            np.save(f'data/sessions_uint8fr/{mouse_id}.npy', frdata)
        frdata = np.load(f'data/sessions_uint8fr/{mouse_id}.npy', mmap_mode='r')
        data['fr'] = pd.DataFrame(
            frdata, index=pd.MultiIndex.from_frame(units),
            columns=np.linspace(0, maxtimems, frdata.shape[1], endpoint=False)/1000
        )
    return data

In [6]:
# compute firing rates for all sessions in the dataset
for mouse_id in tqdm(mouse_ids):
    if path.exists(path.join(data_directory, f'spikes_{mouse_id}.pkl')):
        continue
    behavior, running, lick_times, units, spikes, fr = get_mouse_data(mouse_id).values()
    spikes.to_pickle(path.join(data_directory, f'spikes_{mouse_id}.pkl'))

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))




In [7]:
def compute_differentiation(
    mouse_id, unit_set, norm, function, sampling_rate=200,
    window_length=0.6, state_length=0.1, n_units=None, nrep=1
):
    if n_units is None:
        nrep = 1
    units_bar = tqdm(range(len(unit_set)), desc=mouse_id)
    
    differentiation = []
    
    # load all data
    behavior, running, lick_times, up, spikes, fr = get_mouse_data(mouse_id).values()
    
    # add licking and behavior data to firing rate frame
    behavior = behavior.groupby('time', group_keys=False).apply(
        lambda df: df[df.type=='change'] if len(df)>1 else df
    )
    behavior['stim_id'] = range(len(behavior))

    licks = pd.Series(False, index=fr.columns, name='lick').rename_axis('time')
    idx = pd.Index(lick_times).reindex(licks.index, method='ffill', limit=1)
    licks.loc[idx[0][idx[1]>-1]] = True
    fr.columns = pd.MultiIndex.from_frame(
        licks.reset_index().set_index('time', drop=False).join(
            running.reindex(
                licks.index, method='nearest'
            ).speed.rolling(1000, center=True).mean().bfill().ffill()
        ).join(
            behavior.reindex(
                licks.index, method='ffill',
                limit=int(window_length*sampling_rate)
            )
        )
    )

    t = fr.columns.to_frame(index=False)
    fr = fr[t[~t.type.isna()].time]
    
    idx = fr.columns.to_frame(index=False)
    idx = idx.groupby(
        'stim_id', group_keys=False
    ).apply(lambda df: df.iloc[:int(window_length*sampling_rate)])
    _fr = fr[pd.MultiIndex.from_frame(idx)]
    
    up = up.rename({'area':'region'}, axis=1)
    up['layer'] = up.layer.apply(lambda x: f'L{x}' if len(x)>0 else '')
    
    # compute differentiation
    for i, units_name in enumerate(unit_set):
        units = up[up.eval(unit_set[i])]
        if len(units) == 0:
            units_bar.update()
            continue
        for k in range(nrep):
            units = up[up.eval(unit_set[i])]
            # subsample units if required
            if n_units is not None:
                try:
                    nu = n_units[get_unit_filters(units_name).get('layer', '-')]
                except:
    #                 print(f'skipping {get_unit_filters(units_name).get("layer", "-")}')
                    continue
                if nu > len(units):
                    continue
#                 print(get_unit_filters(units_name).get('layer', '-'), ':', nu, '/', len(units))
                try:
                    units = units.sample(int(nu))
#                     print(sorted(units.index), end=':')
                except:
                    continue

            # extract firing rate for selected units
            unit_fr = _fr.iloc[units.index].T

            # normalize firing rates depending on 'norm'
            if norm=='cohort_full_ts':
                unit_fr = unit_fr / unit_fr.values.mean()
            else:
                raise ValueError(f'normalization {norm} not implemented.')

            # reshape for single shot spectral differentiation calculation
            unit_fr_local = unit_fr.iloc[
                :round((
                    unit_fr.shape[0]//(window_length*sampling_rate)
                )*window_length*sampling_rate)
            ].copy()
            unit_fr_local = np.reshape(
                unit_fr_local.T.values,
                (unit_fr_local.shape[1], -1, int(window_length*sampling_rate))
            ).transpose(1, 0, 2)

            # compute spectral differentiation
            df = function(
                unit_fr_local, sample_rate=sampling_rate, window_length=state_length
            )

            # get median differentiation
            df = np.median(df, axis=1)
#             print(df[:5])

            # put it into a nice series indexed by time
    #         times = unit_fr.index.get_level_values('time')
    #         times = np.linspace(times[0], times[-1], df.shape[0], False)
    #         times = times + np.diff(times).mean()/2
            times = unit_fr.groupby('stim_id').apply(
                lambda df: df.index.get_level_values('time')[0]
            )
            df = pd.Series(
                df, index=times,
                name=(window_length, state_length,
                      f'{units_name}{k if k>0 else ""} & n_units = {len(units)}')
            )
            idx = fr.columns.to_frame(index=False).set_index('time')
            idx['lick'] = idx.lick.rolling(20).sum()
            idx = idx.reindex(df.index, method='nearest')
            df.index = pd.MultiIndex.from_frame(
                idx.rename_axis('time').reset_index()
            )
            differentiation.append(df)
        units_bar.update()
    if len(differentiation)==0:
        return pd.DataFrame()
    differentiation = pd.concat(differentiation, axis=1)
    differentiation = differentiation.sort_index(axis=1)
    differentiation = differentiation.rename_axis(
        columns=['window_length', 'state_length', 'region']
    ).droplevel([0, 1], axis=1)
    return differentiation

# utility functions for properly renaming the columns of the differentiation dataframe
def get_unit_filters(units):
    filter_strings = units.split(' & ')
    filters = {}
    for filt in filter_strings:
        if ' = ' in filt:
            filters[filt.split(' = ')[0]] = filt.split(' = ')[1].strip('""')
        if '==' in filt:
            filters[filt.split(' == ')[0]] = filt.split(' == ')[1].strip('""')
        if '>' in filt:
            filters[filt.split(' > ')[0]] = float(filt.split(' > ')[1])
        if '@' in filt:
            key = filt.split(' ')[0]
            value = filt.split('get')[1].split('"')[1]
            filters[key] = value
    return filters

def rename_columns(c):
    props = get_unit_filters(c.name)
    return dict(
        layer=props.get('layer', '-'),
        area=props.get('region', '-'),
        n_units=int(props['n_units'])
    )

In [8]:
# compute differentiation for all mice
win_ms = 300 # 150/300/600
sta_ms = 60 # 25/60/100
differentiation, n_units = {}, {}
for mouse_id in tqdm(mouse_ids[::-1]):
    fname = path.join(
        data_directory,
        f'spectral_differentiation_{mouse_id}_{win_ms}_{sta_ms}.pkl'
    )
    if not path.exists(fname):
        diffn = compute_differentiation(
            mouse_id, unit_set, 'cohort_full_ts', specD,
            window_length=win_ms/1000, state_length=sta_ms/1000
        )
        try:
            diffn.columns = pd.MultiIndex.from_frame(
                pd.DataFrame(
                    list(
                        diffn.columns.to_frame()
                        .apply(rename_columns, axis=1)
                    )
                )
            )
        except:
            diffn.columns = pd.MultiIndex.from_arrays(
                [['']*len(diffn), ['']*len(diffn), ['']*len(diffn)],
                names=['area', 'layer', 'n_units']
            )
        diffn.to_pickle(fname)
    else:
        diffn = pd.read_pickle(fname)
    
    n_units[mouse_id] = diffn.columns.to_frame(index=False)
    if diffn is not None:
        # normalize differentiation by number of units
        differentiation[mouse_id] = diffn.droplevel(2, axis=1) / n_units[mouse_id].set_index(
            ['layer', 'area']
        ).n_units**0.5
        # drop areas with fewer than min_units units
        differentiation[mouse_id] = differentiation[mouse_id].T[n_units[mouse_id].set_index(
            ['layer', 'area']
        ).n_units>min_units].T
differentiation = pd.concat(
    differentiation,
    names=[
        'expt', 'time', 'lick', 'running_speed', 'image', 'type',
        'session_type', 'reward_time', 'old_image', 'response', 'stim_id'
    ]
)

idx = differentiation.index.to_frame(index=False)
idx['mouse'] = idx.expt.apply(lambda x: x.split('_')[1])
idx['date'] = idx.expt.apply(lambda x: x.split('_')[0])
_idx = idx[['mouse', 'date']].drop_duplicates()
_idx = _idx.join(_idx.groupby('mouse', group_keys=False).apply(
    lambda df: df.date.apply(lambda c: list(df.date.unique()).index(c))
).rename('sid')).set_index(['mouse', 'date'])
idx['sid'] = _idx.loc[pd.MultiIndex.from_frame(idx[['mouse', 'date']])].values
idx = idx[[
    'mouse', 'date', 'sid', 'time', 'lick', 'running_speed', 'image',
    'type', 'session_type', 'reward_time', 'old_image', 'response'
]]
differentiation.index = pd.MultiIndex.from_frame(idx)
differentiation = differentiation.sort_index()
idx = differentiation.index.to_frame(index=False)
differentiation

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))




Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,layer,-,-,-,-,-,-,-,-,-,-,...,L4,L4,L5,L5,L5,L5,L5,L5,L5,L5
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,area,AllVis,HVAs,LGd,LP,THx,THx_VISp,VISal,VISam,VISl,VISp,...,VISrl,VisCtx,HVAs,VISal,VISam,VISl,VISp,VISpm,VISrl,VisCtx
mouse,date,sid,time,lick,running_speed,image,type,session_type,reward_time,old_image,response,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2,Unnamed: 31_level_2,Unnamed: 32_level_2
408527,04102019,1,77.110017,,16.503895,im065,flash,active,,,,11958.140213,11124.471042,,,,15831.283387,10141.424439,,,15831.283387,...,,11238.115544,8614.646261,,,,,21771.238608,,8163.782888
408527,04102019,1,77.860017,0.0,17.680735,im065,flash,active,,,,7666.359327,8012.647396,,,,3668.458954,6447.752733,,,3668.458954,...,,5903.141107,7605.328594,,,,,16538.132159,,6903.410907
408527,04102019,1,78.610017,0.0,17.784651,im065,flash,active,,,,7056.836043,7277.055470,,,,4457.160067,5525.575640,,,4457.160067,...,,5839.011537,6218.592590,,,,,18312.700708,,5646.563639
408527,04102019,1,79.360017,0.0,17.952055,im065,flash,active,,,,5818.645570,6114.179636,,,,3189.455738,5597.129285,,,3189.455738,...,,5432.962644,5759.468855,,,,,13752.593899,,5250.887400
408527,04102019,1,80.115017,0.0,15.611155,im065,flash,active,,,,7446.265775,7883.748461,,,,2451.922211,6891.963253,,,2451.922211,...,,6895.668216,6422.481711,,,,,20640.634511,,5812.278967
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
461027,09132019,0,9158.000000,0.0,10.850420,im085,flash,passive,,,,2259.790927,,,,1483.305250,2119.066436,,,,3548.983127,...,,5370.331735,,,,,,,,2099.467129
461027,09132019,0,9158.750000,0.0,14.521681,im085,flash,passive,,,,2932.830702,,,,1453.029131,2622.603292,,,,3524.785057,...,,4841.838387,,,,,,,,1832.228623
461027,09132019,0,9159.500000,0.0,20.543412,im085,flash,passive,,,,2623.400090,,,,1453.132408,2377.524174,,,,4662.932396,...,,4378.172432,,,,,,,,2567.106036
461027,09132019,0,9160.250000,0.0,29.017218,im085,flash,passive,,,,2513.131066,,,,1435.988370,2321.078575,,,,4305.267075,...,,8122.430405,,,,,,,,2040.232205


In [9]:
n_units = pd.concat(
    n_units, names=['mouse', 'idx']
).set_index(
    ['layer', 'area'], append=True
).droplevel('idx')

idx = n_units.index.to_frame(index=False)
idx['date'] = idx.mouse.apply(lambda x: x.split('_')[0])
idx['mouse'] = idx.mouse.apply(lambda x: x.split('_')[1])
n_units.index = pd.MultiIndex.from_frame(idx)

n_units = n_units[n_units.n_units>min_units]

n_units

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,n_units
mouse,layer,area,date,Unnamed: 4_level_1
461027,L4,VisCtx,09132019,22
461027,L5,VisCtx,09132019,23
461027,-,VISp,09132019,46
461027,-,AllVis,09132019,116
461027,-,THx,09132019,62
...,...,...,...,...
409096,-,AllVis,03212019,277
409096,-,HVAs,03212019,156
409096,-,THx_VISp,03212019,121
409096,-,VisCtx,03212019,277


In [10]:
# no correlation between diffn and n_units after normalizing for n_units
_y = differentiation.groupby(['mouse', 'date']).mean()
_x = n_units.unstack(['layer', 'area'])

f, ax = plt.subplots(figsize=(4, 3), tight_layout=True)
ax.scatter(_x.stack().stack()**0.5, _y.stack().stack(), s=2)
ax.set_xlabel('# units')
ax.set_ylabel('differentiation')
sp.stats.linregress(_y.stack().stack(), _x.stack().stack().n_units**0.5)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

LinregressResult(slope=-7.248491870573729e-05, intercept=9.745794205958918, rvalue=-0.15349740743931195, pvalue=0.0013736812729955887, stderr=2.250271093469086e-05)

In [11]:
len(n_units.index.levels[0])

14

In [12]:
print(
    f"{n_units.xs('-', level='layer').xs('AllVis', level='area').mean().values[0]}"
    f" +/- {n_units.xs('-', level='layer').xs('AllVis', level='area').std().values[0]:.2f}"
    " neurons per mouse."
)
n_units.xs('-', level='layer').xs('AllVis', level='area').sum()

318.25 +/- 146.97 neurons per mouse.


n_units    7638
dtype: int64

In [13]:
# compute differentiation in VisCtx / HVAs after subsampling fewer units to match the numbers in individual areas

areas = ['VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam']

n_median_cortical_units = n_units.groupby(['mouse', 'date', 'layer']).apply(
    lambda df: np.median(df[df.index.get_level_values('area').isin(areas)])
)
idx = n_median_cortical_units.index.to_frame(index=False)
idx['mouse'] = idx.apply(lambda r: f'{r.date}_{r.mouse}', axis=1)
n_median_cortical_units.index = pd.MultiIndex.from_frame(idx[['mouse', 'layer']])
n_median_cortical_units.sort_values().dropna()

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


mouse            layer
05172019_423749  L4        21.0
03262019_417882  L2/3      21.0
07122019_429084  L4        22.0
04102019_408527  L5        23.0
04262019_421323  L5        23.0
                          ...  
04052019_408528  -        105.0
08132019_423750  -        112.0
04042019_408528  -        115.0
05162019_423749  -        130.5
07112019_429084  -        210.0
Length: 75, dtype: float64

In [14]:
# groups of units to apply differentiation to
unit_set_ss = []
unit_set_ss += ['region in @region_sets.get("%s")'%s for s in ['HVAs', 'VisCtx']]
unit_set_ss += ['layer == "%s" & region in @region_sets.get("%s")'%(l, s) for l, s in itertools.product(
    ['L2/3', 'L4', 'L5'], ['HVAs', 'VisCtx']
)]
unit_set_ss

['region in @region_sets.get("HVAs")',
 'region in @region_sets.get("VisCtx")',
 'layer == "L2/3" & region in @region_sets.get("HVAs")',
 'layer == "L2/3" & region in @region_sets.get("VisCtx")',
 'layer == "L4" & region in @region_sets.get("HVAs")',
 'layer == "L4" & region in @region_sets.get("VisCtx")',
 'layer == "L5" & region in @region_sets.get("HVAs")',
 'layer == "L5" & region in @region_sets.get("VisCtx")']

In [15]:
# compute differentiation for all mice after subsampling units
# win_ms = 300 # 150/300/600
# sta_ms = 60 # 25/60/100
differentiation_ss, n_units_ss = {}, {}
for mouse_id in tqdm(mouse_ids[::-1]):
    fname = path.join(
        data_directory,
        f'spectral_differentiation_ss_{mouse_id}_{win_ms}_{sta_ms}.pkl'
    )
    if not path.exists(fname):
        diffn = compute_differentiation(
            mouse_id, unit_set_ss, 'cohort_full_ts', specD,
            window_length=win_ms/1000, state_length=sta_ms/1000,
            n_units=n_median_cortical_units.loc[mouse_id], nrep=10
        )
#         display(diffn)
        try:
            diffn.columns = pd.MultiIndex.from_frame(
                pd.DataFrame(
                    list(
                        diffn.columns.to_frame()
                        .apply(rename_columns, axis=1)
                    )
                )
            )
        except:
            diffn.columns = pd.MultiIndex.from_arrays(
                [['']*len(diffn), ['']*len(diffn), ['']*len(diffn)],
                names=['area', 'layer', 'n_units']
            )
        diffn.to_pickle(fname)
    else:
        diffn = pd.read_pickle(fname)
    
    n_units_ss[mouse_id] = diffn.columns.to_frame(index=False)
    if diffn is not None:
        # normalize differentiation by number of units
        differentiation_ss[mouse_id] = diffn.droplevel(2, axis=1) / n_units_ss[mouse_id].set_index(
            ['layer', 'area']
        ).n_units**0.5
        # drop areas with fewer than 20 units
        differentiation_ss[mouse_id] = differentiation_ss[mouse_id].T[n_units_ss[mouse_id].set_index(
            ['layer', 'area']
        ).n_units>min_units].T
        # keep the mean differentiation across trials
        differentiation_ss[mouse_id] = differentiation_ss[mouse_id].groupby(
            ['layer', 'area'], axis=1
        ).mean()#apply(lambda df: df.iloc[:, 0])#
differentiation_ss = pd.concat(
    differentiation_ss,
    names=[
        'expt', 'time', 'lick', 'running_speed', 'image', 'type',
        'session_type', 'reward_time', 'old_image', 'response', 'stim_id'
    ]
)

idx = differentiation_ss.index.to_frame(index=False)
idx['mouse'] = idx.expt.apply(lambda x: x.split('_')[1])
idx['date'] = idx.expt.apply(lambda x: x.split('_')[0])
_idx = idx[['mouse', 'date']].drop_duplicates()
_idx = _idx.join(_idx.groupby('mouse', group_keys=False).apply(
    lambda df: df.date.apply(lambda c: list(df.date.unique()).index(c))
).rename('sid')).set_index(['mouse', 'date'])
idx['sid'] = _idx.loc[pd.MultiIndex.from_frame(idx[['mouse', 'date']])].values
idx = idx[[
    'mouse', 'date', 'sid', 'time', 'lick', 'running_speed', 'image',
    'type', 'session_type', 'reward_time', 'old_image', 'response'
]]
differentiation_ss.index = pd.MultiIndex.from_frame(idx)
differentiation_ss = differentiation_ss.sort_index()
idx = differentiation_ss.index.to_frame(index=False)

differentiation_ss.columns = differentiation_ss.columns.map(lambda x: (x[0], f'{x[1]}_ss'))
differentiation_ss

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))




Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,layer,-,-,L2/3,L2/3,L4,L4,L5,L5
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,area,HVAs_ss,VisCtx_ss,HVAs_ss,VisCtx_ss,HVAs_ss,VisCtx_ss,HVAs_ss,VisCtx_ss
mouse,date,sid,time,lick,running_speed,image,type,session_type,reward_time,old_image,response,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2
408527,04102019,1,77.110017,,16.503895,im065,flash,active,,,,10446.709975,10538.491826,,,,,,
408527,04102019,1,77.860017,0.0,17.680735,im065,flash,active,,,,7901.851354,7468.201940,,,,,,
408527,04102019,1,78.610017,0.0,17.784651,im065,flash,active,,,,6119.167810,5449.059575,,,,,,
408527,04102019,1,79.360017,0.0,17.952055,im065,flash,active,,,,6058.558553,5557.586276,,,,,,
408527,04102019,1,80.115017,0.0,15.611155,im065,flash,active,,,,7024.421435,6014.130046,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
461027,09132019,0,9158.000000,0.0,10.850420,im085,flash,passive,,,,,3904.497210,,,,,,
461027,09132019,0,9158.750000,0.0,14.521681,im085,flash,passive,,,,,4464.198703,,,,,,
461027,09132019,0,9159.500000,0.0,20.543412,im085,flash,passive,,,,,4543.065849,,,,,,
461027,09132019,0,9160.250000,0.0,29.017218,im085,flash,passive,,,,,4931.542755,,,,,,


In [16]:
differentiation = differentiation.join(differentiation_ss).sort_index(axis=1)
differentiation

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,layer,-,-,-,-,-,-,-,-,-,-,...,L5,L5,L5,L5,L5,L5,L5,L5,L5,L5
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,area,AllVis,HVAs,HVAs_ss,LGd,LP,THx,THx_VISp,VISal,VISam,VISl,...,HVAs,HVAs_ss,VISal,VISam,VISl,VISp,VISpm,VISrl,VisCtx,VisCtx_ss
mouse,date,sid,time,lick,running_speed,image,type,session_type,reward_time,old_image,response,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2,Unnamed: 31_level_2,Unnamed: 32_level_2
408527,04102019,1,77.110017,,16.503895,im065,flash,active,,,,11958.140213,11124.471042,10446.709975,,,,15831.283387,10141.424439,,,...,8614.646261,,,,,,21771.238608,,8163.782888,
408527,04102019,1,77.860017,0.0,17.680735,im065,flash,active,,,,7666.359327,8012.647396,7901.851354,,,,3668.458954,6447.752733,,,...,7605.328594,,,,,,16538.132159,,6903.410907,
408527,04102019,1,78.610017,0.0,17.784651,im065,flash,active,,,,7056.836043,7277.055470,6119.167810,,,,4457.160067,5525.575640,,,...,6218.592590,,,,,,18312.700708,,5646.563639,
408527,04102019,1,79.360017,0.0,17.952055,im065,flash,active,,,,5818.645570,6114.179636,6058.558553,,,,3189.455738,5597.129285,,,...,5759.468855,,,,,,13752.593899,,5250.887400,
408527,04102019,1,80.115017,0.0,15.611155,im065,flash,active,,,,7446.265775,7883.748461,7024.421435,,,,2451.922211,6891.963253,,,...,6422.481711,,,,,,20640.634511,,5812.278967,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
461027,09132019,0,9158.000000,0.0,10.850420,im085,flash,passive,,,,2259.790927,,,,,1483.305250,2119.066436,,,,...,,,,,,,,,2099.467129,
461027,09132019,0,9158.750000,0.0,14.521681,im085,flash,passive,,,,2932.830702,,,,,1453.029131,2622.603292,,,,...,,,,,,,,,1832.228623,
461027,09132019,0,9159.500000,0.0,20.543412,im085,flash,passive,,,,2623.400090,,,,,1453.132408,2377.524174,,,,...,,,,,,,,,2567.106036,
461027,09132019,0,9160.250000,0.0,29.017218,im085,flash,passive,,,,2513.131066,,,,,1435.988370,2321.078575,,,,...,,,,,,,,,2040.232205,


In [17]:
def compute_mfr(mouse_id, unit_set, kind='mean_wrt_flash', win_size_ms=300):
    units_bar = tqdm(range(len(unit_set)), desc=mouse_id)
    mfr = []
    
    behavior, running, lick_times, up, spikes, fr = get_mouse_data(mouse_id).values()
    # remove flash / change duplicates
    behavior = behavior.groupby('time', group_keys=False).apply(
        lambda df: df[df.type=='change'] if len(df)>1 else df
    )
    behavior['stim_id'] = range(len(behavior))

    # add licking and behavior data to firing rate frame
    licks = pd.Series(False, index=fr.columns, name='lick').rename_axis('time')
    idx = pd.Index(lick_times).reindex(licks.index, method='ffill', limit=1)
    licks.loc[idx[0][idx[1]>-1]] = True
    fr.columns = pd.MultiIndex.from_frame(
        licks.reset_index().set_index('time', drop=False).join(
            running.reindex(
                licks.index, method='nearest'
            ).speed.rolling(1000, center=True).mean().bfill().ffill()
        ).join(
            behavior.reindex(
                licks.index, method='ffill',
                limit=int(0.6*sampling_rate)
            )
        )
    )

    t = fr.columns.to_frame(index=False)
    fr = fr[t[~t.type.isna()].time]

    idx = fr.columns.to_frame(index=False)
    idx = idx.groupby('stim_id', group_keys=False).apply(
        lambda df: df.iloc[:int(0.6*sampling_rate)]
    )
    _fr = fr[pd.MultiIndex.from_frame(idx)]
    
    if kind=='mean_wrt_flash':
        _mffr = _fr.T[_fr.T.index.get_level_values('type')=='flash']
        _fr = ((_fr.T - _mffr.mean())/_mffr.std()).T
    
    up = up.rename({'area':'region'}, axis=1)
    up['layer'] = up.layer.apply(lambda x: f'L{x}' if len(x)>0 else '')
    
    for i, units_name in enumerate(unit_set):
        units = up[up.eval(unit_set[i])]
        if len(units) == 0:
            units_bar.update()
            continue

        # extract firing rate for selected units
        unit_fr = _fr.iloc[units.index].T
        df = unit_fr.mean(1)
        
        _mfr = {}
        df.groupby('stim_id').apply(lambda g: _mfr.update({
            tuple(g.index.to_frame().iloc[0])[:-1]:g[:win_size_ms//5].mean()
        }))
        _mfr = pd.Series(_mfr, name=f'{units_name} & n_units = {len(units)}')
    
        mfr.append(_mfr)
        units_bar.update()
    if len(mfr)==0:
        return pd.DataFrame()
    mfr = pd.concat(mfr, axis=1)
    mfr = mfr.sort_index(axis=1)
    return mfr
    mfr = mfr.rename_axis(
        columns=['region']
    ).droplevel([0, 1], axis=1)
    return mfr

In [18]:
# compute mfr for all mice for same window parameters as differentiation
# win_ms = 300 # 150/300/600
# sta_ms = 60 # 25/60/100
kind = 'mean_wrt_flash'
mfr = {}
for mouse_id in tqdm(mouse_ids[::-1]):
    fname = path.join(
        data_directory,
        f'behavior_mfr_{mouse_id}_{win_ms}_{sta_ms}.pkl'
    )
    if not path.exists(fname):
        diffn = compute_mfr(
            mouse_id, unit_set, win_size_ms=win_ms, kind='mean_wrt_flash'
        )
        try:
            diffn.columns = pd.MultiIndex.from_frame(
                pd.DataFrame(
                    list(
                        diffn.columns.to_frame()
                        .apply(rename_columns, axis=1)
                    )
                )
            )
        except:
            diffn.columns = pd.MultiIndex.from_arrays(
                [['']*len(diffn), ['']*len(diffn), ['']*len(diffn)],
                names=['area', 'layer', 'n_units']
            )
        diffn.to_pickle(fname)
    else:
        diffn = pd.read_pickle(fname)
    
    nu = diffn.columns.to_frame(index=False)
    if diffn is not None:
        mfr[mouse_id] = diffn.droplevel(2, axis=1)
        # drop areas with fewer than min_units units
        mfr[mouse_id] = mfr[mouse_id].T[nu.set_index(['layer', 'area']).n_units>min_units].T
        
mfr = pd.concat(mfr, names=[
    'expt', 'time', 'lick', 'running_speed', 'image', 'type',
    'session_type', 'reward_time', 'old_image', 'response'
])

idx = mfr.index.to_frame(index=False)
idx['mouse'] = idx.expt.apply(lambda x: x.split('_')[1])
idx['date'] = idx.expt.apply(lambda x: x.split('_')[0])
_idx = idx[['mouse', 'date']].drop_duplicates()
_idx = _idx.join(_idx.groupby('mouse', group_keys=False).apply(
    lambda df: df.date.apply(
        lambda c: list(df.date.unique()).index(c)
    )
).rename('sid')).set_index(['mouse', 'date'])
idx['sid'] = _idx.loc[
    pd.MultiIndex.from_frame(idx[['mouse', 'date']])
].values
idx = idx[[
    'mouse', 'date', 'sid', 'time', 'lick', 'running_speed', 'image',
    'type', 'session_type', 'reward_time', 'old_image', 'response'
]]
mfr.index = pd.MultiIndex.from_frame(idx)
mfr = mfr.sort_index()
mfr = mfr.replace([np.inf, -np.inf], np.nan)
idx = differentiation.index.to_frame(index=False)
mfr

HBox(children=(IntProgress(value=0, max=24), HTML(value='')))




Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,Unnamed: 6_level_0,Unnamed: 7_level_0,Unnamed: 8_level_0,Unnamed: 9_level_0,Unnamed: 10_level_0,layer,-,-,-,-,-,-,-,-,-,-,...,L4,L4,L5,L5,L5,L5,L5,L5,L5,L5
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,area,AllVis,HVAs,LGd,LP,THx,THx_VISp,VISal,VISam,VISl,VISp,...,VISrl,VisCtx,HVAs,VISal,VISam,VISl,VISp,VISpm,VISrl,VisCtx
mouse,date,sid,time,lick,running_speed,image,type,session_type,reward_time,old_image,response,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2,Unnamed: 17_level_2,Unnamed: 18_level_2,Unnamed: 19_level_2,Unnamed: 20_level_2,Unnamed: 21_level_2,Unnamed: 22_level_2,Unnamed: 23_level_2,Unnamed: 24_level_2,Unnamed: 25_level_2,Unnamed: 26_level_2,Unnamed: 27_level_2,Unnamed: 28_level_2,Unnamed: 29_level_2,Unnamed: 30_level_2,Unnamed: 31_level_2,Unnamed: 32_level_2
408527,04102019,1,77.110017,False,16.503895,im065,flash,active,,,,0.192562,0.181356,,,,0.252669,0.343690,,,0.252669,...,,0.204801,0.185314,,,,,0.198823,,0.203566
408527,04102019,1,77.860017,False,17.680735,im065,flash,active,,,,0.055447,0.060794,,,,0.026768,0.180335,,,0.026768,...,,0.057216,0.072113,,,,,0.097658,,0.079115
408527,04102019,1,78.610017,False,17.784651,im065,flash,active,,,,0.069885,0.069686,,,,0.070949,0.111654,,,0.070949,...,,0.055960,0.122702,,,,,0.119834,,0.122510
408527,04102019,1,79.360017,False,17.952055,im065,flash,active,,,,0.020651,0.024535,,,,-0.000178,0.087078,,,-0.000178,...,,0.011397,0.042805,,,,,0.055147,,0.043504
408527,04102019,1,80.115017,False,15.611155,im065,flash,active,,,,0.050543,0.057281,,,,0.014407,0.144031,,,0.014407,...,,0.059712,0.043734,,,,,0.077505,,0.044403
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
461027,09132019,0,9158.000000,False,10.850420,im085,flash,passive,,,,0.045082,,,,0.042819,0.038526,,,,0.032611,...,,0.018326,,,,,,,,0.046305
461027,09132019,0,9158.750000,False,14.521681,im085,flash,passive,,,,0.067094,,,,0.035864,0.060435,,,,0.094288,...,,0.134854,,,,,,,,0.019145
461027,09132019,0,9159.500000,False,20.543412,im085,flash,passive,,,,0.028772,,,,0.004329,0.024379,,,,0.052003,...,,0.056017,,,,,,,,0.017057
461027,09132019,0,9160.250000,False,29.017218,im085,flash,passive,,,,0.049667,,,,0.012673,0.039631,,,,0.076774,...,,0.151484,,,,,,,,0.039630


In [19]:
def bhc(stats, alpha=0.01):
    '''
    Benjamini-Hochberg multiple comparison correction
    stats must have a pval column
    '''
    stats = stats.sort_values('pval')
    stats['rank'] = range(len(stats))
    stats['thresh'] = (stats['rank']+1)/len(stats)*alpha
    idx = stats.index[stats.pval<stats.thresh]
    if len(idx)==0:
        print('No significant values.')
        return None
    else:
        _s = stats[stats.pval<stats.thresh].copy()
        _s['rank2'] = range(len(_s))
        return _s[_s['rank']<=_s['rank2']].drop('rank2', axis=1)

In [20]:
# each point is a stimulation instance
idx = differentiation.index.to_frame(index=False)
idx['response'] = idx.response.replace(['correctReject', 'noChange', 'falseAlarm'], None)#.replace(['correctReject'], None)
idx['is_running'] = True
idx.loc[idx.index[idx.running_speed>1], 'is_running'] = False
_differentiation = pd.DataFrame(differentiation, index=pd.MultiIndex.from_frame(idx), copy=True)
is_running = -1

if is_running>-1:
    diffn = _differentiation.groupby(level=['session_type', 'is_running', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
else:
    diffn = _differentiation.groupby(level=['session_type', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
display(diffn['-'].head())

f, _axes = plt.subplots(4, 2, figsize=(12, 7), constrained_layout=True, sharex=True, sharey=True)

for axes, session_type in zip(_axes.T, ['active', 'shuffle']):# active, passive and shuffle options
    for ax, layer in zip(axes, diffn.columns.levels[0]):
        areas = [a for a in hierarchy.keys() if a in diffn[layer].columns]
        st = session_type
        if st=='shuffle':
            st = 'active'
        if is_running > -1:
            _dfn = diffn.loc[(st, is_running), layer]/diffn.loc[st, layer].mean()
        else:
            _dfn = diffn.loc[(st), layer]/diffn.loc[st, layer].mean()
        if session_type=='shuffle':
            _dfn.index = pd.MultiIndex.from_frame(_dfn.index.to_frame().sample(frac=1))
        (
            _dfn.stack()
            .droplevel(1).to_frame().reset_index()
            .rename(columns={'level_1': 'idx', 0: 'differentiation'})
            .pipe(
                (sns.boxplot, 'data'),
                x='area', y='differentiation', hue='response', ax=ax,
                fliersize=2, showfliers=False, order=areas
            )
        )
        ax.legend(fontsize=7)
        ax.set_ylabel(f'{layer}\ndifferentiation')
    axes[0].set_title(f'{session_type} ({"running" if is_running==1 else "resting" if is_running==0 else ""})' , fontsize=9);

Unnamed: 0_level_0,Unnamed: 1_level_0,area,AllVis,HVAs,HVAs_ss,LGd,LP,THx,THx_VISp,VISal,VISam,VISl,VISp,VISpm,VISrl,VisCtx,VisCtx_ss,hipp
session_type,response,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1
active,hit,0,9465.235754,8871.586551,7889.744193,,,,6482.144965,3416.331878,,,6482.144965,14842.750977,4681.478445,9465.235754,9829.243377,3687.836615
active,hit,1,8681.645885,8505.386443,6242.956034,,,,8134.070216,7957.684861,,,8134.070216,11241.847651,3533.351814,8681.645885,6843.551226,3397.808238
active,hit,2,7682.529746,7485.630978,8231.100489,,,,5702.026156,6259.862077,,,5702.026156,14350.133308,5169.386919,7682.529746,7824.195904,3680.94789
active,hit,3,9626.889008,9365.771794,9315.201548,,,,9281.875891,8257.173957,,,9281.875891,26736.726842,3236.190602,9626.889008,8017.841483,3870.548263
active,hit,4,8643.776516,9167.4964,7307.599256,,,,2225.547938,6058.061301,,,2225.547938,8564.9193,4012.919223,8643.776516,5800.112252,3489.183731


Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [21]:
# for each col, get a pval for the difference between hits and misses
# then pass the pval dataframe to bhc with alpha to get corrected pvals

In [30]:
# each point is a stimulation instance
areas = ['VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam', 'HVAs_ss', 'VisCtx_ss']#, 'HVAs', 'VisCtx'
idx = differentiation.index.to_frame(index=False)
idx['response'] = idx.response.replace(['correctReject', 'noChange', 'falseAlarm'], None)#.replace(['correctReject'], None)
idx['is_running'] = True
idx.loc[idx.index[idx.running_speed>0.15], 'is_running'] = False
_differentiation = pd.DataFrame(differentiation, index=pd.MultiIndex.from_frame(idx), copy=True)
is_running = -1

if is_running>-1:
    diffn = _differentiation.groupby(level=['session_type', 'is_running', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
else:
    diffn = _differentiation.groupby(level=['session_type', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))

with sns.axes_style('white'):
    f, axes = plt.subplots(4, 1, figsize=(4.5, 4.5), constrained_layout=True, sharex=True)

session_type = 'active'
for ax, layer in zip(axes, diffn.columns.levels[0]):
#     areas = [a for a in hierarchy.keys() if a in diffn[layer].columns]
    st = session_type
    if st=='shuffle':
        st = 'active'
    if is_running > -1:
        _dfn = diffn.loc[(st, is_running), layer]#/diffn.loc[st, layer].mean()
    else:
        _dfn = diffn.loc[(st), layer]#/diffn.loc[st, layer].mean()
    stats = _dfn.apply(lambda c: sp.stats.ttest_ind(c['hit'].dropna(), c['miss'].dropna(), equal_var=False)).T
    stats.columns = ['stastic', 'pval']
    stats = bhc(stats, 1e-4)
    if session_type=='shuffle':
        _dfn.index = pd.MultiIndex.from_frame(_dfn.index.to_frame().sample(frac=1))
    (
        _dfn.stack()
        .droplevel(1).to_frame().reset_index()
        .rename(columns={'level_1': 'idx', 0: 'differentiation'})
        .pipe(
            (sns.boxplot, 'data'),
            x='area', y='differentiation', hue='response', ax=ax,
            fliersize=2, showfliers=False, order=areas, width=0.4,
            linewidth=1, palette={'hit':cm.Reds(0.6, 0.6), 'miss':cm.Greys(0.6, 0.6)}
        )
    )
    yt = ax.get_yticks()
    ax.set_yticks([yt[1], yt[-2]])
    yl = ax.get_ylim()[1]
    for i, a in enumerate(areas):
        if a in stats.index:
#             print(i)
            ax.annotate('*', (i, 0.9), xycoords=('data', 'axes fraction'), ha='center')
            ax.plot([i-0.12, i+0.12], [yl*0.9, yl*0.9], c='r')
#     ax.yaxis.set_major_locator(plt.MaxNLocator(2))
#     ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, y: f'{x/1e5}'))
    ax.legend(fontsize=7).set_visible(False)
    ax.set_ylabel(f'ND ({layer if layer!="-" else "all layers"})', fontsize=8, labelpad=0)
    ax.set_xlabel('')
    ax.set_xticklabels([x.get_text().rstrip('_ss') for x in ax.get_xticklabels()])
    ax.tick_params(axis='both', which='major', labelsize=8, pad=-1)
    ax.yaxis.get_offset_text().set_size(8)
    bottom = True
    if list(axes).index(ax)==len(axes)-1:
        bottom = False
    sns.despine(ax=ax, left=False, bottom=bottom)
# axes[0].set_title(f'{session_type} ({"running" if is_running==1 else "resting" if is_running==0 else ""})' , fontsize=9)
# axes[0].set_title(
#     f'spectral differentiation ({"all trials" if is_running==-1 else "running" if is_running==1 else "resting"})',
#     fontsize=10
# )
# axes[0].set_title('spectral differentiation')
axes[0].legend(loc=(0.06, 0.8), fontsize=8, frameon=False, ncol=1)
f.align_ylabels(axes);
f.savefig('fig_behavior_boxplots.svg')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [31]:
# at the level of individual mice, there is no significant difference between hits and misses
diffn = _differentiation.groupby(level=['mouse', 'date', 'session_type', 'is_running', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
_dfn = diffn.xs('active', level='session_type')

stats = _dfn.groupby('is_running').apply(
    lambda df: df.groupby(
        ['mouse', 'date', 'response']
    ).mean().droplevel([0, 1]).apply(
        lambda c: sp.stats.ttest_ind(
            c['hit'].dropna(), c['miss'].dropna(),
            equal_var=False
        )
    ).T
)
stats = stats.swaplevel(0, -1).loc[areas].swaplevel(0, -1).sort_index().dropna()
stats.columns = ['stastic', 'pval']
# stats = stats.groupby('is_running').apply(lambda s: bhc(s, 0.5))
stats.sort_values('pval')

  **kwargs)
  ret = ret.dtype.type(ret / rcount)


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,stastic,pval
is_running,layer,area,Unnamed: 3_level_1,Unnamed: 4_level_1
False,L4,VISam,1.703512,0.161096
False,L4,VISpm,1.782522,0.303176
False,-,VISrl,0.77711,0.447244
True,-,VisCtx_ss,0.74591,0.493948
True,L5,VisCtx_ss,0.8003,0.530197
False,-,HVAs_ss,0.616244,0.540911
False,-,VISl,0.599602,0.55406
True,-,VISrl,0.649363,0.562389
False,L2/3,VisCtx_ss,0.575705,0.570371
False,L5,VISal,0.593784,0.599293


In [32]:
# each point is a stimulation instance
areas = ['VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam', 'HVAs', 'VisCtx']
idx = mfr.index.to_frame(index=False)
idx['response'] = idx.response.replace(['correctReject', 'noChange', 'falseAlarm'], None)#.replace(['correctReject'], None)
idx['is_running'] = True
idx.loc[idx.index[idx.running_speed>1], 'is_running'] = False
_mfr = pd.DataFrame(mfr.replace([-np.inf, np.inf], np.nan), index=pd.MultiIndex.from_frame(idx), copy=True)
is_running = -1 # -1 for all trials, 0 for resting and 1 for running

if is_running>-1:
    diffn = _mfr.groupby(level=['session_type', 'is_running', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
else:
    diffn = _mfr.groupby(level=['session_type', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))

with sns.axes_style('white'):
    f, axes = plt.subplots(4, 1, figsize=(4.5, 4.5), constrained_layout=True, sharex=True)

session_type = 'active'
for ax, layer in zip(axes, diffn.columns.remove_unused_levels().levels[0]):
#     areas = [a for a in hierarchy.keys() if a in diffn[layer].columns]
    st = session_type
    if st=='shuffle':
        st = 'active'
    if is_running > -1:
        _dfn = diffn.loc[(st, is_running), layer]#/diffn.loc[st, layer].mean()
    else:
        _dfn = diffn.loc[(st), layer]#/diffn.loc[st, layer].mean()
    stats = _dfn.apply(lambda c: sp.stats.ttest_ind(c['hit'].dropna(), c['miss'].dropna(), equal_var=False)).T
    stats.columns = ['stastic', 'pval']
    stats = bhc(stats, 1e-4)
    if session_type=='shuffle':
        _dfn.index = pd.MultiIndex.from_frame(_dfn.index.to_frame().sample(frac=1))
    (
        _dfn.stack()
        .droplevel(1).to_frame().reset_index()
        .rename(columns={'level_1': 'idx', 0: 'differentiation'})
        .pipe(
            (sns.boxplot, 'data'),
            x='area', y='differentiation', hue='response', ax=ax,
            fliersize=2, showfliers=False, order=areas, width=0.4,
            linewidth=1, palette={'hit':cm.Reds(0.6, 0.6), 'miss':cm.Greys(0.6, 0.6)}
        )
    )
    yt = ax.get_yticks()
    ax.set_yticks([yt[1], yt[-2]])
    yl = ax.get_ylim()[1]
    if stats is not None:
        for i, a in enumerate(areas):
            if a in stats.index:
    #             print(i)
                ax.annotate('*', (i, 0.9), xycoords=('data', 'axes fraction'), ha='center')
                ax.plot([i-0.12, i+0.12], [yl*0.9, yl*0.9], c='r')
#     ax.yaxis.set_major_locator(plt.MaxNLocator(2))
#     ax.yaxis.set_major_formatter(ticker.FuncFormatter(lambda x, y: f'{x/1e5}'))
    ax.legend(fontsize=7).set_visible(False)
    ax.set_ylabel(f'{layer}\ndifferentiation', fontsize=8, labelpad=0)
    ax.set_xlabel('')
    ax.tick_params(axis='both', which='major', labelsize=8, pad=-1)
    ax.yaxis.get_offset_text().set_size(8)
    bottom = True
    if list(axes).index(ax)==len(axes)-1:
        bottom = False
    sns.despine(ax=ax, left=False, bottom=bottom)
# axes[0].set_title(f'{session_type} ({"running" if is_running==1 else "resting" if is_running==0 else ""})' , fontsize=9)
axes[0].set_title(f'firing rate ({"all trials" if is_running==-1 else "running" if is_running==1 else "resting"})')
axes[0].legend(loc=2, fontsize=9, frameon=False, ncol=2)
f.align_ylabels(axes);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [33]:
(differentiation.index.get_level_values('running_speed')<0.15).sum()/len(differentiation)*100

8.899454634679309

In [34]:
areas = ['LP', 'THx', 'VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam', 'HVAs_ss', 'VisCtx_ss']#, 'HVAs', 'VisCtx', 'AllVis', 'hipp']
idx = differentiation.index.to_frame(index=False)
idx['response'] = idx.response.replace(['correctReject', 'noChange', 'falseAlarm'], None)
idx['is_running'] = True
idx.loc[idx.index[idx.running_speed>1], 'is_running'] = False
_differentiation = pd.DataFrame(differentiation, index=pd.MultiIndex.from_frame(idx), copy=True)

def get_dfn_stats(_differentiation, is_running=-1, pval=1e-4):
    if is_running>-1:
        diffn = _differentiation.groupby(level=['session_type', 'is_running', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
    else:
        diffn = _differentiation.groupby(level=['session_type', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
    areas = [a for a in hierarchy.keys() if a in diffn[layer].columns]
    if is_running > -1:
        _dfn = diffn.loc[(st, is_running)]#/diffn.loc[st, layer].mean()
    else:
        _dfn = diffn.loc[(st)]#/diffn.loc[st, layer].mean()
    stats = _dfn.apply(lambda c: sp.stats.ttest_ind(c['hit'].dropna(), c['miss'].dropna(), equal_var=False)).T
    stats.columns = ['stastic', 'pval']
    stats = bhc(stats, pval)
    return _dfn, stats

f, axes = plt.subplots(2, 1, figsize=(3.5, 2.4), sharex=True, constrained_layout=True)
cbar_ax = f.add_axes([0.88, .2, .03, .75])
f.tight_layout(rect=[0.04, 0.04, 0.9, 1])

for ax, running in zip(axes, [0, 1]):
    _dfn, stats = get_dfn_stats(_differentiation, is_running=running)
    _dfn = _dfn.groupby('response').mean().diff(-1).stack().loc['hit'].T
    areas = [a for a in hierarchy.keys() if a in areas]
    _stats = pd.DataFrame(index=_dfn.index, columns=_dfn.columns)
    _stats.loc[stats.pval.unstack().index, stats.pval.unstack().columns] = stats.pval.unstack()
    _stats[~_stats.isna()] = '*'
#     display(_dfn[areas])
#     display(_stats[areas])
    sns.heatmap(
        _dfn.T.reindex(areas).T, cmap='RdBu_r', robust=True, center=0, ax=ax,
        vmin=-1000, vmax=4000, annot=_stats.T.reindex(areas).T.fillna(''), fmt='s',
        cbar_kws={'ticks':[-1000, 4000], 'format':'%.0e'},
        cbar_ax=cbar_ax, annot_kws={"size":7}
    )
    cbar_ax.yaxis.label.set_size(7)
    ax.set_ylabel(f'{"running" if running else "resting"}', fontsize=7, labelpad=0.5)
    ax.set_xlabel('')
    ax.set_yticks([0.5, 1.5, 2.5, 3.5])
    ax.set_yticklabels(['all', 'L2/3', 'L4', 'L5'], rotation=0, fontsize=7)
    ax.tick_params(left=False, bottom=False, axis='both', which='major', labelsize=7, pad=-1)

cbar_ax.tick_params(labelsize=6, pad=-2)
cbar_ax.set_ylabel(
    '$\Delta$ differentiation (hit - miss)', labelpad=-10
)
ax.set_xticks(range(len(areas)))
ax.set_xticklabels(areas, rotation=45, fontsize=7)#, ha='right', rotation_mode='anchor')
f.suptitle(f'differentiation ({win_ms}ms post stimulation)', fontsize=8);
f.savefig('fig_behavior_delta_df.pdf')
# f.subplots_adjust(0.1, 0.1, 0.88, 0.9);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



In [35]:
idx = mfr.index.to_frame(index=False)
idx['response'] = idx.response.replace(['correctReject', 'noChange', 'falseAlarm'], None)
idx['is_running'] = True
idx.loc[idx.index[idx.running_speed>1], 'is_running'] = False
_mfr = pd.DataFrame(mfr.replace([np.inf, -np.inf], np.nan), index=pd.MultiIndex.from_frame(idx), copy=True)

def get_dfn_stats(_mfr, is_running=-1, pval=1e-4):
    if is_running>-1:
        diffn = _mfr.groupby(level=['session_type', 'is_running', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
    else:
        diffn = _mfr.groupby(level=['session_type', 'response'], dropna=False).apply(lambda df: df.reset_index(drop=True))
    areas = [a for a in hierarchy.keys() if a in diffn[layer].columns]
    if is_running > -1:
        _dfn = diffn.loc[(st, is_running)]#/diffn.loc[st, layer].mean()
    else:
        _dfn = diffn.loc[(st)]#/diffn.loc[st, layer].mean()
    stats = _dfn.apply(lambda c: sp.stats.ttest_ind(c['hit'].dropna(), c['miss'].dropna(), equal_var=False)).T
    stats.columns = ['stastic', 'pval']
    stats = bhc(stats, pval)
    return _dfn, stats

f, axes = plt.subplots(2, 1, figsize=(3.5, 2.4), sharex=True, constrained_layout=True)
cbar_ax = f.add_axes([0.88, .2, .03, .75])
f.tight_layout(rect=[0.04, 0.04, 0.9, 1])

for ax, running in zip(axes, [0, 1]):
    _dfn, stats = get_dfn_stats(_mfr, is_running=running)
    _dfn = _dfn.groupby('response').mean().diff(-1).stack().loc['hit'].T
    areas = [a for a in hierarchy.keys() if a in ['LP', 'THx', 'VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam', 'HVAs', 'VisCtx']]#, 'AllVis', 'hipp']]
    _stats = pd.DataFrame(index=_dfn.index, columns=_dfn.columns)
    _stats.loc[stats.pval.unstack().index, stats.pval.unstack().columns] = stats.pval.unstack()
    _stats[~_stats.isna()] = '*'
#     display(_dfn[areas])
#     display(_stats[areas])
    sns.heatmap(
        _dfn.T.reindex(areas).T, cmap='RdBu_r', robust=True, center=0, ax=ax,
        vmin=-0.05, vmax=0.05, annot=_stats.T.reindex(areas).T.fillna(''), fmt='s',
        cbar_kws={'ticks':[-0.05, 0.05]},
        cbar_ax=cbar_ax, annot_kws={"size":7}
    )
    cbar_ax.yaxis.label.set_size(7)
    ax.set_ylabel(f'{"running" if running else "resting"}', fontsize=7, labelpad=0.5)
    ax.set_xlabel('')
    ax.set_yticks([0.5, 1.5, 2.5, 3.5])
    ax.set_yticklabels(['all', 'L2/3', 'L4', 'L5'], rotation=0, fontsize=7)
    ax.tick_params(left=False, bottom=False, axis='both', which='major', labelsize=7, pad=-1)

cbar_ax.tick_params(labelsize=6, pad=-2)
cbar_ax.set_ylabel(
    '$\Delta$ z-scored firing rate (hit - miss)', labelpad=-10
)
ax.set_xticks(range(len(areas)))
ax.set_xticklabels(areas, rotation=45, fontsize=7)#, ha='right', rotation_mode='anchor')
f.suptitle(f'firing rates ({win_ms}ms post stimulation)', fontsize=8);
f.savefig('fig_behavior_delta_fr.pdf')
# f.subplots_adjust(0.1, 0.1, 0.88, 0.9);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



In [37]:
f, ax = plt.subplots(1, 1, figsize=(3.5, 2.4), constrained_layout=True)
cbar_ax = f.add_axes([0.88, .2, .03, .75])
f.tight_layout(rect=[0.04, 0.04, 0.9, 1])
running = -1

_dfn, stats = get_dfn_stats(_differentiation, is_running=running)
_dfn = _dfn.groupby('response').mean().diff(-1).stack().loc['hit'].T
areas = [a for a in hierarchy.keys() if a in areas]
_stats = pd.DataFrame(index=_dfn.index, columns=_dfn.columns)
_stats.loc[stats.pval.unstack().index, stats.pval.unstack().columns] = stats.pval.unstack()
_stats[~_stats.isna()] = '*'
#     display(_dfn[areas])
#     display(_stats[areas])
sns.heatmap(
    _dfn.T.reindex(areas).T, cmap='RdBu_r', robust=True, center=0, ax=ax,
    vmin=-1000, vmax=4000, annot=_stats.T.reindex(areas).T.fillna(''), fmt='s',
    cbar_kws={'ticks':[-1000, 4000], 'format':'%.0e'},
    cbar_ax=cbar_ax, annot_kws={"size":7}
)
cbar_ax.yaxis.label.set_size(7)
# ax.set_ylabel(f'{"running" if running else "resting"}', fontsize=7, labelpad=0.5)
ax.set_xlabel('')
ax.set_yticks([0.5, 1.5, 2.5, 3.5])
ax.set_yticklabels(['all', 'L2/3', 'L4', 'L5'], rotation=0, fontsize=7)
ax.tick_params(left=False, bottom=False, axis='both', which='major', labelsize=7, pad=-1)

cbar_ax.tick_params(labelsize=6, pad=-2)
cbar_ax.set_ylabel(
    '$\Delta$ ND (hit - miss)', labelpad=-10
)
ax.set_xticks(range(len(areas)))
ax.set_xticklabels(areas, rotation=45, fontsize=7)#, ha='right', rotation_mode='anchor')
f.suptitle(f'ND ({win_ms}ms post stimulation)', fontsize=8);
f.savefig('fig_behavior_delta_df_all_trials.svg')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  This is separate from the ipykernel package so we can avoid doing imports until
  This is separate from the ipykernel package so we can avoid doing imports until


In [39]:
f, ax = plt.subplots(1, 1, figsize=(3.5, 2.4), constrained_layout=True)
cbar_ax = f.add_axes([0.88, .2, .03, .75])
f.tight_layout(rect=[0.04, 0.04, 0.9, 1])
running = -1

_dfn, stats = get_dfn_stats(_mfr, is_running=running)
_dfn = _dfn.groupby('response').mean().diff(-1).stack().loc['hit'].T
areas = [a for a in hierarchy.keys() if a in ['LP', 'THx', 'VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam', 'HVAs', 'VisCtx']]#, 'AllVis', 'hipp']]
_stats = pd.DataFrame(index=_dfn.index, columns=_dfn.columns)
_stats.loc[stats.pval.unstack().index, stats.pval.unstack().columns] = stats.pval.unstack()
_stats[~_stats.isna()] = '*'
#     display(_dfn[areas])
#     display(_stats[areas])
sns.heatmap(
    _dfn.T.reindex(areas).T, cmap='RdBu_r', robust=True, center=0, ax=ax,
    vmin=-0.05, vmax=0.05, annot=_stats.T.reindex(areas).T.fillna(''), fmt='s',
    cbar_kws={'ticks':[-0.05, 0.05]},
    cbar_ax=cbar_ax, annot_kws={"size":7}
)
cbar_ax.yaxis.label.set_size(7)
# ax.set_ylabel(f'{"running" if running else "resting"}', fontsize=7, labelpad=0.5)
ax.set_xlabel('')
ax.set_yticks([0.5, 1.5, 2.5, 3.5])
ax.set_yticklabels(['all', 'L2/3', 'L4', 'L5'], rotation=0, fontsize=7)
ax.tick_params(left=False, bottom=False, axis='both', which='major', labelsize=7, pad=-1)

cbar_ax.tick_params(labelsize=6, pad=-2)
cbar_ax.set_ylabel(
    '$\Delta$ z-scored firing rate (hit - miss)', labelpad=-10
)
ax.set_xticks(range(len(areas)))
ax.set_xticklabels(areas, rotation=45, fontsize=7)#, ha='right', rotation_mode='anchor')
f.suptitle(f'firing rate ({win_ms}ms post stimulation)', fontsize=8);
f.savefig('fig_behavior_delta_fr_all_trials.svg')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

  This is separate from the ipykernel package so we can avoid doing imports until
  This is separate from the ipykernel package so we can avoid doing imports until


---