In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [2]:
import os
from os import path
import itertools
import pickle

import numpy as np
import pandas as pd
import scipy as sp
from tqdm.auto import tqdm
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import cm
import seaborn as sns
from glob import glob
import umap
from IPython.utils.capture import capture_output
import sklearn as sk

from ipympl.backend_nbagg import Canvas
Canvas.header_visible.default_value = False

In [3]:
data_directory = (
    '/allen/programs/braintv/workgroups/tiny-blue-dot/'
    'differentiation/refactor/data'
)

hierarchy = {
    'Input' : -100,
    'stimulus' : -100,
    'Stim' : -100,
    'TH' : -10,
    'LG' : -9,
    'LGv' : -8,
    'LGd' : -7,
    'LP' : -6,
    'THx' : -5,
    'THx_VISp' : -4,
    'VISp' : 0,
    'VISpl' : 2,
    'VISl' : 4,
    'VISli' : 6,
    'VISrl' : 8,
    'VISal' : 10,
    'VISpm' : 12,
    'VISam' : 14,
    'VISpor' : 16,
    'VISa' : 18,
    'SC' : 24,
    'VISmma' : 20,
    'VISmmp' : 20,
    'VIS' : 22,
    'HVAs' : 21,
    'VisCtx' : 21.5,
    'AllVis' : 22,
    'PF' : 25,
    'MB' : 30,
    'hipp' : 38,
    'CAx' : 39,
    'CA' : 40,
    'CA1' : 41,
    'CA2' : 42,
    'CA3' : 43,
    'DG' : 50,
}

hierarchy_score = {
    'LGd':-0.5150279628298357,
    'VISp':-0.35733209934482374,
    'VISl':-0.09388855125761343,#LM
    'VISrl':-0.05987132463908328,
    'LP':0.10524780962600731,
    'VISal':0.15221797920142832,
    'VISpm':0.32766807486511995,
    'VISam':0.440986074378801
}

colors = {
    x : i/len(hierarchy) for i, x in enumerate(hierarchy.keys())
}

layer_depths = {
    'L1' : 100,
    'L2/3' : 210,
    'L4' : 120,
    'L5' : 220,
    'L6' : 200,
}

region_sets = {
    'VisCtx' : ['VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'],
    'HVAs' : ['VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'],
    'AllVis' : [
        'LGd', 'LP', 'TH', 'VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'
    ],
    'THx' : ['LGd', 'LP', 'TH'],
    'hipp' : ['CA', 'CA1', 'CA2', 'CA3', 'DG', 'DG-mo', 'DG-po', 'DG-sg'],
}

relevant_regions = [
    'LGd', 'LP', 'TH', 'VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam'
]

stim_colors_bg = {
    'spontaneous' : cm.Greys(0.3, 0.3),
    'gabors' : cm.Reds(0.7, 0.3),
    'flashes' : cm.Reds(0.3, 0.3),
    'drifting_gratings' : cm.Blues(0.8, 0.3),
    'drifting_gratings_contrast' : cm.Blues(0.99, 0.3),
    'static_gratings' : cm.Blues(0.5, 0.3),
    'natural_movie_three' : cm.Greens(0.9, 0.3),
    'natural_movie_one' : cm.Greens(0.6, 0.3),
    'natural_movie_one_shuffled' : cm.Purples(0.6, 0.1),
    'Spontaneous' : cm.Greys(0.6, 0.3),
    'Artificial (simple)' : cm.Reds(0.6, 0.3),
    'Artificial (complex)' : cm.Blues(0.8, 0.3),
    'Natural' : cm.Greens(0.8, 0.3),
    'spontaneous' : cm.Greys(0.5, 0.5),
    'simple' : cm.Reds(0.5, 0.5),
    'complex' : cm.Blues(0.5, 0.5),
    'natural' : cm.Greens(0.5, 0.5),
    'shuffled' : cm.Purples(0.4, 0.5)
}

stim_colors = {
    'spontaneous' : cm.Greys(0.3, 0.8),
    'gabors' : cm.Reds(0.3, 0.8),
    'flashes' : cm.Reds(0.8, 0.8),
    'drifting_gratings' : cm.Blues(0.5, 0.8),
    'drifting_gratings_contrast' : cm.Blues(0.7, 0.8),
    'static_gratings' : cm.Blues(0.9, 0.8),
    'natural_movie_three' : cm.Greens(0.75, 0.8),
    'natural_movie_two' : cm.Greens(0.6, 0.8),
    'natural_movie_one' : cm.Greens(0.45, 0.8),
    'natural_movie_one_shuffled' : cm.Purples(0.3, 0.8),
    'Spontaneous' : cm.Greys(0.6, 0.8),
    'Artificial (simple)' : cm.Reds(0.6, 0.8),
    'Artificial (complex)' : cm.Blues(0.8, 0.8),
    'Natural' : cm.Greens(0.8, 0.8),
    'spontaneous' : cm.Greys(0.8, 0.8),
    'simple' : cm.Reds(0.8, 0.8),
    'complex' : cm.Blues(0.8, 0.8),
    'natural' : cm.Greens(0.8, 0.8),
    'shuffled' : cm.Purples(0.4, 0.8)
}

stim_cat_colors = {
    'spontaneous' : cm.Greys(0.8, 0.8),
    'simple' : cm.Reds(0.8, 0.8),
    'complex' : cm.Blues(0.8, 0.8),
    'natural' : cm.Greens(0.8, 0.8),
    'shuffled' : cm.Purples(0.4, 0.8)
    
}

stim_cat_colors_bg = {
    'spontaneous' : cm.Greys(0.5, 0.5),
    'simple' : cm.Reds(0.5, 0.5),
    'complex' : cm.Blues(0.5, 0.5),
    'natural' : cm.Greens(0.5, 0.5),
    'shuffled' : cm.Purples(0.4, 0.5)
    
}

stimulus_categories = {
    'drifting_gratings' : 'complex',
    'drifting_gratings_contrast' : 'simple',
    'flashes' : 'simple',
    'gabors' : 'simple',
    'natural_movie_one_shuffled' : 'shuffled',
    'natural_movies' : 'natural',
    'natural_movie_one' : 'natural',
    'natural_movie_three' : 'natural',
    'spontaneous' : 'spontaneous',
    'static_gratings' : 'complex'
}

stim_by_putative_meaning = {
    'stimulus_name' : [
        'spontaneous', 'natural_movie_one_shuffled', 'flashes', 'gabors',
        'drifting_gratings_contrast', 'drifting_gratings', 'static_gratings',
        'natural_movie_one', 'natural_movie_three', 'natural_movies'
    ],
    'stimulus_category' : [
        'spontaneous', 'simple', 'shuffled', 'complex', 'natural'
    ]
}

In [4]:
session_ids = [
    path.basename(x)
    .strip('.pkl')
    .strip('fr_') for x in glob(
        path.join(data_directory, 'fr_*')
    )
]

In [5]:
# ensembles of neurons to apply differentiation to
rs = True
unit_set = []
unit_set  = [
    f'region == "%s" & snr > 2.5 & RS == {rs}'%reg for reg in [
        'VISp', 'VISl', 'VISal', 'VISam', 'VISpm', 'VISrl', 'LGd', 'LP'
    ]
]
unit_set += [
    f'region in @region_sets.get("%s") & snr > 2.5 & RS == {rs}'%s for s in region_sets.keys()
]
unit_set += [
    f'layer == "%s" & region == "%s" & snr > 2.5 & RS == {rs}'%(l, r) for l, r in itertools.product(
        ['L2/3', 'L4', 'L5', 'L6'], ['VISp', 'VISl', 'VISal', 'VISpm', 'VISam', 'VISrl']
    )
]
unit_set += [
    f'layer == "%s" & region in @region_sets.get("%s") & snr > 2.5 & RS == {rs}'%(l, r) for l, r in itertools.product(
        ['L2/3', 'L4', 'L5', 'L6'], list(region_sets.keys())[:2]
    )
]

In [6]:
aggr_interval_s = 0.25 # interval over which spikes are aggregated to define state

In [7]:
def load_fr(session):
    return pd.read_pickle(
        path.join(data_directory, f'fr_{session}.pkl')
    )

def load_units(session):
    return pd.read_pickle(
        path.join(data_directory, f'units_{session}.pkl')
    )

def load_stimulus_table(session):
    return pd.read_pickle(
        path.join(data_directory, f'stimulus_{session}.pkl')
    )

def load_running(session):
    return pd.read_pickle(
        path.join(data_directory, f'running_{session}.pkl')
    )

def get_unit_filters(units):
    filter_strings = units.split(' & ')
    filters = {}
    for filt in filter_strings:
        if ' = ' in filt:
            filters[filt.split(' = ')[0]] = filt.split(' = ')[1].strip('""')
        if '==' in filt:
            filters[filt.split(' == ')[0]] = filt.split(' == ')[1].strip('""')
        if '>' in filt:
            filters[filt.split(' > ')[0]] = float(filt.split(' > ')[1])
        if '@' in filt:
            key = filt.split(' ')[0]
            value = filt.split('get')[1].split('"')[1]
            filters[key] = value
    return filters

# Single session UMAP

In [8]:
# load firing rates for a session
session = session_ids[0]
units = load_units(session)
firing_rates = load_fr(session)[
    units[(units.snr>2.5)&(
        units.region.isin(relevant_regions+['DG', 'CA1', 'CA2', 'CA3'])
    )].index
]

In [9]:
n_steps = int(aggr_interval_s/np.diff(firing_rates.index).mean())
firing_rates = firing_rates.rolling(n_steps, center=True).mean()
firing_rates = firing_rates[int(n_steps/2)::n_steps].dropna()

In [10]:
# most neurons have very few unique firing rates
n_unique_states_per_neuron = firing_rates.nunique()
f, ax = plt.subplots(figsize=(3, 2.4), tight_layout=True)
n_unique_states_per_neuron.plot.hist(ax=ax, bins=100, log=True)
ax.set_xlabel('# unique FR per neuron');

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [11]:
# add stimulus information to the index
stim_table = load_stimulus_table(session)
firing_rates.index = pd.MultiIndex.from_frame(
    pd.MultiIndex.from_frame(stim_table)
    .drop_duplicates()
    .to_frame(index=False)
    .set_index('time')
    .reindex(
        firing_rates.index, method='ffill'
    ).rename_axis('time').reset_index().bfill()
)

idx = firing_rates.index.to_frame()
idx['stimulus_name'] = idx.stimulus_name.map(
    lambda x: x if x!='drifting_gratings_75_repeats' else 'drifting_gratings'
)
idx['stimulus_name'] = idx.stimulus_name.map(
    lambda x: x if x!='natural_movie_one_more_repeats' else 'natural_movie_one'
)
firing_rates.index = pd.MultiIndex.from_frame(idx)
firing_rates = firing_rates.drop(
    ['natural_scenes', 'dot_motion'], level='stimulus_name', errors='ignore'
)
firing_rates.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,51,52,56,62,63,65,67,75,83,88,...,2423,2428,2429,2430,2431,2432,2435,2443,2448,2453
time,stimulus_name,block,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
25.004997,spontaneous,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.88,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25.254997,spontaneous,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,3.88,0.0,0.0,0.0
25.504997,spontaneous,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25.754997,spontaneous,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,3.88,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
26.004997,spontaneous,-1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,19.62,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


## # states on coarse-grained firing rates

### Basic implementation to count states

In [34]:
# let us look at the different discrete states in 5 s window for one set of units
df = firing_rates.iloc[20:40][[715, 715]]#[units.index[units.eval(unit_set[0])]]
display(df)
print(f'{df.astype(str).agg("-".join, axis=1).nunique()} unique rows found: ', end='')
print(df.astype(str).agg("-".join, axis=1).unique())

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,715,715
time,stimulus_name,block,Unnamed: 3_level_1,Unnamed: 4_level_1
30.004996,spontaneous,-1.0,3.86,3.86
30.254996,spontaneous,-1.0,0.0,0.0
30.504996,spontaneous,-1.0,3.88,3.88
30.754996,spontaneous,-1.0,3.88,3.88
31.004996,spontaneous,-1.0,0.0,0.0
31.254996,spontaneous,-1.0,0.0,0.0
31.504996,spontaneous,-1.0,0.0,0.0
31.754996,spontaneous,-1.0,0.0,0.0
32.004996,spontaneous,-1.0,0.0,0.0
32.254996,spontaneous,-1.0,0.0,0.0


3 unique rows found:['3.86-3.86' '0.0-0.0' '3.88-3.88']


### let us now coarse-grain the firing rates and count states including all VISp units

---

In [283]:
df = firing_rates[units.index[units.eval(unit_set[8])]]

# there will be some units that fire very rarely and spuriously
# and it would be good to not count them in the state
# so we will drop units that have very low firing rate
# and very low variance
f, ax = plt.subplots(figsize=(3, 2.4), tight_layout=True)
sns.ecdfplot(df.mean().rename('mean firing rate'), ax=ax)
ax.axvline(2, c='k', lw=0.5);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [284]:
# also restrict to static_gratings for this example
df = df[df.columns[df.mean()>4]]
df = df.xs('gabors', level='stimulus_name')

# binarize into firing or non-firing
# df = (df > df.mean()).astype(int)
df = (df > df.mean()*2).astype(int)

# group time into 5 s windows
idx = df.index.to_frame()
idx['window'] = np.array(np.arange(len(idx))/20, dtype=int)
df.index = pd.MultiIndex.from_frame(idx)

df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,97,99,103,104,108,113,270,292,303,306,...,1081,1215,1421,1432,1434,1441,2184,2191,2386,2397
time,block,window,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1
88.254989,0.0,0,0,1,0,1,0,0,0,0,0,0,...,0,0,1,0,0,0,1,0,0,0
88.504989,0.0,0,0,1,1,1,1,0,0,0,0,0,...,0,0,1,0,0,0,0,1,1,0
88.754989,0.0,0,1,1,1,0,0,0,0,1,0,1,...,0,0,1,1,1,0,0,1,1,0
89.004989,0.0,0,0,1,1,0,0,0,0,0,0,0,...,0,0,1,0,0,0,0,0,1,1
89.254989,0.0,0,0,1,0,0,0,0,0,1,0,1,...,1,1,1,1,1,1,0,1,1,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
999.004873,0.0,182,1,0,0,0,1,0,0,1,0,0,...,0,1,0,0,0,0,0,0,1,0
999.254872,0.0,182,1,0,0,0,0,0,0,0,0,1,...,0,1,0,0,0,1,0,1,0,0
999.504872,0.0,182,1,0,0,0,0,0,0,0,0,1,...,0,1,0,0,0,0,0,0,0,0
999.754872,0.0,182,1,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


In [285]:
n_discrete_state = df.groupby('window').apply(
    lambda _df: _df.astype(str).agg("-".join, axis=1).nunique()
)

n_discrete_state.groupby(n_discrete_state).size()

8      1
14     4
15     4
16     8
17    19
18    31
19    35
20    81
dtype: int64

---

In [136]:
# another approach is to compute pairwise manhattan distances
# and threshold on the distance (a few steps distance is not 'real')
# to calculate the number of states
df = firing_rates[units.index[units.eval(unit_set[0])]]
df = (df>4).astype(int)

_dists = sp.spatial.distance.pdist(df.iloc[:20], metric='cityblock')

In [142]:
f, ax = plt.subplots(figsize=(3, 2.4), tight_layout=True)
ax.hist(_dists, bins=20)
ax.set_xlabel('# non-matched units')
ax.set_ylabel('count')
print(f'{(_dists<=2).sum()} pairs have identical firing patterns (upto 2 differences).')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

29 pairs have identical firing patterns (upto 2 differences).


### To be continued further...

## Compute # states for all stimuli

In [290]:
def count_n_states(df):
    # binarize the firing rates
    df = (df > 2).astype(int)
    
    # window the responses
    idx = df.index.to_frame()
    idx['window'] = np.array(np.arange(len(idx))/20, dtype=int)
    df.index = pd.MultiIndex.from_frame(idx)
    
    df = df[df.mean().sort_values()[-30:].index]
    
    # compute number of states in each window
    return df.groupby('window').apply(
        lambda _df: _df.astype(str).agg("-".join, axis=1).nunique()
    )[:-1]

In [293]:
df = firing_rates[units.index[units.eval(unit_set[8])]]
df = df[df.columns[df.mean()>2]]
n_states = df.groupby('stimulus_name').apply(count_n_states)
n_bits = np.log2(n_states)

In [294]:
f, ax = plt.subplots(figsize=(5, 3.5), tight_layout=True)
sns.histplot(
    data=n_bits.rename('n_bits').reset_index(), hue='stimulus_name', x='n_bits',
    element='poly', multiple='dodge', fill=False, bins=10
);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [295]:
n_bits_full = {}
for uf in tqdm(unit_set):
    df = firing_rates[units.index[units.eval(uf)]]
    df = df[df.columns[df.mean()>5]]
    n_states = df.groupby('stimulus_name').apply(count_n_states)
    ufstr = get_unit_filters(uf)
    n_bits_full[(ufstr["region"], ufstr.get("layer", "all"))] = np.log2(n_states)
n_bits_full = pd.concat(n_bits_full)

HBox(children=(IntProgress(value=0, max=45), HTML(value='')))




In [296]:
n_bits_melted = n_bits_full.rename_axis(index=['area', 'layer', 'stimulus_name', 'window']).rename('n_bits').reset_index()
n_bits_melted['stimulus_category'] = n_bits_melted.stimulus_name.map(stimulus_categories)
n_bits_melted = n_bits_melted[n_bits_melted.layer=='all']
n_bits_melted

Unnamed: 0,area,layer,stimulus_name,window,n_bits,stimulus_category
0,VISp,all,drifting_gratings,0,3.459432,complex
1,VISp,all,drifting_gratings,1,3.459432,complex
2,VISp,all,drifting_gratings,2,3.459432,complex
3,VISp,all,drifting_gratings,3,3.321928,complex
4,VISp,all,drifting_gratings,4,3.321928,complex
...,...,...,...,...,...,...
11292,hipp,all,static_gratings,145,5.087463,complex
11293,hipp,all,static_gratings,146,4.954196,complex
11294,hipp,all,static_gratings,147,5.169925,complex
11295,hipp,all,static_gratings,148,5.169925,complex


In [298]:
palette = {
    k:v for k, v in stim_colors_bg.items(
    ) if k in n_bits_melted.stimulus_category.unique()
}
hue_order = palette.keys()
area_order = relevant_regions+['HVAs', 'VisCtx', 'hipp']
f, ax = plt.subplots(figsize=(9, 3), tight_layout=True)
sns.boxplot(
    data=n_bits_melted[n_bits_melted.area.isin(relevant_regions)],
    x='area', hue='stimulus_category', y='n_bits', ax=ax,
    linewidth=0.2, order=area_order, hue_order=hue_order,
    showfliers=False, palette=palette
)
ax.legend(fontsize=9, loc=0);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

## UMAP

In [112]:
# choose and ensemble and run UMAP on it
ufs = [0, 3, 6, 9, 12, 8, 37, 39]
embeddings = {}
f, axes = plt.subplots(len(ufs), 1, figsize=(4, 2.5*len(ufs)), tight_layout=True)
for i, ax in zip(ufs, tqdm(axes)):
    ax.set_ylabel(
        f'{get_unit_filters(uf)["region"]} '
        f'{get_unit_filters(uf).get("layer", "all")}', fontsize=12
    )
    ax.tick_params(labelbottom=False, labelleft=False)
    
    uf = unit_set[i]
    _fr = firing_rates[units.index[units.eval(uf)]]

    # umap for FRs
    ur = umap.UMAP(n_components=2, n_neighbors=5, min_dist=0.00005)
    try:
        ur.fit(_fr)
    except:
        continue
    _u1, _u2 = ur.embedding_.T
    embeddings[uf] = (_u1, _u2)

#     f, ax = plt.subplots(figsize=(5, 3), tight_layout=True)
    ax.scatter(
        _u1, _u2, s=0.2, c=_fr.index.get_level_values('stimulus_name').map(stim_colors_bg)
    )
patches = [
    mpl.patches.Patch(
        color=c, label=s#.replace('_', '\n')
    ) for s, c in stim_colors_bg.items()
]
axes[0].legend(handles=patches[:-4], fontsize=5, loc=(1.01, 0))

embeddings = {
    (
        f'{get_unit_filters(k)["region"]}_'
        f'{get_unit_filters(k).get("layer", "all")}'
    ) : pd.DataFrame(
        v, index=['_u1', '_u2'],
        columns=firing_rates.index
    ).T for k, v in embeddings.items()
}
embeddings = pd.concat(embeddings, axis=0)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

In [113]:
embeddings

Unnamed: 0_level_0,Unnamed: 1_level_0,region,VISp_all,VISp_all,VISam_all,VISam_all,HVAs_all,HVAs_all,hipp_all,hipp_all,VisCtx_all,VisCtx_all,VisCtx_L2/3,VisCtx_L2/3,VisCtx_L4,VisCtx_L4
Unnamed: 0_level_1,Unnamed: 1_level_1,dimension,_u1,_u2,_u1,_u2,_u1,_u2,_u1,_u2,_u1,_u2,_u1,_u2,_u1,_u2
time,stimulus_name,block,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2,Unnamed: 14_level_2,Unnamed: 15_level_2,Unnamed: 16_level_2
24.000001,spontaneous,-1.0,-1.879092,7.027840,4.828933,7.944551,0.674857,5.588076,-0.408944,0.485752,12.733475,7.922928,-2.010412,8.359682,-2.960763,5.916942
24.245001,spontaneous,-1.0,-2.164254,6.953552,5.955440,8.068870,0.672025,5.583872,-0.927814,1.359087,12.723401,7.949181,-1.854814,8.728300,-2.584673,6.242374
24.490001,spontaneous,-1.0,-2.453646,9.305155,2.146131,6.276027,-1.227550,-1.080440,0.233207,5.265184,12.863957,5.083230,-0.175271,7.034948,-2.003042,3.054516
24.735001,spontaneous,-1.0,-1.821368,8.452673,-0.592142,4.669295,-1.224704,-1.070249,-0.084242,1.094926,12.749323,5.150882,1.358155,7.901293,-2.019891,3.050668
24.980001,spontaneous,-1.0,-2.348502,10.576662,1.557855,6.345692,-1.225630,-1.080211,0.424295,0.360404,12.776420,5.112898,-0.050133,6.007826,-2.038952,2.997816
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10519.310631,drifting_gratings_contrast,15.0,-1.551568,-0.462221,7.583153,6.691725,10.896751,0.670862,11.231109,4.696738,-1.298108,9.029122,3.304136,2.094615,8.661822,0.958105
10519.555631,drifting_gratings_contrast,15.0,2.220345,1.268372,1.274056,2.205120,6.871519,-1.880029,5.525356,-1.331730,2.929660,5.835488,-0.232901,2.197772,3.880001,-2.131428
10519.800631,drifting_gratings_contrast,15.0,2.352018,1.514933,0.973367,1.050737,6.929215,-2.012832,5.529591,-1.335747,2.952611,5.855663,-0.427489,1.953249,3.855363,-2.342837
10520.045631,drifting_gratings_contrast,15.0,2.355478,1.505100,0.973786,1.051597,6.931222,-2.018980,5.530997,-1.335248,2.949387,5.850908,-0.429134,1.952536,3.855862,-2.342337


In [129]:
sss = (
    embeddings.stack('region').groupby(['region', 'stimulus_name'])
    .apply(lambda df: ((df - df.mean())**2).sum(1).median())
)
sss = sss.unstack('stimulus_name')
sss.index = pd.MultiIndex.from_tuples(
    sss.index.map(lambda x: tuple(x.split('_')))
)

In [133]:
sss

Unnamed: 0,stimulus_name,drifting_gratings,drifting_gratings_contrast,flashes,gabors,natural_movie_one,natural_movie_three,spontaneous,static_gratings
HVAs,all,5.950599,4.101438,5.821878,2.25842,17.338629,3.393766,15.687114,3.857556
VISam,all,8.391865,9.035812,9.483674,5.351234,9.993698,6.591458,5.820198,7.982512
VISp,all,16.935859,3.462924,3.84819,1.887063,6.8246,2.812909,6.03902,5.235059
VisCtx,L2/3,9.503614,3.116247,3.803986,2.319319,10.19624,4.807827,5.572714,4.726826
VisCtx,L4,4.397786,4.045384,6.619796,2.599113,16.547688,4.103304,10.799205,5.014759
VisCtx,all,8.685151,4.403894,4.328388,1.052668,9.175288,2.260832,17.975956,3.153235
hipp,all,9.874713,7.13371,3.688223,1.675609,9.75799,9.09615,11.751623,5.379917


# All sessions

In [234]:
def compute_embeddings(session, unit_set, n_components=2):
    # load firing rates for a session
    units = load_units(session)
    firing_rates = load_fr(session)[
        units[(units.snr>2.5)&(
            units.region.isin(relevant_regions+['DG', 'CA1', 'CA2', 'CA3'])
        )].index
    ]
    
    n_steps = int(aggr_interval_s/np.diff(firing_rates.index).mean())
    firing_rates = firing_rates.rolling(n_steps, center=True).mean()
    firing_rates = firing_rates[int(n_steps/2)::n_steps].dropna()
    
    # add stimulus information to the index
    stim_table = load_stimulus_table(session)
    firing_rates.index = pd.MultiIndex.from_frame(
        pd.MultiIndex.from_frame(stim_table)
        .drop_duplicates()
        .to_frame(index=False)
        .set_index('time')
        .reindex(
            firing_rates.index, method='ffill'
        ).rename_axis('time').reset_index().bfill()
    )

    idx = firing_rates.index.to_frame()
    idx['stimulus_name'] = idx.stimulus_name.map(
        lambda x: x if x!='drifting_gratings_75_repeats' else 'drifting_gratings'
    )
    idx['stimulus_name'] = idx.stimulus_name.map(
        lambda x: x if x!='natural_movie_one_more_repeats' else 'natural_movie_one'
    )
    firing_rates.index = pd.MultiIndex.from_frame(idx)
    firing_rates = firing_rates.drop(
        ['natural_scenes', 'dot_motion'], level='stimulus_name', errors='ignore'
    )
    
    embeddings = {}
    for uf in tqdm(unit_set, desc=session):
        _fr = firing_rates[units.index[units.eval(uf)]]

        # umap for FRs
        ur = umap.UMAP(n_components=n_components, n_neighbors=5, min_dist=0.00005)
        try:
            with capture_output():
                ur.fit(_fr)
        except:
            continue
        embeddings[uf] = ur.embedding_.T
    
    embeddings = {
        (
            f'{get_unit_filters(k)["region"]}_'
            f'{get_unit_filters(k).get("layer", "all")}'
        ) : pd.DataFrame(
            v, columns=firing_rates.index
        ).T for k, v in embeddings.items()
    }
    embeddings = pd.concat(embeddings, axis=0)
    return embeddings

In [235]:
n_components = 3
embeddings = {}
for session in tqdm(session_ids[:30]):
    for n_components in [2, 3, 8]:
        fn = path.join(data_directory, f'embeddings_{n_components}d_{session}.pkl')
        if path.exists(fn):
            embeddings[(session, n_components)] = pd.read_pickle(fn)
        else:
            embeddings[(session, n_components)] = compute_embeddings(session, unit_set)
            embeddings[(session, n_components)].to_pickle(fn)
embeddings = pd.concat(embeddings)
idx = embeddings.index.to_frame()
idx['stimulus_category'] = idx.stimulus_name.map(stimulus_categories)
embeddings.index = pd.MultiIndex.from_frame(idx)
embeddings.rename_axis(index={0:'session', 1:'n_components', 2:'area'}, inplace=True)

HBox(children=(IntProgress(value=0, max=30), HTML(value='')))




In [236]:
sss = (
    embeddings.groupby(level=['session', 'area', 'stimulus_category'])
    .apply(lambda df: ((df - df.mean())**2).sum(1).median())
)
sss = sss.unstack('stimulus_category')
sss.index = pd.MultiIndex.from_tuples(
    sss.index.to_frame().apply(
        lambda x: tuple([x[0]]+x[1].split('_')), axis=1
    ).values
)
sss

Unnamed: 0,Unnamed: 1,stimulus_category,complex,natural,shuffled,simple,spontaneous
719161530,AllVis,all,38.906452,22.782698,,59.347466,55.436737
719161530,HVAs,L2/3,9.906996,8.165577,,4.617380,20.950821
719161530,HVAs,L4,32.889069,13.910753,,1.938536,31.626175
719161530,HVAs,L5,19.992950,19.290602,,31.693428,26.270130
719161530,HVAs,L6,53.129005,76.085052,,62.185127,49.156952
...,...,...,...,...,...,...,...
847657808,VisCtx,L4,34.477982,22.996954,52.217041,39.636559,42.982311
847657808,VisCtx,L5,10.022230,12.974141,1.413202,5.448433,5.105828
847657808,VisCtx,L6,29.150499,31.603088,13.830759,59.418831,45.111977
847657808,VisCtx,all,34.952847,36.761467,55.381035,67.711594,66.792542


In [237]:
area_order = [x for x in hierarchy.keys() if x in sss.index.levels[1]]
area_order

['LGd',
 'LP',
 'THx',
 'VISp',
 'VISl',
 'VISrl',
 'VISal',
 'VISpm',
 'VISam',
 'HVAs',
 'VisCtx',
 'AllVis',
 'hipp']

In [238]:
f, ax = plt.subplots(figsize=(6, 2.4), tight_layout=True)
sss.groupby(level=[1, 2]).mean().xs('all', level=1).loc[area_order].apply(
    lambda c: c.plot(
        marker='o', ax=ax, color=stim_colors[c.name], label=c.name,
        yerr=sss.groupby(level=[1, 2]).std().xs('all', level=1).loc[area_order, c.name]
    )
)
ax.set_xticks(range(len(area_order)))
ax.set_xticklabels(area_order, fontsize=9)
ax.legend(loc=0, fontsize=7, ncol=2)
ax.set_ylabel(f'spread in\n{n_components} dimensions',fontsize=9)
ax.tick_params(labelsize=8);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [239]:
hue_order = ['spontaneous', 'simple', 'complex', 'natural', 'shuffled']
palette = {s:stim_colors_bg[s] for s in hue_order}
f, ax = plt.subplots(figsize=(8, 3), tight_layout=True)
sns.boxplot(
    data=(
        sss.xs('all', level=2).rename_axis(index=['session', 'region'])
        .stack().rename('spread').reset_index()
    ), x='region', y='spread', hue='stimulus_category', ax=ax, linewidth=0.2,
    order=area_order, hue_order=hue_order, showfliers=False, palette=palette
)
ax.legend(fontsize=9);

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …