In [None]:
import pynwb
from pathlib import Path
import pandas as pd
import glob
import sys

# Insert at the beginning of sys.path to prioritize this location
sys.path.insert(0, r'G...')

from spike_utils import get_good_units_with_structure, maxInterval, get_burst_times_dict, count_bursts_and_spikes_after_stim, add_metadata_to_df, add_network_burst_counts, get_session_network_bursts
from allensdk.brain_observatory.ecephys.dynamic_gating_ecephys_session import DynamicGatingEcephysSession

# Define base path
base_path = Path('...')

# Load master stim trials table
stim_table = pd.read_csv('.../metadata_tables/master_stim_trials_table.csv')

# Define all sessions
sessions_info = [
    "607660_20220607",
    "607660_20220609",
    "626279_20220928",
    "633232_20221110",
    "638387_20221201",
    "615048_20220812",
    "615048_20220810",
    "607660_20220608",
    "626279_20220926",
    "633232_20221108"
]

# Define burst parameters (same as before)
burst_params = {
    'max_begin_ISI': 0.004,  # 4 ms
    'max_end_ISI': 0.02,     # 20 ms
    'min_IBI': 0.1,          # 100 ms
    'min_burst_duration': 0.008,  # 8 ms
    'min_spikes_in_burst': 3,
    'pre_burst_silence': 0.1  # 100 ms
}

# Define network burst parameters
OVERLAP_THRESHOLD = 0.02  # 2% of units must burst simultaneously
WINDOW_SIZE = 1  # 1-bin sliding window
BIN_DURATION = 1.0  # 1-second bins

# Initialize an empty list to store all DataFrames
all_counts_df = []

# Loop through each session
for session_info in sessions_info:
    try:
        # Parse subject_id and session id
        subject_id = f"sub-{session_info.split('_')[0]}"
        session_id = session_info.split('_')[1]

        # Construct session filename pattern
        session_pattern = f"{subject_id}_ses-{session_id}*.nwb"

        # Construct full path to NWB file using glob
        nwb_path_pattern = str(base_path / 'data' / subject_id / session_pattern)
        print(nwb_path_pattern)
        matching_files = glob.glob(nwb_path_pattern)
        print(matching_files)
        # Construct session filename
        session = matching_files[0]
        print(session)
        # Construct full path to NWB file
        nwb_path = base_path / 'data' / subject_id / session
        
        print(f"Processing {subject_id} - {session}")

        structures=('LGd', 'VISp', 'VISl')
        
        # Load the NWB file
        with pynwb.NWBHDF5IO(str(nwb_path), mode='r', load_namespaces=True) as nwb_file_asset:
            nwb_file = nwb_file_asset.read()
            dynamic_gating_session = DynamicGatingEcephysSession.from_nwb(nwb_file)
            
            # Get session ID from metadata
            session_ecephys_id = dynamic_gating_session.metadata['ecephys_session_id']
            print(f"Session ID: {session_ecephys_id}")
            
            # Filter stim_table for this session and no_reward_epoch
            filtered_stim_table_no_reward = stim_table[
                (stim_table['session_id'] == session_ecephys_id) & 
                (stim_table['no_reward_epoch'] == True)
            ]
            
            print(f"Found {len(filtered_stim_table_no_reward)} no-reward stimuli for this session")
            
            if len(filtered_stim_table_no_reward) == 0:
                print(f"No no-reward stimuli found for session {session_info}, skipping...")
                continue
            
            # Get stimulus presentations
            stim_presentations = dynamic_gating_session.stimulus_presentations
            
            # Check which image-related stimulus name exists in this session
            if 'dynamic_routing_image_set' in stim_presentations['stimulus_name'].values:
                stims = stim_presentations[stim_presentations['stimulus_name'] == 'dynamic_routing_image_set']
            elif stim_presentations['stimulus_name'].str.contains('image-set', case=False).any():
                # Use any stimulus name containing 'image-set'
                stims = stim_presentations[stim_presentations['stimulus_name'].str.contains('image-set', case=False)]
            else:
                print(f"Warning: No recognized image set stimulus found in session {session_info}")
                print(f"Available names: {stim_presentations['stimulus_name'].unique()}")
                continue
            
            # Get valid stimulus IDs from filtered table
            valid_stim_ids = filtered_stim_table_no_reward['stimulus_presentations_id'].values
            
            # Filter stims to only include presentations in no-reward epochs
            stims_filtered = stims[stims.index.isin(valid_stim_ids)]
            
            # Get CHANGE TIMES from filtered stimuli
            change_times_filtered = stims_filtered[stims_filtered['active'] & stims_filtered['is_change']]['start_time'].values
            
            print(f"Using {len(change_times_filtered)} change stimuli from no-reward epochs")
            
            if len(change_times_filtered) == 0:
                print(f"No change times found in no-reward epochs for session {session_info}, skipping...")
                continue
            
            # Get good units with structure info
            good_units, spike_times_dict, structure_dict = get_good_units_with_structure(
                dynamic_gating_session, structures)
            
            # Get burst times
            burst_times_dict = get_burst_times_dict(spike_times_dict, burst_params)
            
            # Count spikes and bursts after stimuli
            counts_df = count_bursts_and_spikes_after_stim(
                burst_times_dict, 
                spike_times_dict, 
                structure_dict,
                change_times_filtered  # Using change times
            )
            
            # Add network burst counts to existing DataFrame
            counts_df = add_network_burst_counts(
                counts_df,
                burst_times_dict,
                structure_dict,
                window_duration=1.0,
                overlap_threshold=OVERLAP_THRESHOLD,
                window_size=WINDOW_SIZE,
                bin_duration=BIN_DURATION
            )

            # Add metadata
            counts_df = add_metadata_to_df(
                counts_df,
                dynamic_gating_session.metadata,
                session=session,
                subject_id=subject_id
            )
            
            # Add a column to indicate these are no-reward epoch change trials
            counts_df['epoch_type'] = 'no_reward'
            counts_df['stimulus_type'] = 'change'
            
            # Append to list of all DataFrames
            all_counts_df.append(counts_df)
            
    except Exception as e:
        print(f"Error processing {session_info}: {str(e)}")
        continue

# Combine all DataFrames
if all_counts_df:
    final_df = pd.concat(all_counts_df, ignore_index=True)
    
    # Save the results
    output_path = base_path / 'analysis_results_no_reward_change_epochs'
    output_path.mkdir(exist_ok=True)
    final_df.to_csv(output_path / 'all_sessions_burst_analysis_no_reward_changes_natural.csv', index=False)
    
    print("Analysis complete!")
    print(f"Total sessions processed: {len(all_counts_df)} out of {len(sessions_info)}")
    print(f"Total rows in final DataFrame: {len(final_df)}")
else:
    print("No data was successfully processed!")