In this notebook, we will download the Neuropixels data using the Allen SDK,  
and extract and save units metadata, spike times, behavioral data (running speed).  
We also convolve the spike times with a Gaussian kernel to obtain firing rates.

# Accessing Neuropixels Visual Coding Data

In [1]:
from os import path

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

%matplotlib widget

In [2]:
layer_depths = {
    'L1' : 100,
    'L2/3' : 210,
    'L4' : 120,
    'L5' : 220,
    'L6' : 200,
}

ctx_regions = ['VISp', 'VISl', 'VISrl', 'VISal', 'VISpm', 'VISam']

In [3]:
src_directory = '/home/saurabh.gandhi/Projects/tiny-blue-dot/differentiation/refactor/tmp'# must be updated to a valid directory in your filesystem
data_directory = '/allen/programs/braintv/workgroups/tiny-blue-dot/differentiation/refactor/data'

manifest_path = path.join(src_directory, "manifest.json")
cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)
sessions = cache.get_session_table()

print('Total number of sessions: ' + str(len(sessions)))
sessions.head()

Total number of sessions: 58


Unnamed: 0_level_0,published_at,specimen_id,session_type,age_in_days,sex,full_genotype,unit_count,channel_count,probe_count,ecephys_structure_acronyms
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
715093703,2019-10-03T00:00:00Z,699733581,brain_observatory_1.1,118.0,M,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,884,2219,6,"[CA1, VISrl, nan, PO, LP, LGd, CA3, DG, VISl, ..."
719161530,2019-10-03T00:00:00Z,703279284,brain_observatory_1.1,122.0,M,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,755,2214,6,"[TH, Eth, APN, POL, LP, DG, CA1, VISpm, nan, N..."
721123822,2019-10-03T00:00:00Z,707296982,brain_observatory_1.1,125.0,M,Pvalb-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,444,2229,6,"[MB, SCig, PPT, NOT, DG, CA1, VISam, nan, LP, ..."
732592105,2019-10-03T00:00:00Z,717038288,brain_observatory_1.1,100.0,M,wt/wt,824,1847,5,"[grey, VISpm, nan, VISp, VISl, VISal, VISrl]"
737581020,2019-10-03T00:00:00Z,718643567,brain_observatory_1.1,108.0,M,wt/wt,568,2218,6,"[grey, VISmma, nan, VISpm, VISp, VISl, VISrl]"


In [4]:
filtered_sessions = sessions # [(sessions.full_genotype.str.find('wt/wt') > -1)]

filtered_sessions.head()

Unnamed: 0_level_0,published_at,specimen_id,session_type,age_in_days,sex,full_genotype,unit_count,channel_count,probe_count,ecephys_structure_acronyms
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
715093703,2019-10-03T00:00:00Z,699733581,brain_observatory_1.1,118.0,M,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,884,2219,6,"[CA1, VISrl, nan, PO, LP, LGd, CA3, DG, VISl, ..."
719161530,2019-10-03T00:00:00Z,703279284,brain_observatory_1.1,122.0,M,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,755,2214,6,"[TH, Eth, APN, POL, LP, DG, CA1, VISpm, nan, N..."
721123822,2019-10-03T00:00:00Z,707296982,brain_observatory_1.1,125.0,M,Pvalb-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,444,2229,6,"[MB, SCig, PPT, NOT, DG, CA1, VISam, nan, LP, ..."
732592105,2019-10-03T00:00:00Z,717038288,brain_observatory_1.1,100.0,M,wt/wt,824,1847,5,"[grey, VISpm, nan, VISp, VISl, VISal, VISrl]"
737581020,2019-10-03T00:00:00Z,718643567,brain_observatory_1.1,108.0,M,wt/wt,568,2218,6,"[grey, VISmma, nan, VISpm, VISp, VISl, VISrl]"


In [5]:
# # download filtered session data
# 
# for session_index in tqdm(filtered_sessions.index.values):
#     cache.get_session_data(
#         session_index, isi_violations_maximum = np.inf,
#         amplitude_cutoff_maximum = np.inf,
#         presence_ratio_minimum = -np.inf
#     )

# Read, reformat and export session-wise data

In [6]:
import h5py # required for reading the nwb file

attr_names = { # this is required because some attribute names changed from one nwb version to another
    'id' : 'id',
    'channel' : 'channel',
    'snr' : 'snr',
    'd_prime' : 'd_prime',
    'isi_viol' : 'isi_viol',
    'region' : 'ccf_structure',
    'probe' : 'probe',
    'hpos' : 'xpos_probe',
    'vpos' : 'ypos_probe',
}
attr_names_alt = {
    'channel' : 'peak_channel_id',
    'isi_viol' : 'isi_violations',
    'region' : 'id',
    'probe' : 'id',
    'hpos' : 'id',
    'vpos' : 'id',
}

def _read_activity(src_file):
    """Read units, unit metadata and spiking rate data from source file"""
    
    units = pd.DataFrame()
    spike_times = {}
    
    with h5py.File(src_file, mode='r') as sf:
        if 'nwb_version' not in sf.keys(): # session_format nwb
#             print('Old nwb version.')
            attrs = [
                'channel', 'id', 'snr', 'd_prime', 'isi_viol',
                'region', 'probe', 'hpos', 'vpos'
            ]
            units = pd.DataFrame(columns=attrs)
            for attr in attrs:
                try:
                    units[attr] = sf['units'][attr_names[attr]]
                except KeyError:
                    units[attr] = sf['units'][attr_names_alt[attr]]
            units['RS'] = sf['units']['waveform_duration'][:]>0.4
            grp_probe = sf['general/extracellular_ephys/electrodes']
            cols = ['hpos', 'vpos', 'probe', 'region']
            probes = pd.DataFrame(
                index=grp_probe['id'][:],
                columns=cols,
            )
            probes['hpos'] = grp_probe['probe_horizontal_position'][:]
            probes['vpos'] = grp_probe['probe_vertical_position'][:]
            probes['probe'] = grp_probe['probe_id'][:].astype(str)
            probes['region'] = grp_probe['location'][:].astype(str)
            df = probes.loc[units['channel'], cols]
            df.index = units.index
            units[cols] = df
            
            spi = sf['units/spike_times_index'][:]
            sido = 0
            stimes = sf['units/spike_times'][:]
            for i in range(len(units.index)):
                uid = units.index[i]
                sid = spi[i]
                spike_times[uid] = stimes[sido:sid]
                sido = sid
            
            stims = sf['intervals/'].keys()
            stim_table = []
            for stim in stims:
                if stim=='invalid_times':
                    continue
                start_times = sf[f'intervals/{stim}/start_time'][:]
                stimulus_names = sf[f'intervals/{stim}/stimulus_name'][:]
                try:
                    stimulus_blocks = sf[f'intervals/{stim}/stimulus_block'][:]
                except:
                    stimulus_blocks = -np.ones(len(start_times))
                st = pd.DataFrame(
                    data=[start_times, stimulus_names, stimulus_blocks],
                    index=['time', 'stimulus_name', 'block']
                ).T
                stim_table.append(st)
            stim_table = pd.concat(stim_table).sort_values('time').reset_index(drop=True)
            
            running = pd.Series(
                data=sf['processing/running/running_speed/data'][:],
                index=sf['processing/running/running_speed_end_times/timestamps'][:],
                name='running_speed'
            ).rename_axis('times').reset_index()
            
            return units, spike_times, running, stim_table

def get_firing_rates(
    spike_times, sampling_rate=200,
    win=np.exp(-(np.arange(11)-5)**2/4)
):
    maxtime = max([
        st.max() if len(st)>0 else 0 for st in spike_times.values()
    ]) + 1
    n_units = len(spike_times)
    data = np.zeros((
        n_units, np.rint(maxtime*sampling_rate).astype(int)
    ), dtype='uint8')
    
    for i, st in enumerate(spike_times.values()):
        st_int = np.array(st*sampling_rate, dtype=int)
#         st_int = st_int[st_int<maxtime-1]
        fr = np.zeros(
            np.rint(maxtime*sampling_rate).astype(int),
            dtype='uint8'
        )
        fr[st_int] = 1
        data[i] = (sampling_rate/win.sum()*np.convolve(
            fr, win, mode='same'
        )).astype('uint8')
    return data, np.linspace(
        0, maxtime,
        np.rint(maxtime*sampling_rate).astype(int),
        endpoint=False
    )

---

In [31]:
# how often do we get more than one spike in a 5 ms interval?
means = []
for session_index in tqdm(filtered_sessions.index.values):
# session_index = filtered_sessions.index.values[2]
    source = path.join(src_directory, f'session_{session_index}', f'session_{session_index}.nwb')
    units, spike_times, running, stim_table = _read_activity(source)

    unit_ids = units[(units.RS)&(units.snr>2.5)&(units.region.isin(ctx_regions))].index
    n_viol = []
    for spt in unit_ids:
        dt = np.diff(spike_times[spt])
        n_viol.append(len(dt[dt<0.005])/len(dt)*100)
    print(f'{session_index}: {np.mean(n_viol):.2f} +- {np.std(n_viol):.2f} % spikes violate 5 ms interval')
    means.append(np.mean(n_viol))
#     f, ax = plt.subplots(figsize=(4, 3))
#     ax.hist(n_viol, bins=20)
#     ax.axvline(np.mean(n_viol), color='k', lw=0.5)
#     ax.annotate('mean across cortical neurons', (2, 60))
#     ax.set_ylabel('# neurons')
#     ax.set_xlabel('% spikes violating 5 ms limit')
#     ax.set_title(f'session {session_index}');
print(f'{np.mean(means):.2f} +- {np.std(means):.2f} % of spikes per neuron violate 5 ms condition.')
f, ax = plt.subplots(figsize=(4, 3))
ax.hist(means, bins=20)
ax.axvline(np.mean(means), color='k', lw=0.5)
ax.annotate('mean across\nsessions', (2.2, 8))
ax.set_ylabel('# sessions')
ax.set_xlabel('mean % of violations across cortical neurons');

---

In [8]:
for session_index in tqdm(filtered_sessions.index.values):
    if path.exists(f'{data_directory}/fr_{session_index}.pkl'):
        continue
    source = path.join(src_directory, f'session_{session_index}', f'session_{session_index}.nwb')
    units, spike_times, running, stim_table = _read_activity(source)
    layers = units.groupby('region').apply(
        lambda _df: assign_approx_layers(_df) if _df.name in ctx_regions else None
    ).dropna().droplevel(0).rename('layer')
    units = units.join(layers)
    
    fr, times = get_firing_rates(spike_times)
    fr = pd.DataFrame(fr.T, index=times)
    
    units.to_pickle(
        f'{data_directory}/units_{session_index}.pkl'
    )
    
    running.to_pickle(
        f'{data_directory}/running_{session_index}.pkl'
    )
    
    stim_table.to_pickle(
        f'{data_directory}/stimulus_{session_index}.pkl'
    )
    
    fr.to_pickle(
        f'{data_directory}/fr_{session_index}.pkl'
    )

HBox(children=(IntProgress(value=0, max=58), HTML(value='')))




In [59]:
# def process_units(df, session=None):
#     layers = df.groupby('ecephys_structure_acronym').apply(
#         lambda _df: assign_approx_layers(_df) if _df.name in ctx_regions else None
#     ).dropna().droplevel(0).rename('layer')
#     df = df.join(layers)
    
#     ct = pd.cut(
#         df.waveform_duration, bins=[0, 0.4, 100],
#         right=False, include_lowest=True, labels=[False, True]
#     ).rename('RS')
#     df = df.join(ct.astype(bool))
    
#     df = df.rename({'ecephys_structure_acronym':'region'}, axis=1)
#     if session is not None:
#         df.to_pickle(f'{data_directory}/units_{session}.pkl')
#         return 1
#     return df

# units = cache.get_units()
# units.groupby('ecephys_session_id').apply(lambda df: process_units(df, df.name))

# up = process_units(units[units.ecephys_session_id==session_index])
# up

# Exploring a single-session dataset

In [56]:
with h5py.File(source, mode='r') as sf:
    display(sf['general'].keys())
#     display(set(sf['general/extracellular_ephys/electrodes/probe_vertical_position'][:]))
#     display(sf.keys())
#     print('')
#     display(sf['units'].keys())
#     print('')
#     display(sf['units/waveform_duration'][:])
#     halfwidths = sf['units/waveform_duration'][:]
#     print('')
#     display(sf['units/peak_channel_id'][:])
#     print('')
#     display(sf['processing'].keys())
#     print('')
#     display(sf['processing/running'].keys())
#     print('')
#     display(sf['processing/running/running_speed/data'][:])
#     print('')
#     display(sf['processing/eye_tracking_rig_metadata/eye_tracking_rig_metadata'].keys())
#     print('')
#     display(sf['general'].keys())
#     print('')
#     display(sf['general/extracellular_ephys/electrodes'].keys())
#     print('')
#     display(sf['general/extracellular_ephys/electrodes/id'][:])
#     print('')
#     display(sf['general/extracellular_ephys/electrodes/location'][:])
#     print('')
#     display(sf['processing/stimulus/timestamps'].keys())
#     print('')
#     display(sf['intervals/'].keys())
#     print('')
#     display(sf['intervals/invalid_times'].keys())
#     print('')
#     display(sf['intervals/invalid_times/start_time'][:])

<KeysViewHDF5 ['devices', 'extracellular_ephys', 'institution', 'session_id', 'stimulus', 'subject']>

{20,
 40,
 60,
 80,
 100,
 120,
 140,
 160,
 180,
 200,
 220,
 240,
 260,
 280,
 300,
 320,
 340,
 360,
 380,
 400,
 420,
 440,
 460,
 480,
 500,
 520,
 540,
 560,
 580,
 600,
 620,
 640,
 660,
 680,
 700,
 720,
 740,
 760,
 780,
 800,
 820,
 840,
 860,
 880,
 900,
 920,
 940,
 960,
 980,
 1000,
 1020,
 1040,
 1060,
 1080,
 1100,
 1120,
 1140,
 1160,
 1180,
 1200,
 1220,
 1240,
 1260,
 1280,
 1300,
 1320,
 1340,
 1360,
 1380,
 1400,
 1420,
 1440,
 1460,
 1480,
 1500,
 1520,
 1540,
 1560,
 1580,
 1600,
 1620,
 1640,
 1660,
 1680,
 1700,
 1720,
 1740,
 1760,
 1780,
 1800,
 1820,
 1840,
 1860,
 1880,
 1900,
 1920,
 1940,
 1960,
 1980,
 2000,
 2020,
 2040,
 2060,
 2080,
 2100,
 2120,
 2140,
 2160,
 2180,
 2200,
 2220,
 2240,
 2260,
 2280,
 2300,
 2320,
 2340,
 2360,
 2380,
 2400,
 2420,
 2440,
 2460,
 2480,
 2500,
 2520,
 2540,
 2560,
 2580,
 2600,
 2620,
 2640,
 2660,
 2680,
 2700,
 2720,
 2740,
 2760,
 2780,
 2800,
 2820,
 2840,
 2860,
 2880,
 2900,
 2920,
 2940,
 2960,
 2980,
 3000,
 30