In [1]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from matplotlib import gridspec
from scipy import stats

from one.api import ONE
from brainbox.io.one import SessionLoader

from iblnm.util import EVENT2COLOR, protocol2type, _load_event_times
from iblnm.resp import get_responses

In [13]:
one = ONE()

In [3]:
# eid = '0d118e83-1450-4382-9125-5fabc5c31b88'

In [4]:
# eids = list(one.search(subject='ZFM-08757', datasets=DSETS))
# len(eids)
# eid = eids[-1]

In [17]:
df_sessions = pd.read_parquet('metadata/sessions.pqt').query('remote_photometry == True')
df_sessions = df_sessions[df_sessions['target'].apply(len) > 0]
df_sessions['target_NM'] = df_sessions.apply(lambda x: '-'.join([x['target'][0], x['NM']]), axis='columns')
df_sessions['target_NM'].value_counts()

target_NM
NBM-ACh    139
LC-NE       76
VTA-DA      75
SI-ACh      18
Name: count, dtype: int64

In [18]:
# Pick a target
target_NM = 'VTA-DA'
df_target = df_sessions.query('target_NM == @target_NM')
df_target.groupby(['subject', 'session_type']).count()['session_n']

subject                    session_type
ZFM-08488                  training         7
ZFM-08554                  training         5
ZFM-08689                  habituation      2
                           training         4
ZFM-08818                  habituation      3
                           training        20
ZFM-08827                  habituation      3
                           training        20
ZFM-08828                  habituation      3
                           training         6
photometry_test_subject_B  misc             2
Name: session_n, dtype: int64

In [38]:
# Pick a session
subject = 'ZFM-08828'
session_type = 'training'
session_n = -2
session = df_target.query('(subject == @subject) & (session_type == @session_type)').iloc[session_n]
eid = session['eid']
print(session[['subject', 'target_NM', 'session_type', 'session_n', 'start_time', 'eid']])

# Get responses to task events
session = _load_event_times(session, one)

subject                                    ZFM-08828
target_NM                                     VTA-DA
session_type                                training
session_n                                          8
start_time                2025-05-29T12:40:42.372336
eid             3ebb02ef-9be1-4d44-81fe-9403173b7b48
Name: 2750, dtype: object


In [45]:
# Pick a subject
subject = 'ZFM-08863'
eids = one.search(subject=subject)

# Pick a session
session_i = 21
eid = eids[session_i]

session_details = one.get_details(eid) 
session_type = protocol2type(session_details['task_protocol'])
session_n = session_i + 1 if session_i >= 0 else len(eids) + session_i + 1
session = pd.Series(data={'eid': eid, 'subject': subject, 'session_type': session_type, 'session_n': session_n})
print(session[['subject', 'session_type', 'session_n', 'eid']])

# Get responses to task events
session = _load_event_times(session, one)

IndexError: list index out of range

> [0;32m/tmp/ipykernel_45943/2497906931.py[0m(7)[0;36m<module>[0;34m()[0m
[0;32m      5 [0;31m[0;31m# Pick a session[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      6 [0;31m[0msession_i[0m [0;34m=[0m [0;36m21[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m----> 7 [0;31m[0meid[0m [0;34m=[0m [0meids[0m[0;34m[[0m[0msession_i[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      8 [0;31m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0msession_details[0m [0;34m=[0m [0mone[0m[0;34m.[0m[0mget_details[0m[0;34m([0m[0meid[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


In [37]:
# Get photometry data
signals = one.load_dataset(id=eid, dataset='photometry.signal.pqt')
locations = one.load_dataset(id=eid, dataset='photometryROI.locations.pqt').reset_index()
rois = locations['ROI'].to_list()
photometry = signals[list(rois) + ['name']].set_index(signals['times']).dropna()
print(locations)

  ROI     fiber brain_region
0  G0  fiber_LC           LC


In [38]:
# Pick an ROI
# roi = 'G0'
roi = locations['ROI'].iloc[0]

In [39]:
# Restrict raw photometry to the task period
buffer = 15
loader = SessionLoader(one, eid=eid)
## FIXME: appropriately handle cases with multiple task collections
loader.load_trials(collection='alf/task_00')
timings = [col for col in loader.trials.columns if col.endswith('_times')]
t0 = loader.trials[timings].min().min()
t1 = loader.trials[timings].max().max()
i0 = photometry.index.searchsorted(t0 - buffer)
i1 = photometry.index.searchsorted(t1 + buffer)
photometry = photometry.iloc[i0:i1].copy()

# Pull channels out of df
gcamp = photometry.query('name == "GCaMP"')
iso = photometry.query('name == "Isosbestic"')

In [40]:
events = ['cue', 'movement', 'reward', 'omission']
psths = []
for event in events:
    responses, tpts = get_responses(gcamp[roi], session[f'{event}_times'])
    psths.append(responses)

In [41]:
%matplotlib tk

fig = plt.figure(figsize=(12, 8))
grid = gridspec.GridSpec(2, 3)

ax2 = fig.add_subplot(grid[1, 0])
ax3 = fig.add_subplot(grid[1, 1])
ax4 = fig.add_subplot(grid[1, 2])
ax1 = fig.add_subplot(grid[0, :])

ax1.plot(gcamp[roi])
ax1.plot(iso[roi], color='gray')
ax1.set_ylabel('Signal (a.u.)')
ax1.set_title(f"{session['subject']}, {session['session_type']}, session {session['session_n']}")

colors = [EVENT2COLOR[event] for event in events]
for event, responses, color, ax in zip(events, psths, colors, [ax2, ax3, ax4, ax4]):
    ax.plot(tpts, responses.mean(axis=0), color=color, label=event)
    ax.plot(tpts, responses.mean(axis=0) - stats.sem(responses, axis=0), ls='--', color=color)
    ax.plot(tpts, responses.mean(axis=0) + stats.sem(responses, axis=0), ls='--', color=color)
    ax.axvline(0, ls='--', color='black', alpha=0.5)
    ax.axhline(0, ls='--', color='gray', alpha=0.5)
    ax.set_xlabel('Time (s)')
    if event != 'omission':
        ax.set_title(event.capitalize())
    ax.ticklabel_format(axis='y', style='sci', scilimits=[-2, 2])