<!-- 1) get the individual information  -->
<!-- 'Patient ID', 'EMU ID', 'Filename': filename,'Spike Time (s)': 'Side': spike side, 'YASA': yasa_stage -->

In [1]:
import os
import re
import pickle
import numpy as np
import pandas as pd
from multiprocessing import Pool, Manager # Import Pool and Manager for multiprocessing

In [None]:


# Input folders
pkl_roots = [
    '/mnt/sauce/littlab/users/jurikim/ied_yesno/colab/Jack/add_144/pickle_outputs/', #pickle data(SN2, SN2_zero_left, SN2_zero_right, YASA) for the 2017-2022 data
    # '/mnt/sauce/littlab/users/jurikim/ied_yesno/colab/Jack/add_144/pickle_outputs_2324_updates/' #pickle data(SN2, SN2_zero_left, SN2_zero_right, YASA) for the 2023-2024 data
]
deid_csv = '/mnt/sauce/littlab/users/jurikim/ied_yesno/add_2023/deid_with_redcap_Apr3025.csv' #deidentified_table from Erin
output_csv = '/mnt/sauce/littlab/users/jurikim/ied_yesno/colab/Jack/add_144/outputs/spike_summary_1.csv'

# Parameters
THRESHOLD = 0.43
SAMPLES_PER_SEC = int(1 / 0.0625)  # 16
SKIP_FRONT = 8
SKIP_BACK = 9
NUM_PROCESSES = 16 # Define the number of CPU cores to use

# Load de-identified mapping
deid_df = pd.read_csv(deid_csv)
deid_df['admission_id'] = deid_df['admission_id'].astype(str)
# Dictionary is safe to use in multiprocessing as it's read-only
ADMISSION_TO_PATIENT = dict(zip(deid_df['admission_id'], deid_df['patient_id']))

# Spike clustering function (No change)
def cluster_spikes_with_max(SN2, SN2_right, SN2_left, threshold):
    spike_info = []
    i = SAMPLES_PER_SEC // 2
    while i + SAMPLES_PER_SEC // 2 < len(SN2):
        if SN2[i] > threshold:
            win_start = max(0, i - SAMPLES_PER_SEC // 2)
            win_end = min(len(SN2), i + SAMPLES_PER_SEC // 2)
            max_left = np.max(SN2_left[win_start:win_end])
            max_right = np.max(SN2_right[win_start:win_end])
            side = 'R' if max_right > max_left else 'L'
            spike_info.append((i, side))
            i += SAMPLES_PER_SEC
        else:
            i += 1
    return spike_info

# Worker function to process a single PKL file
def process_pkl_file(pkl_path):
    """Processes a single PKL file and returns a list of dictionaries (rows)."""
    
    filename = os.path.basename(pkl_path)
    
    match_emu = re.match(r'(EMU\d+)', filename)
    if not match_emu:
        print(f"[SKIP] Cannot extract EMU ID from {filename}")
        return None
    emu_id = match_emu.group(1)
    
    patient_id = ADMISSION_TO_PATIENT.get(emu_id, 'UNKNOWN')
    
    local_rows = []

    try:
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)

        if not all(k in data for k in ['SN2', 'SN2_zero_left', 'SN2_zero_right', 'YASA']):
            return None

        # Apply skip parameters
        SN2 = data['SN2'][SKIP_FRONT:-SKIP_BACK]
        SN2_right = data['SN2_zero_left'][SKIP_FRONT:-SKIP_BACK]
        SN2_left = data['SN2_zero_right'][SKIP_FRONT:-SKIP_BACK]
        YASA = data['YASA'][SKIP_FRONT:-SKIP_BACK]

        spike_data = cluster_spikes_with_max(SN2, SN2_right, SN2_left, THRESHOLD)

        for idx, side in spike_data:
            yasa_stage = YASA[idx] if idx < len(YASA) else np.nan
            
            local_rows.append({
                'Patient ID': patient_id,
                'EMU ID': emu_id, 
                'Filename': filename,
                'Spike Time (s)': round((idx + SKIP_FRONT) * 0.0625, 2),
                'Side': side,
                'YASA': yasa_stage
            })

    except Exception as e:
        # print(f"[ERROR] {filename}: {e}") # Suppress excessive printing from workers
        return None

    return local_rows

# Sort by EMU ID, day, etc. (No change)
def extract_sort_key(path):
    fname = os.path.basename(path)
    match = re.search(r'(EMU\d+)_(?:Day(\d+)_)?(\d+)_(\d+\.\d+)_\d+\.\d+\.pkl$', fname)
    if match:
        emu_id = match.group(1)
        day = int(match.group(2)) if match.group(2) else 0
        clip = int(match.group(3))
        start_sec = float(match.group(4))
        return (emu_id, day, clip, start_sec)
    return ('ZZZ', 999, 999, float('inf'))


# --- MAIN EXECUTION BLOCK ---
if __name__ == '__main__':
    # --- PKL File Path Collection ---
    all_pkl_paths = []
    print("🔍 Collecting PKL file paths from both root directories...")
    for pkl_root in pkl_roots:
        if not os.path.isdir(pkl_root):
            print(f"[WARNING] Directory does not exist. Skipping: {pkl_root}")
            continue
            
        for root, _, files in os.walk(pkl_root):
            for file in files:
                if file.endswith('.pkl'):
                    all_pkl_paths.append(os.path.join(root, file))

    print(f" Collected a total of {len(all_pkl_paths)} PKL files.")

    # Sort files to ensure deterministic spike numbering later (optional but good practice)
    all_pkl_paths.sort(key=extract_sort_key)
    
    # --- Multiprocessing Pool Execution ---
    print(f" Starting spike extraction using {NUM_PROCESSES} CPU cores...")
    
    all_results = []
    
    with Pool(processes=NUM_PROCESSES) as pool:
        # pool.map returns results in the order the inputs were submitted.
        # We pass the list of pkl paths to the worker function.
        results = pool.map(process_pkl_file, all_pkl_paths)

    # --- Combine Results and Finalize Spike Numbering ---
    
    all_rows = []
    spike_counters = {}
    
    # Flatten the list of lists returned by pool.map
    for result_list in results:
        if result_list:
            all_rows.extend(result_list)

    print(f"Gathered {len(all_rows)} raw spike entries.")

    # Re-sort by EMU ID, then Filename, then Spike Time to ensure correct spike numbering
    if all_rows:
        df_raw = pd.DataFrame(all_rows)
        # Sort order: EMU ID > Filename > Spike Time (s)
        df_raw = df_raw.sort_values(by=['EMU ID', 'Filename', 'Spike Time (s)'])
        
        # Apply spike numbering sequentially within each EMU ID
        for emu_id, group in df_raw.groupby('EMU ID'):
            if emu_id not in spike_counters:
                spike_counters[emu_id] = 1
                
            group_size = len(group)
            group['Spike #'] = range(spike_counters[emu_id], spike_counters[emu_id] + group_size)
            spike_counters[emu_id] += group_size
            
            # Update the main list of rows with the new spike numbering
            all_rows.extend(group.to_dict('records'))

        # Prepare final DataFrame for saving
        df = pd.DataFrame(all_rows)
        # Drop the temporary 'EMU ID' column used for grouping/sorting in the worker process
        df = df.drop(columns=['EMU ID'], errors='ignore')
        
    else:
        df = pd.DataFrame()

    print(f"\n✨ Total number of extracted spike data points: {len(df)}")
    
    # Save to CSV
    try:
        df.to_csv(output_csv, index=False)
        print(f" Results successfully saved to '{output_csv}'.")
    except Exception as e:
        print(f"[ERROR] CSV save failed: {e}")

<!-- 2) summary of each patient's # spike, laterality, and YASA -->

In [None]:
#add duration in the outcome
import pandas as pd
import numpy as np
import os
import pickle

# Input/output paths
input_csv = '/mnt/sauce/littlab/users/jurikim/ied_yesno/colab/Jack/add_144/outputs/spike_summary_1.csv'
output_csv = '/mnt/sauce/littlab/users/jurikim/ied_yesno/colab/Jack/add_144/outputs/spike_summary_2.csv'
pkl_base_dir = '/mnt/sauce/littlab/users/jurikim/ied_yesno/colab/Jack/add_144/pickle_outputs/'

# Load input CSV
df = pd.read_csv(input_csv)

# Clean and prepare data
df['admission_id'] = df['Filename'].str.extract(r'(EMU\d+)')
df['Spike #'] = df['Spike #'].astype(int)
df['Side'] = df['Side'].fillna('')
df['YASA'] = df['YASA'].fillna('')

# Define sleep stages
sleep_stages = ['N1', 'N2', 'N3', 'R']

# Initialize list for rows
summary_rows = []

# Group by patient
for patient_id, group in df.groupby('Patient ID'):
    admission_ids = sorted(group['admission_id'].unique())
    admission_id_str = ','.join(admission_ids)

    # Spike counts
    total_spikes = len(group)
    left_spikes = (group['Side'] == 'L').sum()
    right_spikes = (group['Side'] == 'R').sum()
    wake_spikes = (group['YASA'] == 'W').sum()
    sleep_spikes = group['YASA'].isin(sleep_stages).sum()

    # Duration counts (across all EMUxxxx folders)
    total_duration = 0
    wake_duration = 0
    sleep_duration = 0

    for adm_id in admission_ids:
        pkl_dir = os.path.join(pkl_base_dir, adm_id)
        if os.path.exists(pkl_dir):
            for filename in os.listdir(pkl_dir):
                if filename.endswith('.pkl'):
                    pkl_path = os.path.join(pkl_dir, filename)
                    try:
                        with open(pkl_path, 'rb') as f:
                            data = pickle.load(f)
                            yasa = np.array(data.get('YASA', []))
                            time = np.array(data.get('Time', []))

                            if len(time) > 0:
                                total_duration += time[-1]  # last timestamp in seconds

                            wake_duration += (yasa == 'W').sum() * 0.0625
                            sleep_duration += sum((yasa == stage).sum() for stage in sleep_stages) * 0.0625
                    except Exception as e:
                        print(f"Could not load {pkl_path}: {e}")


    # Append summary for this patient
    summary_rows.append({
        'patient_id': patient_id,
        'admission_id': admission_id_str,
        'Total_spikes': total_spikes,
        'Left_spikes': left_spikes,
        'Right_spikes': right_spikes,
        'Wake_spikes': wake_spikes,
        'Sleep_spikes': sleep_spikes,
        'Total_duration(sec)': total_duration,
        'Wake_duration(sec)': wake_duration,
        'Sleep_duration(sec)': sleep_duration
    })

# Create summary DataFrame
summary_df = pd.DataFrame(summary_rows)

# Save to CSV
summary_df.to_csv(output_csv, index=False)
print(f"Spike summary saved to: {output_csv}")

