In [1]:
import os
import numpy as np
import pandas as pd
import xarray as xr
import scipy as sp
import matplotlib.pyplot as plt
%matplotlib inline

from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache
from allensdk.brain_observatory.ecephys.visualization import plot_spike_counts, raster_plot

# Example cache directory path, it determines where downloaded data will be stored
output_dir = './ecephys_cache_dir/'

manifest_path = os.path.join(output_dir, "manifest.json")

cache = EcephysProjectCache.from_warehouse(manifest=manifest_path)

pd.set_option('display.max_columns', None)

In [2]:
session_id = 721123822 # 797828357
ecephys_structure_acronyms = 'VISp'

In [3]:
session = cache.get_session_data(session_id)
sel_units = session.units[session.units['ecephys_structure_acronym'] == ecephys_structure_acronyms]

In [4]:
# units center
ccf_coords = ['anterior_posterior_ccf_coordinate', 'dorsal_ventral_ccf_coordinate', 'left_right_ccf_coordinate']
units_coord = sel_units[ccf_coords].values
units_coord_mean = units_coord.mean(axis=0)
center_unit_id = sel_units.index[np.argmin(np.sum((units_coord - units_coord_mean) ** 2, axis=1))]

channel_index = sel_units.loc[center_unit_id, 'probe_channel_number']
probe_id = sel_units.loc[center_unit_id, 'probe_id']

channel_id = session.channels[(session.channels.probe_channel_number == channel_index) & 
                           (session.channels.probe_id == probe_id)].index[0]

In [5]:
fs = session.probes.loc[probe_id].lfp_sampling_rate
stimulus_presentations = session.stimulus_presentations

## Load data

In [6]:
filepath = os.path.join(output_dir, 'session_%d' % session_id,
                        'lfp_probe%d_%s_channel_groups.nc' % (probe_id, ecephys_structure_acronyms))
lfp_array = xr.open_dataset(filepath)
lfp_array = lfp_array.assign_attrs(fs=fs).rename(group_id='channel')
lfp_array

In [7]:
channel_group_map = pd.read_csv(filepath.replace('.nc', '.csv'), index_col='id')
group_dv_ccf = dict(zip(channel_group_map['group_id'], channel_group_map['dorsal_ventral_ccf_coordinate']))
group_id = channel_group_map.loc[channel_id, 'group_id']
channel_group_map

Unnamed: 0_level_0,group_id,anterior_posterior_ccf_coordinate,dorsal_ventral_ccf_coordinate,left_right_ccf_coordinate
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
850245925,7,8545.625000,1435.500000,8019.625000
850245931,7,8545.625000,1435.500000,8019.625000
850245937,7,8545.625000,1435.500000,8019.625000
850245943,6,8575.875000,1347.250000,8024.875000
850245949,6,8575.875000,1347.250000,8024.875000
...,...,...,...,...
850246077,0,8844.181818,583.363636,8054.818182
850246079,0,8844.181818,583.363636,8054.818182
850246081,0,8844.181818,583.363636,8054.818182
850246083,0,8844.181818,583.363636,8054.818182


In [8]:
stimulus_name = 'drifting_gratings'
filepath = os.path.join(output_dir, 'session_%d' % session_id,
                        'lfp_probe%d_%s_%s_units_pca.csv' % (probe_id, ecephys_structure_acronyms, stimulus_name))
pca_df = pd.read_csv(filepath, index_col='unit_id')

## Analyze spike entrainment to LFP

In [9]:
def align_trials(lfp_array, presentation_ids, onset_times, window=(0., 1.)):
    trial_window = np.arange(window[0], window[1], 1 / lfp_array.fs)
    time_selection = np.concatenate([trial_window + t for t in onset_times])
    inds = pd.MultiIndex.from_product((presentation_ids, trial_window), 
                                      names=('presentation_id', 'time_from_presentation_onset'))
    aligned_lfp = lfp_array.sel(time=time_selection, method='nearest')
    aligned_lfp = aligned_lfp.assign(time=inds).unstack('time')
    return aligned_lfp

def align_gratings(stimulus_presentations, stimulus_name='drifting_gratings'):
    presentations = stimulus_presentations[stimulus_presentations.stimulus_name == stimulus_name]
    null_rows = presentations[presentations['orientation'].values == 'null']
    if len(null_rows):
        null_condition = null_rows.iloc[0]['stimulus_condition_id']
        presentations = presentations[presentations['stimulus_condition_id'] != null_condition]

    presentations_times = presentations['start_time'].values
    presentations_ids = presentations.index.values
    trial_duration = presentations['duration'].max()
    return presentations, presentations_ids, presentations_times, trial_duration

def presentation_conditions(presentations, condtion_types):
    conditions = {c: np.unique(presentations[c]).astype(float) for c in condtion_types}
    cond_id_map = dict(zip(map(tuple, presentations[conditions.keys()].values),
                           presentations['stimulus_condition_id']))
    condition_id = [cond_id_map[x, y] for x, y in zip(*map(np.ravel, np.meshgrid(*conditions.values(), indexing='ij')))]
    condition_id = xr.DataArray(np.reshape(condition_id, tuple(map(len, conditions.values()))), coords=conditions, name='condition_id')
    cond_presentation_id = {c: presentations.index[presentations['stimulus_condition_id'] == c] for c in condition_id.values.ravel()}
    return condition_id, cond_presentation_id

In [10]:
stimulus_name = 'drifting_gratings'
drifting_gratings_presentations, grating_ids, grating_times, grating_duration = align_gratings(
    session.stimulus_presentations, stimulus_name=stimulus_name)
aligned_lfp = align_trials(lfp_array, grating_ids, grating_times, window=(-0.2, grating_duration + 0.2))

condition_id, cond_presentation_id = presentation_conditions(
    drifting_gratings_presentations, condtion_types=['orientation', 'temporal_frequency'])

### Get LFP phase

In [11]:
filt_band = [20., 40.]
bfilt, afilt = sp.signal.butter(4, filt_band, btype='bandpass', fs=aligned_lfp.fs)
axis = aligned_lfp.LFP.dims.index('time_from_presentation_onset')
analytic = sp.signal.hilbert(sp.signal.filtfilt(bfilt, afilt, aligned_lfp.LFP, axis=axis))
aligned_lfp = aligned_lfp.assign(amplitude=aligned_lfp.LFP.copy(data=np.abs(analytic)))
aligned_lfp = aligned_lfp.assign(phase=aligned_lfp.LFP.copy(data=np.angle(analytic)))

In [12]:
aligned_lfp

### Get unit channels

In [13]:
channels = session.channels.loc[session.channels['structure_acronym'] == ecephys_structure_acronyms]
probe_channel_number_to_id = {row['probe_channel_number']: i for i, row in channels.iterrows()}
unit_channel = pd.Series({i: channel_group_map.loc[probe_channel_number_to_id[u['probe_channel_number']], 'group_id']
                          for i, u in sel_units.iterrows()}, index=sel_units.index)

### Get unit spike times

In [14]:
def get_spike_phase(spike_times, aligned_lfp, unit_channel):
    unit_ids = unit_channel.index
    presentation_ids = aligned_lfp.presentation_id.to_index()
    spike_trains = [[[] for _ in range(presentation_ids.size)] for _ in range(unit_ids.size)]
    for row in spike_times.itertuples(index=False):
        i = unit_ids.get_loc(row.unit_id)
        j = presentation_ids.get_loc(row.stimulus_presentation_id)
        spike_trains[i][j].append(row.time_since_stimulus_presentation_onset)
    spike_trains = np.array(spike_trains, dtype=object)

    resultant_phase = np.zeros(spike_trains.shape, dtype=complex)
    spike_number = np.zeros(spike_trains.shape, dtype=int)
    for i, u in enumerate(unit_ids):
        unit_phase = aligned_lfp.phase.sel(channel=unit_channel[u])
        for j, p in enumerate(presentation_ids):
            spk_train = spike_trains[i, j]
            spike_number[i, j] = len(spk_train)
            phase = unit_phase.sel(presentation_id=p).sel(time_from_presentation_onset=spk_train, method='nearest')
            resultant_phase[i, j] = np.sum(np.exp(1j * phase))
    spike_phase = xr.DataArray(resultant_phase, name='resultant_phase',
                               coords={'unit_id': unit_ids, 'presentation_id': presentation_ids}).to_dataset(name='resultant_phase')
    spike_phase = spike_phase.assign(spike_number=spike_phase.resultant_phase.copy(data=spike_number))
    return spike_phase

In [15]:
spike_times = session.presentationwise_spike_times(stimulus_presentation_ids=grating_ids, unit_ids=sel_units.index)
spike_phase = get_spike_phase(spike_times, aligned_lfp, unit_channel)

In [16]:
spike_phase

In [19]:
# # Save data
# filepath = os.path.join(output_dir, 'session_%d' % session_id,
#                         'lfp_probe%d_%s_%s_units_plv.nc' % (probe_id, ecephys_structure_acronyms, stimulus_name))
# spike_trains_ds.to_netcdf(filepath) # save downsampled channels