In [1]:
import pandas as pd
import os
import json
import logging
from pathlib import Path

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Core Loading Functions
def load_units_table(session_folder):
    """Load units table from session folder."""
    units_table_path = os.path.join(session_folder, "processed_data", "units")
    
    if not os.path.exists(units_table_path):
        logger.warning(f"Units folder does not exist in {session_folder}")
        return None
    
    # Find filename for units table pickle (exclude units_epochs)
    units_files = [f for f in os.listdir(units_table_path) 
                   if 'units_epoch' in f and 'units_epochs' not in f]
    
    if not units_files:
        logger.warning(f"No units table files found in {units_table_path}")
        return None
    
    units_table_file = os.path.join(units_table_path, units_files[0])
    try:
        units_table = pd.read_pickle(units_table_file)
        logger.info(f"Loaded units table from {units_table_file}")
        return units_table
    except Exception as e:
        logger.error(f"Error loading units table from {units_table_file}: {e}")
        return None

def load_stimulus_table(session_folder):
    """Create stimulus table from opto_stim_df.csv and session metadata."""
    stim_path = os.path.join(session_folder, "processed_data", "stim")
    
    if not os.path.exists(stim_path):
        logger.warning(f"Stimulus folder does not exist in {session_folder}")
        return None
    
    # Find opto_stim_df file
    opto_files = [f for f in os.listdir(stim_path) if 'opto_stim_df' in f]
    
    if not opto_files:
        logger.warning(f"No opto_stim_df files found in {stim_path}")
        return None
    
    try:
        # Load opto stimulus data
        stim_table_path = os.path.join(stim_path, opto_files[0])
        opto_stim_df = pd.read_csv(stim_table_path)
        logger.info(f"Loaded stimulus table from {stim_table_path}")
        
        # Create new stimulus table with required columns
        stim_table = pd.DataFrame()
        stim_table['start_time'] = opto_stim_df['stim_on']
        stim_table['stop_time'] = opto_stim_df['stim_off']
        stim_table['stim_name'] = opto_stim_df['epoch_label']
        stim_table['emission_location'] = opto_stim_df['probe']
        
        # Get metadata for wavelength and power
        wavelength, power = load_stimulus_metadata(session_folder)
        stim_table['wavelength'] = wavelength
        stim_table['power'] = power
        
        return stim_table
        
    except Exception as e:
        logger.error(f"Error creating stimulus table: {e}")
        return None

def load_stimulus_metadata(session_folder):
    """Extract wavelength and power from session metadata."""
    metadata_path = os.path.join(session_folder, "metadata", "session.json")
    
    if not os.path.exists(metadata_path):
        logger.warning(f"Metadata file does not exist: {metadata_path}")
        return None, None
    
    try:
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
            
        # Extract wavelength and power from stimulus_epochs with light_source_config
        for epoch in metadata.get('stimulus_epochs', []):
            light_config = epoch.get('light_source_config')
            if light_config is not None:
                wavelength = light_config.get('wavelength')
                power = light_config.get('laser_power')
                return wavelength, power
                
        return None, None
        
    except Exception as e:
        logger.error(f"Error loading stimulus metadata: {e}")
        return None, None

def load_epoch_table(session_folder):
    """Create epoch table from stimulus CSV files and metadata."""
    stim_path = os.path.join(session_folder, "processed_data", "stim")
    
    if not os.path.exists(stim_path):
        logger.warning(f"Stimulus folder does not exist in {session_folder}")
        return None
    
    try:
        # Initialize lists to store epoch data
        epoch_names = []
        start_times = []
        stop_times = []
        
        # Process each CSV file
        stim_files = [f for f in os.listdir(stim_path) if f.endswith('.csv')]
        if stim_files:
            for stim_file in stim_files:
                file_path = os.path.join(stim_path, stim_file)
                
                # Read required columns, handle missing stim_off
                try:
                    temp_df = pd.read_csv(file_path, usecols=['stim_on', 'stim_off', 'epoch_label'])
                except ValueError:
                    temp_df = pd.read_csv(file_path, usecols=['stim_on', 'epoch_label'])
                    temp_df['stim_off'] = pd.NA
                
                # Process each unique epoch
                for epoch_label in temp_df['epoch_label'].unique():
                    epoch_data = temp_df[temp_df['epoch_label'] == epoch_label]
                    
                    # Determine stop_time based on epoch type
                    if 'OptoTagging' in epoch_label:
                        stop_time = epoch_data['stim_off'].max()
                    else:
                        stop_time = epoch_data['stim_on'].max()
                    
                    start_time = epoch_data['stim_on'].min()
                    
                    epoch_names.append(epoch_label)
                    start_times.append(start_time)
                    stop_times.append(stop_time)

        # Create initial epoch table
        epoch_table = pd.DataFrame({
            'stim_name': epoch_names,
            'start_time': start_times,
            'stop_time': stop_times
        })
        
        # Sort by start_time
        epoch_table = epoch_table.sort_values(by='start_time').reset_index(drop=True)
        
        # Add injection epochs
        epoch_table = add_injection_epochs(epoch_table, session_folder)
        
        logger.info(f"Created epoch table with {len(epoch_table)} epochs")
        return epoch_table
        
    except Exception as e:
        logger.error(f"Error creating epoch table: {e}")
        return None

def add_injection_epochs(epoch_table, session_folder):
    """Add injection epochs between OptoTagging_0 and Spontaneous_1."""
    metadata_path = os.path.join(session_folder, "metadata", "session.json")
    
    if not os.path.exists(metadata_path):
        return epoch_table
    
    try:
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)

        # Look for injection notes in OptoTagging epochs
        for epoch in metadata.get('stimulus_epochs', []):
            if epoch.get('stimulus_name') == 'OptoTagging':
                notes = epoch.get('notes', '')
                
                if 'Saline injection' in notes or 'Psilocybin injection' in notes:
                    # Determine injection type
                    injection_name = 'Saline_Inj' if 'Saline' in notes else 'Psilocybin_Inj'
                    
                    # Find OptoTagging_0 stop and Spontaneous_1 start times
                    opto_0_stop = None
                    spont_1_start = None
                    
                    for idx, row in epoch_table.iterrows():
                        if row['stim_name'] == 'OptoTagging_0':
                            opto_0_stop = row['stop_time']
                        elif row['stim_name'] == 'Spontaneous_1':
                            spont_1_start = row['start_time']
                    
                    if opto_0_stop is not None and spont_1_start is not None:
                        injection_start = opto_0_stop + 0.01
                        injection_end = spont_1_start - 0.01
                        
                        # Find insertion point
                        insert_idx = None
                        for idx, row in epoch_table.iterrows():
                            if row['stim_name'] == 'OptoTagging_0':
                                insert_idx = idx + 1
                                break
                        
                        # Insert injection epoch
                        if insert_idx is not None:
                            new_row = pd.DataFrame({
                                'stim_name': [injection_name],
                                'start_time': [injection_start],
                                'stop_time': [injection_end]
                            })
                            
                            epoch_table = pd.concat([
                                epoch_table.iloc[:insert_idx],
                                new_row,
                                epoch_table.iloc[insert_idx:]
                            ], ignore_index=True)
                            
                            logger.info(f"Added {injection_name} epoch from {injection_start} to {injection_end}")
                    
                    break  # Only process first OptoTagging epoch with injection note
                    
    except Exception as e:
        logger.error(f"Error adding injection epochs: {e}")
    
    return epoch_table

# Table Processing Functions
def prepare_analysis_table(units_table):
    """Prepare units table for analysis storage."""
    if units_table is None:
        return None
    
    analysis_table = units_table.copy()
    
    # Remove probe_config column if it exists
    if 'probe_config' in analysis_table.columns:
        analysis_table = analysis_table.drop('probe_config', axis=1)
        logger.info("Removed probe_config column from analysis table")
    
    return analysis_table

def prepare_stimulus_table(stimulus_table):
    """Prepare stimulus table for analysis storage."""
    return stimulus_table.copy() if stimulus_table is not None else None

def prepare_epoch_table(epoch_table):
    """Prepare epoch table for analysis storage."""
    return epoch_table.copy() if epoch_table is not None else None

# Session Processing Functions
def process_session_tables(session, base_path):
    """Process all tables for a single session."""
    session_folder = os.path.join(base_path, session)
    
    if not os.path.exists(session_folder):
        logger.error(f"Session folder {session_folder} does not exist")
        return None
    
    # Load all tables
    units_table = load_units_table(session_folder)
    stimulus_table = load_stimulus_table(session_folder)
    epoch_table = load_epoch_table(session_folder)
    
    # Prepare tables for analysis
    analysis_table = prepare_analysis_table(units_table)
    prepared_stimulus = prepare_stimulus_table(stimulus_table)
    prepared_epochs = prepare_epoch_table(epoch_table)
    
    session_data = {
        'session': session,
        'analysis_table': analysis_table,
        'stimulus_table': prepared_stimulus,
        'epoch_table': prepared_epochs
    }
    
    logger.info(f"Processed tables for session: {session}")
    return session_data

def process_multiple_sessions(session_list, base_path):
    """Process tables for multiple sessions."""
    results = []
    
    for session in session_list:
        logger.info(f"Processing session: {session}")
        session_data = process_session_tables(session, base_path)
        
        if session_data:
            results.append(session_data)
        else:
            logger.warning(f"Failed to process session: {session}")
    
    logger.info(f"Successfully processed {len(results)} out of {len(session_list)} sessions")
    return results

# Example usage with your existing data
def run_analysis_pipeline(filtered_sessions, base_path="/Volumes/scratch/andrew.shelton/NPUltra_data/raw_npultra_data/"):
    """Run the complete analysis pipeline using your filtered_sessions DataFrame."""
    session_list = filtered_sessions['session'].tolist()
    return process_multiple_sessions(session_list, base_path)

In [4]:
# Your existing workflow becomes:
base_path = "/Volumes/scratch/andrew.shelton/NPUltra_data/raw_npultra_data/"

# Import recording session summary excel spreadsheet 
recording_summary = "/Volumes/scratch/andrew.shelton/NPUltra_data/raw_npultra_data/NPUltra_recording_summary.xlsx"
recording_summary_table = pd.read_excel(recording_summary)

# Filter table for sessions of interest (experiment = NPUltra_psychedelics, uploaded to CO = yes)
filtered_sessions = recording_summary_table[
    (recording_summary_table['experiment'] == 'NPUltra_psychedelics') &
    (recording_summary_table['uploaded to CO'] == 'yes')]

filtered_sessions

# # Process all sessions
# results = run_analysis_pipeline(filtered_sessions, base_path)

# # Or process a subset (like your [0:1] slice)
session_subset = filtered_sessions['session'].tolist()[0:1]
results = process_multiple_sessions(session_subset, base_path)

# Access individual tables
for session_data in results:
    print(f"Session: {session_data['session']}")
    if session_data['analysis_table'] is not None:
        print(f"Analysis table shape: {session_data['analysis_table'].shape}")
    if session_data['stimulus_table'] is not None:
        print(f"Stimulus table shape: {session_data['stimulus_table'].shape}")
    if session_data['epoch_table'] is not None:
        print(f"Epoch table shape: {session_data['epoch_table'].shape}")

INFO:__main__:Processing session: 2024-05-14_714527
INFO:__main__:Loaded units table from /Volumes/scratch/andrew.shelton/NPUltra_data/raw_npultra_data/2024-05-14_714527/processed_data/units/2024-05-14_714527_1_units_epoch.pkl
INFO:__main__:Loaded stimulus table from /Volumes/scratch/andrew.shelton/NPUltra_data/raw_npultra_data/2024-05-14_714527/processed_data/stim/opto_stim_df.csv
INFO:__main__:Added Saline_Inj epoch from 2714.53929 to 3058.6834599999997
INFO:__main__:Created epoch table with 13 epochs
INFO:__main__:Removed probe_config column from analysis table
INFO:__main__:Processed tables for session: 2024-05-14_714527
INFO:__main__:Successfully processed 1 out of 1 sessions


Session: 2024-05-14_714527
Analysis table shape: (350, 53)
Stimulus table shape: (600, 6)
Epoch table shape: (13, 3)
