In [1]:
# General imports 
import os 
import sys
import logging
import platform
from os.path import join as pjoin
import re
import numpy as np
import pandas as pd
import scipy.stats as stats 
from skimage import measure
import matplotlib.pyplot as plt
from pathlib import Path
# Pynwb imports
from hdmf_zarr import NWBZarrIO
from nwbwidgets import nwb2widget

sys.path.insert(0,'/code/src')
from bci.loaders import load
from bci.thresholds.thresholds import align_thresholds
from bci.trials.align import indep_roll
from bci.dataviz import traces

In [2]:
logging.basicConfig(
    format="%(asctime)s | %(message)s ",
    datefmt="%d/%m/%Y %I:%M:%S %p",
    level=logging.INFO,
)

In [3]:
# set data path
platstring = platform.platform()
system = platform.system()
if system == "Darwin":
    # macOS
    data_dir = "/Volumes/Brain2025/"
elif system == "Windows":
    # Windows (replace with the drive letter of USB drive)
    data_dir = "E:/"
elif "amzn" in platstring:
    # then on CodeOcean
    data_dir = "/data/"
else:
    # then your own linux platform
    # EDIT location where you mounted hard drive
    data_dir = "/media/$USERNAME/Brain2025/"
    
print('data directory set to', data_dir)

data directory set to /data/


In [4]:
# Load metadata csv file
metadata = pd.read_csv(os.path.join(data_dir, 'bci_task_metadata', 'bci_metadata.csv'))
# Get all mice available
subject_ids = np.sort(metadata['subject_id'].unique())
# Select one mice
n_subjects = len(subject_ids)
subject_id = 772414#subject_ids[1]#754303
# Select one subject metadata, sorted by 'session_number'
this_mouse_metadata = metadata[metadata['subject_id']==subject_id].sort_values(by='session_number')
# Pick one session for this mouse
session_name = this_mouse_metadata.name.values[3]
print('Selected subject is', subject_id)
print('Selected session is', session_name)

Selected subject is 772414
Selected session is single-plane-ophys_772414_2025-02-10_11-15-26_processed_2025-08-04_23-06-21


In [5]:
# Read data in nwb file
nwbfile = load.load_nwb_session_file(session_name)
epoch_table = nwbfile.intervals["epochs"].to_dataframe()
dff_traces = nwbfile.processing["processed"].data_interfaces["dff"].roi_response_series["dff"].data
roi_table = nwbfile.processing["processed"].data_interfaces["image_segmentation"].plane_segmentations["roi_table"].to_dataframe()
frame_rate = nwbfile.imaging_planes["processed"].imaging_rate
bci_trials = nwbfile.stimulus["Trials"].to_dataframe()
thresholds = load.load_session_thresh_file(session_name)

BCI data directory: /data/brain-computer-interface

Session directory: /data/brain-computer-interface/single-plane-ophys_772414_2025-02-10_11-15-26_processed_2025-08-04_23-06-21

NWB file: single-plane-ophys_772414_2025-02-10_11-15-26_behavior_nwb
NWB path: /data/brain-computer-interface/single-plane-ophys_772414_2025-02-10_11-15-26_processed_2025-08-04_23-06-21/single-plane-ophys_772414_2025-02-10_11-15-26_behavior_nwb


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


All threshold files for mouse 772414: ['single-plane-ophys_772414_2025-02-10', 'single-plane-ophys_772414_2025-02-06', 'single-plane-ophys_772414_2025-01-27']

Found threshold file at: /data/bci-thresholds/single-plane-ophys_772414_2025-02-10


In [6]:
# Add threshold information to the trials
bci_trials = align_thresholds(bci_trials=bci_trials, thresholds=thresholds)
# Get correct trials
correct_bci_trials = bci_trials[bci_trials['hit']==True]
# Select trials where there is information about threshold
correct_bci_trials = correct_bci_trials[correct_bci_trials['low'].notna()]
correct_bci_trials.dropna(inplace=True,subset=['start_time', 'stop_time','threshold_crossing_times'])
correct_bci_trials=correct_bci_trials.reset_index()

# select relevant information
BCI_epochs = epoch_table[epoch_table.stimulus_name.str.contains('BCI')]
start_bci_epoch = BCI_epochs.loc[BCI_epochs.index[0]].start_frame
stop_bci_epoch = BCI_epochs.loc[BCI_epochs.index[0]].stop_frame
start_bci_trial = correct_bci_trials['start_frame']-start_bci_epoch
stop_bci_trial = correct_bci_trials['stop_frame']-start_bci_epoch
thrcrossframe_bci_trial = np.round(correct_bci_trials['threshold_crossing_times']*frame_rate).astype(int)
zaber_steps = np.round(np.array(correct_bci_trials['zaber_step_times'].tolist())*frame_rate)
go_cue_bci = np.round(correct_bci_trials['go_cue']*frame_rate).astype(int)
reward_time = np.round(correct_bci_trials['reward_time']*frame_rate).astype(int)

low_thres =correct_bci_trials.low
high_thres =correct_bci_trials.high

total difference in dataframes: 3


In [None]:
# Select relevant epoch
dff_bci = dff_traces[start_bci_epoch:stop_bci_epoch, :]
dff_bci = dff_bci.T# Transpose so rows are ROI IDs

# Remove ROIs with traces that are NaNs (note - this takes a few seconds)
valid_trace_ids = [i for i in range(dff_traces.shape[1]) if np.isnan(dff_traces[0, i])==False]
# Limit ROI table to non-NaN traces
roi_table2 = roi_table.loc[valid_trace_ids]

# Find the likely somatic ROIs
soma_probability = 0.005 # Emperically determined threshold - just trust us
# Limit to valid somatic ROIs
valid_rois = roi_table2[roi_table2.soma_probability>soma_probability]
target_roi_idx = bci_trials['closest_roi'].unique()
print(f"CN: {target_roi_idx}")
if len(target_roi_idx)>1:
    raise ValueError("More than one CN during BCI epoch")
target_roi_idx=target_roi_idx[0]
if not(target_roi_idx in valid_rois.index):
    valid_rois = pd.concat((valid_rois, roi_table2.loc[[target_roi_idx], :]), axis=0)
    valid_rois = valid_rois.sort_index()

# Select valid rois
dff_bci_valid = dff_bci[valid_rois.index.values, :]
roi_original_idx = valid_rois.reset_index()['id']
#cn_new_idx = roi_original_idx[roi_original_idx==target_roi_idx].index[0]

# Smooth dff
smoothing_window = 10
smooth_dff_valid = np.full(dff_bci_valid.shape,np.nan)
kernel = np.ones(smoothing_window) / smoothing_window
for itr,trial in enumerate (dff_bci_valid):
    smooth_dff_valid[itr] = np.convolve(trial, kernel, mode='same')

# Organize data by trials
n_rois = smooth_dff_valid.shape[0]
n_trials = len(start_bci_trial)
max_tr_duration = np.max(stop_bci_trial-start_bci_trial)
dff_by_trial = np.full((n_rois,n_trials,max_tr_duration*2),np.nan)
for itr,(ist,istp) in enumerate(zip(start_bci_trial,stop_bci_trial)):

    dff_by_trial[:,itr,:int(istp-ist)] = smooth_dff_valid[:,ist:istp]

# Let's align on threshold_crossing_times
frames_before = int(np.max(thrcrossframe_bci_trial.values))
shifts = thrcrossframe_bci_trial.values -frames_before
dff_bci_alignon_thr = indep_roll(dff_by_trial,-shifts,axis=-1)

# get when the threshold changes
idx_threshold_change = np.insert(np.where(np.insert(np.diff(high_thres), 0, 0)),0,0)
idx_threshold_change=np.append(idx_threshold_change,n_trials)
# idx_threshold_change


CN: [1075]


: 

In [8]:
save=True

In [None]:
colors=['g','c']*5
for i_roi, roi_idx in enumerate(roi_original_idx):
    logging.info(i_roi)
    cn = False
    if roi_idx == target_roi_idx:
        cn= True

    roi_dff_bci_alignon_thr = dff_bci_alignon_thr[i_roi]
    aligned_zaber_steps = zaber_steps-shifts.reshape(-1,1).astype(int)

    fig=traces.plot_lick_spout_steps_dff(roi_dff_bci_alignon_thr,zaber_steps=aligned_zaber_steps,frames_before=frames_before,idx_threshold_change=idx_threshold_change,
                                    subject_id=subject_id,session_name=session_name,i_roi=roi_idx,cn=cn,colors=colors,save=save, savepath='/scratch',save_format='jpg')
    plt.close('all')


02/09/2025 03:30:47 AM | 0 
02/09/2025 03:30:48 AM | 1 
02/09/2025 03:30:48 AM | 2 
02/09/2025 03:30:48 AM | 3 
02/09/2025 03:30:49 AM | 4 
02/09/2025 03:30:49 AM | 5 
02/09/2025 03:30:49 AM | 6 
02/09/2025 03:30:50 AM | 7 
02/09/2025 03:30:50 AM | 8 
02/09/2025 03:30:50 AM | 9 
02/09/2025 03:30:51 AM | 10 
02/09/2025 03:30:51 AM | 11 
02/09/2025 03:30:51 AM | 12 
02/09/2025 03:30:52 AM | 13 
02/09/2025 03:30:52 AM | 14 
02/09/2025 03:30:52 AM | 15 
02/09/2025 03:30:53 AM | 16 
02/09/2025 03:30:53 AM | 17 
02/09/2025 03:30:53 AM | 18 
02/09/2025 03:30:54 AM | 19 
02/09/2025 03:30:54 AM | 20 
02/09/2025 03:30:54 AM | 21 
02/09/2025 03:30:55 AM | 22 
02/09/2025 03:30:55 AM | 23 
02/09/2025 03:30:55 AM | 24 
02/09/2025 03:30:56 AM | 25 
02/09/2025 03:30:56 AM | 26 
02/09/2025 03:30:56 AM | 27 
02/09/2025 03:30:57 AM | 28 
02/09/2025 03:30:57 AM | 29 
02/09/2025 03:30:57 AM | 30 
02/09/2025 03:30:58 AM | 31 
02/09/2025 03:30:58 AM | 32 
02/09/2025 03:30:58 AM | 33 
02/09/2025 03:30:59 AM |

In [None]:
ntrials = dff_signal.shape[0]
max_act = np.nanmax(dff_signal)+0.555
dff_signal + np.arange(0,int(ntrials*max_act),max_act).reshape(-1,1) 

NameError: name 'dff_signal' is not defined

In [None]:
ntrials*max_act

np.float64(73.26482465956362)

In [None]:
(ntrials*max_act)/ntrials

np.float64(1.4652964931912722)

In [None]:
np.arange(0,ntrials*max_act,max_act)

array([ 0.        ,  1.46529649,  2.93059299,  4.39588948,  5.86118597,
        7.32648247,  8.79177896, 10.25707545, 11.72237195, 13.18766844,
       14.65296493, 16.11826143, 17.58355792, 19.04885441, 20.5141509 ,
       21.9794474 , 23.44474389, 24.91004038, 26.37533688, 27.84063337,
       29.30592986, 30.77122636, 32.23652285, 33.70181934, 35.16711584,
       36.63241233, 38.09770882, 39.56300532, 41.02830181, 42.4935983 ,
       43.9588948 , 45.42419129, 46.88948778, 48.35478428, 49.82008077,
       51.28537726, 52.75067375, 54.21597025, 55.68126674, 57.14656323,
       58.61185973, 60.07715622, 61.54245271, 63.00774921, 64.4730457 ,
       65.93834219, 67.40363869, 68.86893518, 70.33423167, 71.79952817,
       73.26482466])