In [1]:
import json
import numpy as np
import pandas as pd
import xarray as xr

import add_path
import toolkit.allen_helpers.stimuli as st
import toolkit.pipeline.signal as ps
from toolkit.analysis.signal import bandpass_power, instantaneous_power
from toolkit.pipeline.data_io import SessionDirectory, FILES
from toolkit.pipeline.global_settings import GLOBAL_SETTINGS

pd.set_option('display.max_columns', None)

with open('config.json') as f:
    config = json.load(f)

## Get session and load data

In [2]:
session_id = config['session_id']
ecephys_structure_acronym = config['ecephys_structure_acronym']
print(f"Session ID: {session_id}")

Session ID: 721123822


In [3]:
session_dir = SessionDirectory(session_id, ecephys_structure_acronym, cache_lfp=True)

probe_info = session_dir.load_probe_info()
if not session_dir.has_lfp_data:
    raise ValueError(f"Session {session_id} has no LFP data")

session = session_dir.session

core - cached version: 2.2.2, loaded version: 2.7.0
  self.warn_for_ignored_namespaces(ignored_namespaces)


In [4]:
stimulus_presentations = session.stimulus_presentations
session_type = session.session_type

drifting_gratings_stimuli = st.STIMULUS_CATEGORIES[session_type]['drifting_gratings']
natural_movies_stimuli = st.STIMULUS_CATEGORIES[session_type]['natural_movies']

## Analyze data

### Get frequency bands

In [5]:
wave_bands = ['beta', 'gamma']
bands_of_interest = FILES.load('bands_of_interest')
layer_of_interest = GLOBAL_SETTINGS['layer_of_interest']
instantaneous_band = GLOBAL_SETTINGS['instantaneous_band']

In [6]:
def find_band_in_layers(bands_ds, layer_of_interest):
    band = bands_ds.bands.sel(layer=layer_of_interest)
    if np.isnan(band).all():
        wave_band = bands_ds.wave_band.item().title()
        print(f"{wave_band} band not found in the layer of interest '{layer_of_interest}'. "
            "Trying to find them in other layers.")
        band = ps.get_band_with_highest_peak(bands_ds)
        if band is None:
            print(f"{wave_band} band not found in any layer.")
        else:
            print(f"{wave_band} band found in layer '{band.layer.item()}'.")
    return band

In [7]:
if session_id in bands_of_interest.session_id:
    freq_bands = bands_of_interest.sel(wave_band=wave_bands, session_id=session_id)
else:
    print("Warning: Bands of interest not found in the PSD of this session. "
        f"Trying to find them in layers other than the layer of interest '{layer_of_interest}'.")
    bands_ds = session_dir.load_wave_bands()
    beta_stim = drifting_gratings_stimuli[0] + '_filtered'
    gamma_stim = natural_movies_stimuli[0]
    beta_band = find_band_in_layers(bands_ds.sel(stimulus=beta_stim, wave_band='beta'), layer_of_interest)
    gamma_band = find_band_in_layers(bands_ds.sel(stimulus=gamma_stim, wave_band='gamma'), layer_of_interest)
    if beta_band is None or gamma_band is None:
        average_wave_bands = FILES.load('average_wave_bands', session_type=session_type, session_set='selected_sessions')
        if beta_band is None:
            print("Beta band not found in the session. Trying to find the average bands from selected sessions.")
            beta_band = find_band_in_layers(average_wave_bands.sel(stimulus=beta_stim, wave_band='beta'), layer_of_interest)
            if beta_band is None:
                raise ValueError("No beta band found.")
        if gamma_band is None:
            print("Gamma band not found in the session. Trying to find the average bands from selected sessions.")
            gamma_band = find_band_in_layers(average_wave_bands.sel(stimulus=gamma_stim, wave_band='gamma'), layer_of_interest)
            if gamma_band is None:
                raise ValueError("No gamma band found.")
    freq_bands = xr.concat([beta_band, gamma_band], dim='wave_band')

### Process LFP

In [None]:
group_width = 1  # Number of channels to the left and right of the central channel
extend_time = 1.0  # extend time at the start and end of each block to avoid boundary effect for filtering
filter_instantaneous_power = True  # whether to filter before calculating instantaneous power

lfp_groups, _ = ps.get_lfp_channel_groups(session_dir,
    probe_info['central_channels'], probe_id=probe_info['probe_id'], width=group_width)

lfp_power_dss = {}

#### Drifting gratings

In [9]:
stim = drifting_gratings_stimuli[0]  # first drifting grating stimulus
drifting_gratings_window = (-0.5, 0.5)  # right boundary < extend_time

In [None]:
stimulus_trials = st.get_stimulus_trials(stimulus_presentations, stim)
conditions = st.presentation_conditions(stimulus_trials.presentations)
window = (drifting_gratings_window[0], stimulus_trials.duration + drifting_gratings_window[1])
aligned_lfp, valid_trials = st.align_trials(
    lfp_groups, stimulus_trials, window=window, ignore_nan_trials='any')

if valid_trials is None:
    valid_trials = stimulus_trials
else:  # if any trial is dropped by NaN values
    cond_presentation_id = st.presentation_conditions(valid_trials.presentations)[1]
    if len(conditions[1]) != len(cond_presentation_id):
        diff = set(conditions[1].keys()) - set(cond_presentation_id.keys())
        raise ValueError(f"All trials are dropped by NaN values in {stim} for conditions: {diff}")
    conditions = (conditions[0], cond_presentation_id)
valid_blocks = st.get_stimulus_blocks(valid_trials)

In [None]:
lfp_bands_power = []
for wave_band in wave_bands:
    block_power = bandpass_power(ps.bandpass_filter_blocks(
        lfp_groups, valid_blocks,
        freq_bands.sel(wave_band=wave_band).values,
        extend_time=extend_time,
        include_filtered=False,
        include_amplitude=True
    ))
    lfp_bands_power.append(st.align_trials(block_power, valid_trials, window=window, ignore_nan_trials='')[0])
lfp_bands_power = xr.concat(lfp_bands_power, dim=pd.Index(wave_bands, name='wave_band'), combine_attrs='drop_conflicts')

if filter_instantaneous_power:
    block_filt = ps.bandpass_filter_blocks(lfp_groups, valid_blocks, instantaneous_band, extend_time=extend_time)
    aligned_lfp = st.align_trials(block_filt.filtered, valid_trials, window=window, ignore_nan_trials='')[0]
lfp_power_dss[stim] = xr.Dataset(
    data_vars = dict(
        instantaneous_power = instantaneous_power(aligned_lfp),
        band_power = lfp_bands_power,
        freq_bands = freq_bands,
    )
)

#### Natural movies

In [12]:
natural_movies_window = (0., 0.)  # right boundary < extend_time

In [None]:
for stim in natural_movies_stimuli:
    stimulus_trials = st.get_stimulus_trials(stimulus_presentations, stim)
    window = (natural_movies_window[0], stimulus_trials.duration + natural_movies_window[1])
    aligned_lfp, valid_blocks = st.align_trials_from_blocks(
        lfp_groups, st.get_stimulus_blocks(stimulus_trials), window=window, ignore_nan_trials='any')
    if not valid_blocks:
        print(f"Warning: All trials are dropped by NaN values in {stim}.")
        continue

    lfp_bands_power = []
    for wave_band in wave_bands:
        block_power = bandpass_power(ps.bandpass_filter_blocks(
            lfp_groups, valid_blocks,
            freq_bands.sel(wave_band=wave_band).values,
            extend_time=extend_time,
            include_filtered=False,
            include_amplitude=True
        ))
        lfp_bands_power.append(st.align_trials_from_blocks(block_power, valid_blocks, window=window, ignore_nan_trials='')[0])
    lfp_bands_power = xr.concat(lfp_bands_power, dim=pd.Index(wave_bands, name='wave_band'), combine_attrs='drop_conflicts')

    if filter_instantaneous_power:
        block_filt = ps.bandpass_filter_blocks(lfp_groups, valid_blocks, instantaneous_band, extend_time=extend_time)
        aligned_lfp = st.align_trials_from_blocks(block_filt.filtered, valid_blocks, window=window, ignore_nan_trials='')[0]
    lfp_power_dss[stim] = xr.Dataset(
        data_vars = dict(
            instantaneous_power = instantaneous_power(aligned_lfp),
            band_power = lfp_bands_power,
            freq_bands = freq_bands,
        )
    )

### Save data

In [14]:
session_dir.save_stimulus_lfp_power(lfp_power_dss)