In [None]:
import pandas as pd
import os
import json
import hdmf_zarr
from hdmf.common import DynamicTable
import numpy as np
from typing import List, Dict, Optional, Tuple
import glob 

def get_sessions_of_interest(summary_path: str, experiment_filter: str = 'NPUltra_psychedelics', 
                           upload_filter: str = 'yes') -> List[str]:
    """
    Generate list of sessions that meet specified criteria from recording summary.
    
    Args:
        summary_path: Path to NPUltra_recording_summary.xlsx
        experiment_filter: Experiment type to filter for
        upload_filter: Code Ocean Upload status to filter for
    
    Returns:
        List of sessions that match the criteria 
    """
    recording_summary_table = pd.read_excel(summary_path)
    
    filtered_sessions = recording_summary_table[
        (recording_summary_table['experiment'] == experiment_filter) &
        (recording_summary_table['uploaded to CO'] == upload_filter)
    ]
    
    session_list = filtered_sessions['session'].tolist()
    print(f"Found {len(session_list)} sessions matching criteria")
    
    return session_list

def extract_analysis_table(session_folder: str) -> Optional[pd.DataFrame]:
    # Optional indicates that function can return None or pd.DataFrame 
    """
    Extract [postprocessed] units table from the session folder. This table should be renamed to analysis_table to avoid 
    confusion with the existing units table in the NWB file. 
    
    Args:
        session_folder: Path to session directory
    
    Returns:
        DataFrame with analysis data or None if not found
    """
    units_table_path = os.path.join(session_folder, "processed_data", "units")
    
    if not os.path.exists(units_table_path):
        print(f"Units folder does not exist in {session_folder}")
        return None
    
    # Find units table pickle (exclude units_epochs)
    units_files = [f for f in os.listdir(units_table_path) 
                   if 'units_epoch' in f and 'units_epochs' not in f]
    
    if not units_files:
        print(f"No units table files found in {units_table_path}")
        return None
    
    units_table_file = os.path.join(units_table_path, units_files[0])
    analysis_table = pd.read_pickle(units_table_file)
    print(f"Loaded units table from {units_table_file}")
    
    return analysis_table

def extract_stimulus_table(session_folder: str) -> Optional[pd.DataFrame]:
    """
    Extract stimulus table from session folder and add stimulus information from metadata.
    
    Args:
        session_folder: Path to session directory
    
    Returns:
        DataFrame with stimulus data or None if not found
    """
    stim_path = os.path.join(session_folder, "processed_data", "stim")
    
    if not os.path.exists(stim_path):
        print(f"Stimulus folder does not exist in {session_folder}")
        return None
    
    # Initialize stimulus table
    stim_table = pd.DataFrame(columns=['start_time', 'stop_time', 'power', 'stim_name', 
                                     'emission_location', 'wavelength'])
    
    # Load stimulus table from opto_stim_df.csv
    stim_table_path = os.path.join(stim_path, "opto_stim_df.csv")
    if not os.path.exists(stim_table_path):
        print(f"No stimulus table files found in {stim_path}")
        return None
    
    opto_stim_df = pd.read_csv(stim_table_path)
    print(f"Loaded stimulus table from {stim_table_path}")
    
    # Populate stimulus table
    stim_table['start_time'] = opto_stim_df['stim_on']
    stim_table['stop_time'] = opto_stim_df['stim_off']
    stim_table['stim_name'] = opto_stim_df['epoch_label']
    stim_table['emission_location'] = opto_stim_df['probe']
    
    # Extract more stimulus information from session.json metadata 
    metadata_path = os.path.join(session_folder, "metadata", "session.json")
    if os.path.exists(metadata_path):
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
            
        # Extract wavelength and power from first stimulus epoch light_source_config 
        wavelength = None
        power = None
        for epoch in metadata.get('stimulus_epochs', []):
            light_config = epoch.get('light_source_config')
            if light_config is not None:
                wavelength = light_config.get('wavelength')
                power = light_config.get('laser_power')
                break # Only store first epoch's config as the rest are redundant 
        
        stim_table['wavelength'] = wavelength
        stim_table['power'] = power
    
    return stim_table

def create_epoch_table(session_folder: str) -> Optional[pd.DataFrame]:
    """
    Create epoch table from stimulus epoch files and metadata.
    
    Args:
        session_folder: Path to session directory
    
    Returns:
        DataFrame with epoch information
    """
    epoch_path = os.path.join(session_folder, "processed_data", "stim")
    
    if not os.path.exists(epoch_path):
        print(f"Stimulus folder does not exist in {session_folder}")
        return None
    
    # Process each CSV file that contains trial by trial timing information for each stimulus epoch
    all_epoch_data = []
    epoch_files = [f for f in os.listdir(epoch_path) if f.endswith('.csv')]
    
    if not epoch_files:
        print(f"No CSV files found in {epoch_path}")
        return None
    
    # Each CSV corresponds to a different stimulus epoch type 
    for epoch_file in epoch_files:
        file_path = os.path.join(epoch_path, epoch_file)
        
        # Read required columns, handle missing stim_off 
        # Only the OptoTagging epochs have a stim_off column, others do not 
        try:
            temp_df = pd.read_csv(file_path, usecols=['stim_on', 'stim_off', 'epoch_label'])
        except ValueError:
            temp_df = pd.read_csv(file_path, usecols=['stim_on', 'epoch_label'])
            temp_df['stim_off'] = pd.NA
        
        all_epoch_data.append(temp_df)
    
    # Combine trial-by-trial information across all stimulus epochs
    combined_df = pd.concat(all_epoch_data, ignore_index=True)
    
    # Group combined dataframe by epoch_label to get start/stop times for the entire epoch 
    epoch_summary = combined_df.groupby('epoch_label').apply(
        lambda group: pd.Series({
            'start_time': group['stim_on'].min(),
            'stop_time': group['stim_off'].max() if 'OptoTagging' in group.name else group['stim_on'].max()
        }) 
    ).reset_index() # Use max for stim_off only if it exists, otherwise use max of stim_on
    
    # Rename columns 
    epoch_table = epoch_summary.rename(columns={'epoch_label': 'stim_name'})
    
    # Sort epochs by start_time 
    epoch_table = epoch_table.sort_values(by='start_time').reset_index(drop=True)
    
    return epoch_table

def add_injection_epoch(epoch_table: pd.DataFrame, session_folder: str) -> pd.DataFrame:
    """
    Add injection epoch between OptoTagging_0 and Spontaneous_1. 
    
    Args:
        epoch_table: Existing epoch table
        session_folder: Path to session directory
    
    Returns:
        Modified epoch table with injection epoch added
    """
    metadata_path = os.path.join(session_folder, "metadata", "session.json")
    
    if not os.path.exists(metadata_path):
        print(f"Metadata file does not exist in {metadata_path}")
        return epoch_table
    
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    
    # Determine injection type from OptoTagging stimulus epoch note 
    injection_name = "Injection"  # Default if missing info in metadata 
    for epoch in metadata.get('stimulus_epochs', []):
        if epoch.get('stimulus_name') == 'OptoTagging':
            notes = epoch.get('notes', '')
            if 'Saline' in notes:
                injection_name = "Saline_Injection"
            elif 'Psilocybin' in notes:
                injection_name = "Psilocybin_Injection"
            break
    
    # Get start and stop times for new epoch based on the timing of Optotaggin_0 and Spontaneous_1 epochs 
    opto_0_stop = None
    spont_1_start = None
    
    for idx, row in epoch_table.iterrows():
        if row['stim_name'] == 'OptoTagging_0': # Get value of stop_time for OptoTagging_0
            opto_0_stop = row['stop_time']
        elif row['stim_name'] == 'Spontaneous_1': # Get value of start_time for Spontaneous_1 
            spont_1_start = row['start_time']
    
    if opto_0_stop is not None and spont_1_start is not None:
        # Create injection epoch with time buffer 
        injection_start = opto_0_stop + 0.01
        injection_end = spont_1_start - 0.01
        
        # Identify the correct index in epoch_table to insert the new epoch (after Optotagging_0)
        insert_idx = None
        for idx, row in epoch_table.iterrows():
            if row['stim_name'] == 'OptoTagging_0':
                insert_idx = idx + 1
                break
        
        # Fill out the new row with injection information 
        new_row = pd.DataFrame({
            'stim_name': [injection_name],
            'start_time': [injection_start],
            'stop_time': [injection_end]
        })
        
        # Insert the new row at the correct index into the epoch_table 
        if insert_idx is not None:
            epoch_table = pd.concat([
                epoch_table.iloc[:insert_idx], # Keep rows before the insertion point 
                new_row, 
                epoch_table.iloc[insert_idx:] # Keep rows after the insertion point 
            ], ignore_index=True)
            print(f"Added {injection_name} epoch from {injection_start} to {injection_end}")
    
    return epoch_table

def convert_dataframe_to_dynamic_table(df: pd.DataFrame, table_name: str) -> DynamicTable:
    """
    Convert pandas DataFrame to DynamicTable object compatible with NWB files. 
    
    Args:
        df: Input DataFrame
        table_name: Name for the dynamic table
    
    Returns:
        DynamicTable object
    """
    dynamic_table = DynamicTable.from_dataframe(
        name=table_name.lower().replace(' ', '_'),
        df=df
    )
    # name ensures the input is lowercase and has underscores  
    print(f"Created DynamicTable '{table_name}' with {len(df)} rows and {len(df.columns)} columns")
    return dynamic_table

def modify_nwb_file(original_path: str, new_path: str, analysis_table: Optional[pd.DataFrame] = None,
                   stim_table: Optional[pd.DataFrame] = None, epoch_table: Optional[pd.DataFrame] = None) -> None:
    """
    Modify existing NWB file with new tables and save to new location. 
    Tables are pandas DataFrames that will be converted to DynamicTables. 
    
    Args:
        original_path: Path to input NWB file
        new_path: Path for output NWB file
        analysis_table: Analysis table to add
        stim_table: Stimulus table to add
        epoch_table: Epoch table to add
    """
    # Create directory for new file
    os.makedirs(os.path.dirname(new_path), exist_ok=True)
    
    # Open existing NWB file
    io = hdmf_zarr.NWBZarrIO(original_path, mode="r")
    nwbfile = io.read()
    
    # Add analysis table
    if analysis_table is not None:
        print("Processing analysis table...")
        try:
            analysis_dynamic_table = convert_dataframe_to_dynamic_table(
                df=analysis_table,
                table_name="analysis_table"
            )
            nwbfile.add_analysis(analysis_dynamic_table)
            print(f"Added analysis table with {len(analysis_table)} rows")
        except Exception as e:
            print(f"Failed to add analysis table: {e}")
            raise
    
    # Add stimulus table
    if stim_table is not None:
        print("Processing stimulus table...")
        try:
            stimulus_dynamic_table = convert_dataframe_to_dynamic_table(
                df=stim_table,
                table_name="stimulus_table"
            )
            nwbfile.add_stimulus(stimulus_dynamic_table)
            print(f"Added stimulus table with {len(stim_table)} rows")
        except Exception as e:
            print(f"Failed to add stimulus table: {e}")
            raise
    
    # Add epoch table
    if epoch_table is not None:
        print("Processing epoch table...")
        try: # Constructs epoch table row by row from DataFrame  
            for idx, row in epoch_table.iterrows():
                tags = [str(row['stim_name'])] if 'stim_name' in row and pd.notna(row['stim_name']) else []
                nwbfile.add_epoch(
                    start_time=float(row['start_time']),
                    stop_time=float(row['stop_time']),
                    tags=tags
                )
            print(f"Added {len(epoch_table)} epochs")
        except Exception as e:
            print(f"Failed to add epoch table: {e}")
            raise
    
    # Write modified NWB to new file 
    with hdmf_zarr.NWBZarrIO(new_path, mode='w') as export_io:
        export_io.export(src_io=io, nwbfile=nwbfile)
    
    io.close()
    print(f"Saved to: {new_path}")

def process_single_session(session_name: str, base_path: str, output_path: str) -> None:
    """
    Process a single session through the complete pipeline. Finds correct file paths based on session_name for all necessary directories and files. 
    
    Args:
        session_name: Name of the session to process
        base_path: Base path containing session folders
        output_path: Output directory for modified NWB files
    """
    session_folder = os.path.join(base_path, session_name)
    print(f"\nProcessing session: {session_name}")
    
    if not os.path.exists(session_folder):
        print(f"Session folder {session_folder} does not exist")
        return
    
    # Extract all tables
    analysis_table = extract_analysis_table(session_folder)
    stim_table = extract_stimulus_table(session_folder)
    epoch_table = create_epoch_table(session_folder)
    
    # Add injection epoch if epoch table exists
    if epoch_table is not None:
        epoch_table = add_injection_epoch(epoch_table, session_folder)
    
    # Find original NWB file
    nwb_search_path = os.path.join(session_folder, f"*{session_name.split('_')[1]}*")
    
    # First, find directories containing the session name
    cleaned_session_name = session_name[0].replace('-', '')
    # Within the session folder, search for directories that contain the cleaned session name
    session_dirs = glob.glob(os.path.join(session_folder, f"*{cleaned_session_name}*")) 
    if not session_dirs:
        print(f"No directory containing '{cleaned_session_name}' found in {session_folder}")

    # Then search for NWB files containing the session name within those directories
    nwb_files = []
    for session_dir in session_dirs:
        # instead, search for a pattern that ends in .nwb 
        nwb_pattern = os.path.join(session_dir, f"*.nwb")
        # Get full file path to nwb file 
        nwb_files.extend(glob.glob(nwb_pattern))
        # Get nwb file name to match for output file 
        nwb_files.extend([f for f in os.listdir(session_dir) if f.endswith('.nwb')])

    if not nwb_files:
        print(f"No NWB file containing '{session_name}' found in session directories")

    original_nwb = nwb_files[0]  # Take the first match
    print(f"Found NWB file: {original_nwb}")
    new_nwb = os.path.join(output_path, nwb_files[1])
    
    # Modify NWB file
    modify_nwb_file(
        original_path=original_nwb,
        new_path=new_nwb,
        analysis_table=analysis_table,
        stim_table=stim_table,
        epoch_table=epoch_table
    )

In [None]:
# Example usage: 
filtered_session_list = get_sessions_of_interest(
    summary_path = "/Volumes/scratch/andrew.shelton/NPUltra_data/raw_npultra_data/NPUltra_recording_summary.xlsx",
    experiment_filter = "NPUltra_psychedelics", 
    upload_filter = 'yes'
    )

process_single_session(
    session_name = filtered_session_list[2],  # Process one session as an example 
    base_path = ' /Volumes/scratch/andrew.shelton/NPUltra_data/raw_npultra_data/',
    output_path = '/Volumes/scratch/suyee.lee'
)

Found 12 sessions matching criteria

Processing session: 2024-05-16_714789
Loaded units table from /Volumes/scratch/andrew.shelton/NPUltra_data/raw_npultra_data/2024-05-16_714789/processed_data/units/2024-05-16_714789_1_units_epoch.pkl
Loaded stimulus table from /Volumes/scratch/andrew.shelton/NPUltra_data/raw_npultra_data/2024-05-16_714789/processed_data/stim/opto_stim_df.csv


  combined_df = pd.concat(all_epoch_data, ignore_index=True)
  epoch_summary = combined_df.groupby('epoch_label').apply(


Added Saline_Injection epoch from 2719.7499900000003 to 2781.6479499999996
Found NWB file: /Volumes/scratch/andrew.shelton/NPUltra_data/raw_npultra_data/2024-05-16_714789/20240516_714789_1/ecephys_714789_2024-05-16_13-16-59_experiment1_recording1.nwb


  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."


Processing analysis table...
Created DynamicTable 'analysis_table' with 910 rows and 54 columns
Added analysis table with 910 rows
Processing stimulus table...
Created DynamicTable 'stimulus_table' with 600 rows and 6 columns
Added stimulus table with 600 rows
Processing epoch table...
Added 13 epochs


  _init_array_metadata(


Saved to: /Volumes/scratch/suyee.lee/ecephys_714789_2024-05-16_13-16-59_experiment1_recording1.nwb


In [None]:
# Check outputs of new nwb file 
# 
# new_nwb = '/Volumes/scratch/suyee.lee/ecephys_714789_2024-05-16_13-16-59_experiment1_recording1.nwb'

with hdmf_zarr.NWBZarrIO(new_nwb, mode='r') as io:
    new_nwbfile = io.read()
    print("Loaded new NWB file:", new_nwb)
    print("Processing modules:", list(new_nwbfile.processing.keys()))
    print("Stimulus presentations and epochs:", list(new_nwbfile.stimulus.keys()) if hasattr(new_nwbfile, 'stimulus') else "None")
    print("Epochs:", new_nwbfile.epochs)  # Should show the added epochs
    print("Trials:", new_nwbfile.trials)  # Should show the added trials


Loaded new NWB file: /Volumes/scratch/suyee.lee/ecephys_714789_2024-05-16_13-16-59_experiment1_recording1.nwb
Processing modules: ['ecephys']
Stimulus presentations and epochs: ['stimulus_table']
Epochs: epochs pynwb.epoch.TimeIntervals at 0x17964502048
Fields:
  colnames: ['start_time' 'stop_time' 'tags']
  columns: (
    start_time <class 'hdmf.common.table.VectorData'>,
    stop_time <class 'hdmf.common.table.VectorData'>,
    tags_index <class 'hdmf.common.table.VectorIndex'>,
    tags <class 'hdmf.common.table.VectorData'>
  )
  description: experimental epochs
  id: id <class 'hdmf.common.table.ElementIdentifiers'>

Trials: None
