# Guo-Inagaki-2017 - A DataJoint example

The data and results presented in this notebook pertain to the paper:
>Zengcai V. Guo, Hidehiko K. Inagaki, Kayvon Daie, Shaul Druckmann, Charles R. Gerfen & Karel Svoboda. "Maintenance of persistent activity in a frontal thalamocortical loop" (2017) Nature
(https://dx.doi.org/10.1038/nature22324)


This notebook provide demonstrations of working with a DataJoint data pipeline in querying data, apply data conditioning and reproduce some key figures in the paper. The orignal data , in *NWB 2.0* format, had been ingested into a DataJoint data pipeline (data pipeline schema is given below). As a validation of complete ingestion of the original data into DataJoint, figures 3b,e, 6b,e and 4b,e,h will be reproduced in this example.

In [None]:
from datetime import datetime
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from pipeline import reference, subject, acquisition, stimulation, analysis #, behavior, ephys, action

In [None]:
import datajoint as dj
all_erd = dj.ERD(reference) + dj.ERD(subject) + dj.ERD(acquisition) + dj.ERD(stimulation) + dj.ERD(analysis)
dj.ERD(all_erd)

## Reproduce Figure 4b, e, f
First, we wish to demonstrate the queries and reproduction of figures 4b, 4e and 4h in the paper. These figures represents population-level membrane potentials of intracellular recordings, in response to photostimulation, of the entire study. To investigate the responses to photostimulation, we wish to visualize the membrane potentials responses time-locked to the onset of the "delay" period, where photostimulation was performed. This trial-based segmentation has already been performed and results stored in DataJoint, however appropriate queries are still required.  

Specifically, we need to first query membrane potentials recordings from all sessions and categorize based on: 
+ Photostim location: contralateral ALM, Thalamus, and M1
+ Trial condition: good trials without stimulation (control) and with stimulation (stim)

In this data pipeline, a recording session contains the acquired intracellular recordings, and the photostimulation, which specifies brain location the stimulation was performed on. We need to use this photostimulation information to constrain the queries of trial-segmented membrane potentials. 

In [None]:
# -- Get all whole cell
cell_keys = acquisition.Cell.fetch(dj.key)

In [None]:
# -- Backtrack session, and get photostim info
# photostim keys for contra ALM stimulation
region_dict = {'brain_region':'ALM', 'hemisphere':'right'}
contraALM_photostims = (acquisition.PhotoStimulation & cell_keys & region_dict).fetch('KEY')
# photostim keys for thalamus stimulation
region_dict = {'brain_region':'VM', 'hemisphere':'left'}
thal_photostims = (acquisition.PhotoStimulation & cell_keys & region_dict).fetch('KEY')
# photostim keys for M1 stimulation
region_dict = {'brain_region':'M1', 'hemisphere':'left'}
m1_photostims = (acquisition.PhotoStimulation & cell_keys & region_dict).fetch('KEY')

In [None]:
# -- define cell restrictor for each stim location
contraALM_stim_cells = (acquisition.Cell & contraALM_photostims).fetch('KEY')
thal_stim_cells = (acquisition.Cell & thal_photostims).fetch('KEY')
m1_stim_cells = (acquisition.Cell & m1_photostims).fetch('KEY')

At this point, whole-cell recordings sessions (session-key) with photostimulation at contralateral ALM, Thalamus, and M1 are embedded in the variables *contraALM_stim_cells*, *thal_stim_cells* and *m1_stim_cells* respectively

In [None]:
# define some trial restrictor 
stim_trial_cond = {'trial_is_good': True, 'trial_stim_present': True}
ctrl_trial_cond = {'trial_is_good': True, 'trial_stim_present': False}

In [None]:
# define trial-segmentation setting 
seg_param_key = (analysis.TrialSegmentationSetting & {'event': 'pole_out', 'pre_stim_duration': 1.5, 'post_stim_duration': 3}).fetch1('KEY')

Here, for convenient operations we define a function to query trial-segmented intracellular recordings based on the session-key, and trial-restrictor defined above

In [None]:
def query_segmented_intracellular(cell_key, trial_key, seg_param_key):
    data_keys = (analysis.TrialSegmentedIntracellular & cell_key & seg_param_key &
                 (acquisition.TrialSet.Trial & trial_key))
    return [{**dict(zip(['segmented_mp', 'segmented_mp_wo_spike'],
                        (analysis.TrialSegmentedIntracellular.MembranePotential & k).fetch1(
                            'segmented_mp', 'segmented_mp_wo_spike'))),
             **dict(zip(*(analysis.RealignedEvent.RealignedEventTime & k).fetch(
                 'realigned_trial_event', 'realigned_event_time')))}
            for k in data_keys]

In [None]:
# query trial-segmented data based on the cell and trial restrictors    
trial_segmented_ic = {ic_loc: {'stim': query_segmented_intracellular(ic_loc_key, stim_trial_cond, seg_param_key),
                                'ctrl': query_segmented_intracellular(ic_loc_key, ctrl_trial_cond, seg_param_key)}
                      for ic_loc, ic_loc_key in zip(('thalamus', 'm1', 'contraALM'), (thal_stim_cells, m1_stim_cells, contraALM_stim_cells))}

In [None]:
# get sampling rate
fs = acquisition.IntracellularAcquisition.MembranePotential.fetch('membrane_potential_sampling_rate', limit=1)

In [None]:
# 
avg_trial_segmented_ic = {ic_key: {'stim':{
                                       'data': np.vstack([k['segmented_mp_wo_spike'][:min(k['segmented_mp_wo_spike'].size
                                                                                          for k in trial_segmented_ic[ic_key]['stim'])] 
                                                          for k in trial_segmented_ic[ic_key]['stim']]),
                                       'timestamps': np.arange(min(k['segmented_mp_wo_spike'].size 
                                                                   for k in trial_segmented_ic[ic_key]['stim'])) / fs - float(seg_param_key['pre_stim_duration'])},
                                    'ctrl':{
                                       'data': np.vstack([k['segmented_mp_wo_spike'][:min(k['segmented_mp_wo_spike'].size
                                                                                          for k in trial_segmented_ic[ic_key]['ctrl'])] 
                                                          for k in trial_segmented_ic[ic_key]['ctrl']]),
                                       'timestamps': np.arange(min(k['segmented_mp_wo_spike'].size 
                                                                   for k in trial_segmented_ic[ic_key]['ctrl'])) / fs - float(seg_param_key['pre_stim_duration'])}
                                   } for ic_key in trial_segmented_ic}

In [None]:
def plot_with_sem(data1, data2, ax):
    for d, c, s in zip((data1, data2), ('b', 'k'), ('b', 'gray')):
        v_mean = d['data'].mean(axis=0)
        v_sem = d['data'].std(axis=0) / np.sqrt(data1['data'].shape[0])
        ax.plot(d['timestamps'], v_mean, c)
        ax.fill_between(d['timestamps'], v_mean - v_sem, v_mean + v_sem, alpha=0.5, facecolor=s)          
    ax.axvline(x=0, linestyle='--', color='k')
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)

In [None]:
# plot
fig4, axs = plt.subplots(len(avg_trial_segmented_ic), 1, figsize=(6, 6))
fig4.subplots_adjust(hspace=0.4)

for avg, ax in zip(avg_trial_segmented_ic.values(), axs):
    plot_with_sem(avg['stim'], avg['ctrl'], ax)
    ax.set_xlim(-0.1, 1.4)

## Reproduce Figure 3b, 6b - Extracellular
The following parts of this example relates to the extracellular recording results of this study, namedly the neuronal spiking patterns in response to photostimulation 

Fairly similar to the routine layed out above, we wish to query neuronal spike times, segmented time-locked to the "delay" period, categorized by recording locations and photostimulation locations:
+ Record at ALM, stimulation at ALM
+ Record at ALM, stimulation at Thalamus
+ Record at Thalamus, stimulation at Thalamus
+ Record at Thalamus, stimulation at ALM

Also with two trial-based conditions: good trials without stimulation (control) and with stimulation (stim)

In [None]:
# get trial-segmented spiketimes for a single unit in the one specifed session
def query_unit_segmented_spiketimes(sess_key, unit, trial_key, seg_param_key):
    data_keys = (analysis.TrialSegmentedUnitSpikeTimes & sess_key & {'unit_id': unit} & seg_param_key &
                 (acquisition.TrialSet.Trial & trial_key)).fetch(dj.key)
    return pd.DataFrame([dict(**dict(zip(*(analysis.RealignedEvent.RealignedEventTime & k).fetch('realigned_trial_event', 'realigned_event_time'))), 
                 segmented_spike_times=(analysis.TrialSegmentedUnitSpikeTimes & k).fetch1(
                     'segmented_spike_times')) for k in data_keys])

In [None]:
def plot_spike_raster_and_histogram(contra_spike_times, ipsi_spike_times, axes, ax_title='', bin_counts=200):
    # get event timing
    events = ['pole_in', 'pole_out', 'cue_start']
    event_times = np.around([np.hstack([ipsi_spike_times[e], contra_spike_times[e]]).mean() for e in events], 4)
    
    # restructure data for spike raster  
    ipsi_c_trial_idx, ipsi_c_spike_times = zip(*((np.full_like(r, ri), r)
                                            for ri, r in enumerate(r for r in ipsi_spike_times.segmented_spike_times if len(r) != 0)))
    ipsi_c_trial_idx = np.hstack(ipsi_c_trial_idx)
    ipsi_c_spike_times = np.hstack(ipsi_c_spike_times)    
    
    contra_c_trial_idx, contra_c_spike_times = zip(*((np.full_like(r, ri), r)
                                            for ri, r in enumerate(r for r in contra_spike_times.segmented_spike_times if len(r) != 0)))
    contra_c_trial_idx = np.hstack(contra_c_trial_idx)
    contra_c_spike_times = np.hstack(contra_c_spike_times)  
    
    # spiketime histogram
    time_range = (np.hstack([ipsi_spike_times.trial_start, contra_spike_times.trial_start]).min(),
                  np.hstack([ipsi_spike_times.trial_stop, contra_spike_times.trial_stop]).max())

    ipsi_spk_counts, ipsi_edges = np.histogram(np.hstack([r for r in ipsi_spike_times.segmented_spike_times]),
                                       bins=bin_counts,
                                       range=(time_range[0], time_range[-1]))
    ipsi_spk_rates = ipsi_spk_counts / np.diff(ipsi_edges) / ipsi_spike_times.segmented_spike_times.shape[0]
    
    contra_spk_counts, contra_edges = np.histogram(np.hstack([r for r in contra_spike_times.segmented_spike_times]),
                                       bins=bin_counts,
                                       range=(time_range[0], time_range[-1]))
    contra_spk_rates = contra_spk_counts / np.diff(contra_edges) / contra_spike_times.segmented_spike_times.shape[0]
   
    # plot
    # spike raster
    ax_top = axes[0]
    ax_top.plot(contra_c_spike_times, contra_c_trial_idx + ipsi_c_trial_idx.max(), '|r')
    ax_top.plot(ipsi_c_spike_times, ipsi_c_trial_idx, '|b')
    # event markers
    for e in event_times:
        ax_top.axvline(x=e, linestyle='--', color='k')
    ax_top.set_xticklabels([])
    ax_top.set_yticklabels([])
    ax_top.set_ylabel(ax_title)
    ax_top.set_xlim(-1.5, 3);
    
    # spike histogram
    ax_bot = axes[1]
    ax_bot.plot(ipsi_edges[1:], ipsi_spk_rates, 'b')
    ax_bot.plot(contra_edges[1:], contra_spk_rates, 'r')
    for e in event_times:
        ax_bot.axvline(x=e, linestyle='--', color='k')
    ax_bot.set_xlim(-1.5, 3);
    
    # Hide the spines
    ax_top.spines['right'].set_visible(False)
    ax_top.spines['top'].set_visible(False)
    ax_top.spines['left'].set_visible(False)
    ax_top.spines['bottom'].set_visible(False)
    ax_bot.spines['right'].set_visible(False)
    ax_bot.spines['top'].set_visible(False)

In [None]:
# blue - correct contra trial (licking-right) ; red - correct ipsi trial (licking left)
correct_contra_trial_stim =  {'trial_is_good': True, 'trial_stim_present': True, 'trial_type': 'lick right'}
correct_ipsi_trial_stim =  {'trial_is_good': True, 'trial_stim_present': True, 'trial_type': 'lick left'}
correct_contra_trial_ctrl =  {'trial_is_good': True, 'trial_stim_present': False, 'trial_type': 'lick right'}
correct_ipsi_trial_ctrl =  {'trial_is_good': True, 'trial_stim_present': False, 'trial_type': 'lick left'}

ec_alm_insert = (acquisition.ProbeInsertion & {'brain_region': 'ALM', 'hemisphere': 'left'})
ec_thal_insert = (acquisition.ProbeInsertion & {'brain_region': 'thalamus', 'hemisphere': 'left'})

alm_photostim = (acquisition.PhotoStimulation & {'brain_region': 'ALM', 'hemisphere': 'left'})
thal_photostim = (acquisition.PhotoStimulation & {'brain_region': 'thalamus', 'hemisphere': 'left'})

# ALM probe with ALM or Thalamus photostim
alm_insert_alm_stim = (acquisition.Session & ec_alm_insert) & (acquisition.Session & alm_photostim).fetch('KEY')
alm_insert_thal_stim = (acquisition.Session & ec_alm_insert) & (acquisition.Session & thal_photostim).fetch('KEY')
# Thalamus probe with ALM or Thalamus photostim
thal_insert_alm_stim = (acquisition.Session & ec_thal_insert) & (acquisition.Session & alm_photostim).fetch('KEY')
thal_insert_thal_stim = (acquisition.Session & ec_thal_insert) & (acquisition.Session & thal_photostim).fetch('KEY')

### Example unit - Fig 3b & 6b
We pick some arbitrary neuronal units for plotting the spike raster and spike histogram

In [None]:
# # Chery-picking unit and session...
# for idx, sess in enumerate(alm_insert_thal_stim.fetch('KEY')): # pick one session here
#     print([idx, len(acquisition.TrialSet.Trial & sess & correct_contra_trial_stim),
#            len(acquisition.TrialSet.Trial & sess & correct_ipsi_trial_stim)])

In [None]:
# get segmented spike times for one unit - ALM-insert, thal-stim
def make_unit_spiketimes(unit, sess):
    return {k: query_unit_segmented_spiketimes(sess, unit, v, seg_param_key) 
            for k, v in zip(('contra_ctrl', 'ipsi_ctrl', 'contra_stim', 'ipsi_stim'), 
                            (correct_contra_trial_ctrl, correct_ipsi_trial_ctrl, 
                             correct_contra_trial_stim, correct_ipsi_trial_stim))}

In [None]:
# 3 units with ALM insert and Thalamus photostim
sess = alm_insert_thal_stim.fetch('KEY')
alm_insert_thal_stim_unit = [ make_unit_spiketimes(unit=unit_no, sess=sess[session_no]) 
                             for unit_no, session_no in zip((1, 1, 1), (21, 34, 19))]

In [None]:
# spike raster and histogram
bin_counts = 200
fig3, axs = plt.subplots(4, len(alm_insert_thal_stim_unit), figsize=(12, 6))
for u_idx, unit_spiketimes in enumerate(alm_insert_thal_stim_unit):
    plot_spike_raster_and_histogram(unit_spiketimes['contra_ctrl'], unit_spiketimes['ipsi_ctrl'], 
                                    axes = (axs.flatten()[u_idx + 3*0], axs.flatten()[u_idx + 3*1]),
                                    ax_title='ALM-Thal-ctrl', bin_counts=bin_counts)
    plot_spike_raster_and_histogram(unit_spiketimes['contra_stim'], unit_spiketimes['ipsi_stim'], 
                                    axes = (axs.flatten()[u_idx + 3*2], axs.flatten()[u_idx + 3*3]),
                                    ax_title='ALM-Thal-stim', bin_counts=bin_counts)

In [None]:
# # Chery-picking unit and session...
# for idx, sess in enumerate(thal_insert_alm_stim.fetch('KEY')): # pick one session here
#     print([idx, len(acquisition.TrialSet.Trial & sess & correct_contra_trial_stim),
#            len(acquisition.TrialSet.Trial & sess & correct_ipsi_trial_stim)])

In [None]:
# 3 units with ALM insert and Thalamus photostim
sess = thal_insert_alm_stim.fetch('KEY')
thal_insert_alm_stim_unit = [make_unit_spiketimes(unit=unit_no, sess=sess[session_no])
                             for unit_no, session_no in zip((1, 1, 1), (1, 2, 35))]

In [None]:
# spike raster and histogram
bin_counts = 200
fig3, axs = plt.subplots(4, len(alm_insert_thal_stim_unit), figsize=(12, 6))
for u_idx, unit_spiketimes in enumerate(thal_insert_alm_stim_unit):
    plot_spike_raster_and_histogram(unit_spiketimes['contra_ctrl'], unit_spiketimes['ipsi_ctrl'], 
                                    axes = (axs.flatten()[u_idx + 3*0], axs.flatten()[u_idx + 3*1]),
                                    ax_title='ALM-Thal-ctrl', bin_counts=bin_counts)
    plot_spike_raster_and_histogram(unit_spiketimes['contra_stim'], unit_spiketimes['ipsi_stim'], 
                                    axes = (axs.flatten()[u_idx + 3*2], axs.flatten()[u_idx + 3*3]),
                                    ax_title='ALM-Thal-stim', bin_counts=bin_counts)

## Population - Fig 3e and 6e
Similarly, for the entire population, instead of picking a few representative units for plotting, here we will query all units from all sessions constrained by the recording/stimulation brain regions and trial-condition restrictors defined above.

The routines for plotting the spike histogram is analogous to that above. 

In [None]:
# get trial-segmented spiketimes for all units in all specifed sessions
def extract_segmented_spiketimes_histogram(sess_key, trial_key, seg_param_key, time_range=(-1.5, 3), bin_counts=100):
    unit_trial_keys = (analysis.TrialSegmentedUnitSpikeTimes & sess_key & seg_param_key &
                 (acquisition.TrialSet.Trial & trial_key)).fetch('KEY')
    print(f'Found {len(unit_trial_keys)} total trials')
        
    def make_spike_histogram():
        for idx, k in enumerate(unit_trial_keys):
            segmented_spike_times = (analysis.TrialSegmentedUnitSpikeTimes & k).fetch1('segmented_spike_times')
            if segmented_spike_times.size > 0:
                segmented_spike_times = segmented_spike_times[np.logical_and(segmented_spike_times >= time_range[0] 
                                                              , segmented_spike_times <= time_range[-1])]
                spk_counts, spk_edges = np.histogram(segmented_spike_times, bins=bin_counts, range=time_range)
                yield spk_counts / np.diff(spk_edges), spk_edges[1:]
            
    return {'data': np.vstack(x[0] for x in make_spike_histogram()), 'timestamps': next(make_spike_histogram())[1]}

In [None]:
# get segmented spike times for population - ALM/thal-insert, thal-stim
stim_trial =  {'trial_is_good': True, 'trial_stim_present': True}
ctrl_trial =  {'trial_is_good': True, 'trial_stim_present': False}

In [None]:
population_spikes = {name_k: {'ctrl': extract_segmented_spiketimes_histogram(sess_k, ctrl_trial, seg_param_key,
                                                                   (-0.2, 0.5), 200),
                              'stim': extract_segmented_spiketimes_histogram(sess_k, stim_trial, seg_param_key,
                                                                   (-0.2, 0.5), 200)}
                     for name_k, sess_k in zip(('alm_insert_thal_stim', 'thal_insert_thal_stim', 
                                                'alm_insert_alm_stim', 'thal_insert_alm_stim'),
                                               (alm_insert_thal_stim, thal_insert_thal_stim, 
                                                alm_insert_alm_stim, thal_insert_alm_stim))}

In [None]:
fig36, axs = plt.subplots(2, 2, figsize=(6, 6))
fig36.subplots_adjust(hspace=0.3)
for s, ax in zip(population_spikes.keys(), axs.flatten()):
    plot_with_sem(population_spikes[s]['stim'], population_spikes[s]['ctrl'], ax)
    ax.set_xlim(-0.02, 0.04);
    ax.set_title(s)