## This notebook shows an example where a set of electrodes are selected from a dataset and then LFP is extracted from those electrodes and written to a new NWB file


### We assume that you have added an NWB file to the database (see Populate_from_NWB_tutorial notebook)

#### Load all of the relevant modules and set the environment variables. 
Note that the datadir and datadir/analysis must exist

In [1]:
%env DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
%load_ext autoreload
%autoreload 2

import pynwb
import os
from pathlib import Path

# DataJoint and DataJoint schema
import datajoint as dj

# CONFIG FOR LOCAL DATABASE - CHANGE AS NEEDED
# dj.config['database.host'] = 'localhost'
# dj.config['database.user'] = 'root'
# dj.config['database.password'] = 'tutorial'


# the commands below can be run once to update your global configuration
# c
#dj.config.save_global()

import nwb_datajoint as nd
# import ndx_franklab_novela
# import ndx_franklab_novela.probe

import warnings
warnings.simplefilter('ignore')

# Note that all of the following must exist

# data_dir = Path('/Users/loren/data/nwb_builder_test_data') # CHANGE ME TO THE BASE DIRECTORY FOR DATA STORAGE ON YOUR SYSTEM

# os.environ['NWB_DATAJOINT_BASE_DIR'] = str(data_dir)
# os.environ['KACHERY_STORAGE_DIR'] = str(data_dir / 'kachery-storage')
# os.environ['SPIKE_SORTING_STORAGE_DIR'] = str(data_dir / 'spikesorting')

env: DJ_SUPPORT_FILEPATH_MANAGEMENT=TRUE
Connecting jhbak@lmf-db.cin.ucsf.edu:3306


#### Next we select the NWB file, which corresponds to the dataset we want to extract LFP from

In [4]:
nd.common.Nwbfile()

nwb_file_name  name of the NWB file,nwb_file_abs_path
,


In [3]:
nwb_file_names = nd.common.Nwbfile().fetch('beans20190718.nwb')
# take the first one for this demonstration
nwb_file_name = nwb_file_names[0]
print(nwb_file_name)

DataJointError: Attribute `beans20190718.nwb` is not found

#### Create the standard LFP Filters. This only needs to be done once.

In [None]:
nd.common.FirFilter().create_standard_filters()

### Select every 16th electrode for LFP

In [None]:
electrode_ids = nd.common.Electrode.fetch('electrode_id')
lfp_electrode_ids = electrode_ids[range(0, len(electrode_ids), 16)]
nd.common.LFPSelection().set_lfp_electrodes(nwb_file_name, lfp_electrode_ids.tolist())

Show the list of selected electrodes. Note that the electrode_group corresponds to the physical probe the electrode was part of.

In [None]:
nd.common.LFPSelection().LFPElectrode()

### Populate the LFP table

In [None]:
nd.common.LFP().populate()

### Now that we've created the LFP object we can perform a second level of filtering for a band of interest, in this case the theta band
We first need to create the filter

In [None]:
lfp_sampling_rate = (nd.common.LFP() & {'nwb_file_name' : nwb_file_name}).fetch1('lfp_sampling_rate')
filter_name = 'Theta 5-11 Hz'
nd.common.FirFilter().add_filter(filter_name, lfp_sampling_rate, 'bandpass', [4, 5, 11, 12], 'theta filter for 1 KHz data')

In [None]:
nd.common.FirFilter()

Next we add an entry for the LFP Band and the electrodes we want to filter

In [None]:
# assume that we've filtered these electrodes; change this if not
lfp_band_electrode_ids = [0, 16, 32]

# set the interval list name corresponding to the first epoch (a sleep session)
interval_list_name = '01_s1'

# set the reference to -1 to indicate no reference for all channels
ref_elect = [-1]

# desired sampling rate
lfp_band_sampling_rate = lfp_sampling_rate // 10

In [None]:
(nd.common.IntervalList() & {'nwb_file_name' : nwb_file_name, 'interval_list_name': interval_list_name}).fetch1('valid_times')[0, 1]

In [None]:
nd.common.LFPBandSelection().set_lfp_band_electrodes(nwb_file_name, lfp_band_electrode_ids, filter_name, interval_list_name, ref_elect, lfp_band_sampling_rate)

Check to make sure it worked

In [None]:
nd.common.LFPBandSelection()

In [None]:
nd.common.LFPBandSelection().LFPBandElectrode()

In [None]:
(nd.common.LFP()& {'nwb_file_name' : nwb_file_name}).fetch_nwb()

In [None]:
nd.common.LFPBand().populate()

### Now we can plot the original signal, the LFP filtered trace, and the theta filtered trace together.
Much of the code below could be replaced by a function calls that would return the data from each electrical series, or better yet, plot the data in an electrical series.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
#get the three electrical series objects and the indices of the electrodes we band pass filtered
orig_eseries = (nd.common.Raw() & {'nwb_file_name' : nwb_file_name}).fetch_nwb()[0]['raw']
orig_elect_indices = nd.common.get_electrode_indices(orig_eseries, lfp_band_electrode_ids)

lfp_eseries = (nd.common.LFP() & {'nwb_file_name' : nwb_file_name}).fetch_nwb()[0]['lfp']
lfp_elect_indices = nd.common.get_electrode_indices(lfp_eseries, lfp_band_electrode_ids)

lfp_band_eseries = (nd.common.LFPBand() & {'nwb_file_name' : nwb_file_name}).fetch_nwb()[0]['filtered_data']
lfp_band_elect_indices = nd.common.get_electrode_indices(lfp_band_eseries, lfp_band_electrode_ids)

In [None]:
# get a list of times for the first run epoch and then select a 1 second interval 100 seconds from the beginning
times = (nd.common.IntervalList & {'interval_list_name' : interval_list_name}).fetch1('valid_times')
plottimes = [times[0][0] + 101, times[0][0] + 102]

In [None]:
# get the time indices for each dataset
orig_time_ind = np.argwhere(np.logical_and(orig_eseries.timestamps > plottimes[0], orig_eseries.timestamps < plottimes[1]))
lfp_time_ind = np.argwhere(np.logical_and(lfp_eseries.timestamps > plottimes[0], lfp_eseries.timestamps < plottimes[1]))
lfp_band_time_ind = np.argwhere(np.logical_and(lfp_band_eseries.timestamps > plottimes[0], lfp_band_eseries.timestamps < plottimes[1]))

In [None]:
plt.plot(orig_eseries.timestamps[orig_time_ind], orig_eseries.data[orig_time_ind,orig_elect_indices[1]], 'k-')
plt.plot(lfp_eseries.timestamps[lfp_time_ind], lfp_eseries.data[lfp_time_ind,lfp_elect_indices[1]], 'b-')
plt.plot(lfp_band_eseries.timestamps[lfp_band_time_ind], lfp_band_eseries.data[lfp_band_time_ind,lfp_band_elect_indices[1]], 'r-')
plt.xlabel('Time (sec)')
plt.ylabel('Amplitude (AD units)')

plt.show()