In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.ndimage import label, binary_dilation, binary_erosion
from scipy.ndimage import find_objects
import os.path as op
import mne
import random

In [None]:
#dictionary with results of SWR detection per patient
detected_SWRs = {"PatientX" : "file path"}

In [None]:
#dictionary with hippocampal contacts per peatient
hippocampus_electrodes = {"PatientX" : ['EEG001']}

In [None]:
#dictionayr of recording durations per patient
recording_durations = { "PatientX" : 100}

In [None]:
#dictionary of raw recoridng per patient
raw_files = {"PatientX" : "file path"}

In [None]:
#parameters 
eeg_fs = 1000 #eeg samplig frequency
min_gap = 50 #minim gap between SWR events included in ms
no_of_epochs = 30 #no of desired epochs

#save path
epochs_save_path = ""

In [None]:

for patient_id, electrodes in hippocampus_electrodes.items():
    df = pd.read_csv(detected_SWRs[patient_id])
    df = df[df['Channel'].isin(electrodes)]
    if df.empty:
        print(f"No hippocampal channels found for {patient_id}")
        continue
    SWR_times = (df['rippleTime'].apply(lambda x: eval(x, {'array': np.array}))).tolist()
    
    duration = recording_durations[patient_id]
    eeg_no_samples = int(duration * eeg_fs)
    no_channels = len(df)

    #covert to a binary array with the lenght of recording(number of samples) where 1 represents a timepoint (sample) where SWR is detected
    SWR_array = np.zeros((no_channels, eeg_no_samples), dtype=int)
    for channel, times in enumerate(SWR_times):
        if not times or all(len(t) == 0 for t in times):
            ripple_time_value = df.iloc[channel]['rippleTime'] 
            print(f"Skipping empty SWR_times for channel index {channel} in patient {patient_id}")
            continue
        channel_SWR_times = np.hstack(times)
        channel_SWR_idx = np.round((channel_SWR_times * eeg_fs), 0).astype(int)
        SWR_array[channel, channel_SWR_idx] = 1

    #sum across the hippocampal contacts
    SWR_sum = np.sum(SWR_array, axis = 0)

    #get regions where SWRs were detected
    valid_regions = (SWR_sum >= 1)
    #dialate and erode to merge neighbouring regions
    min_duration_samples = int(10)
    structure_element = np.ones(min_duration_samples)
    valid_regions = binary_dilation(valid_regions, structure=structure_element)
    valid_regions = binary_erosion(valid_regions, structure=structure_element)
    #need to label connected valid regions
    labeled_regions, num_regions = label(valid_regions)
    slices = find_objects(labeled_regions)
    #compoute region properties
    region_peaks = []
    for region in slices:
        start = region[0].start
        end = region[0].stop
        region_signal = SWR_sum[start:end]
        peak_relative = np.argmax(region_signal)
        peak_sample = start + peak_relative
        peak_time_sec = peak_sample / eeg_fs
        peak_value = SWR_sum[peak_sample]

        region_peaks.append({
            "start_sample": start,
            "end_sample": end,
            "peak_sample": peak_sample,
            "peak_time_sec": peak_time_sec,
            "peak_value": SWR_sum[peak_sample],
            "duration_ms": (end - start) / eeg_fs * 1000
        })

    print(f"{patient_id}: Found {len(region_peaks)} SWR events.")
    df_regions = pd.DataFrame(region_peaks)
  
    
    #get regions of no SWRs    
    valid_no_regions = (SWR_sum == 0)
    #dialate and erode to merge neighbouring regions
    min_duration_samples = int(10)
    structure_element = np.ones(min_duration_samples)
    valid_no_regions = binary_dilation(valid_no_regions, structure=structure_element)
    valid_no_regions = binary_erosion(valid_no_regions, structure=structure_element)
    
    labeled_low_regions, num_low_regions = label(valid_no_regions)
    low_slices = find_objects(labeled_low_regions)

    low_regions = []
    for low_region in low_slices:
        start = low_region[0].start
        end = low_region[0].stop
        duration_ms = (end - start) / eeg_fs * 1000
        midpoint_sample = (start + end) // 2 

        low_regions.append({
        "start_sample": start,
        "end_sample": end,
        "midpoint_sample": midpoint_sample,
        "duration_ms": duration_ms})

    print(f"{patient_id}: Found {len(low_regions)} no SWR events.")


    #make the MNE epochs
    #improt raw
    raw = mne.io.read_raw_fif(raw_files[patient_id], preload = True)
    #notch filter
    notch_freqs = [50, 100, 150]
    raw = raw.notch_filter(freqs=notch_freqs, notch_widths=4)
    raw_selected = raw.copy()
    data = raw_selected.get_data()
    info = raw_selected.info
    data_elect = mne.io.RawArray(data, info)

    #events array for SWRs
    if len(df_regions) < no_of_epochs:
        print(f"{patient_id}: Less than {no_of_epochs} SWRs regions detected. Using all {len(df_regions)} available regions.")
    else:
        df_regions = df_regions.sort_values("peak_value", ascending = False).head(no_of_epochs)

    df_regions = df_regions.sort_values("peak_sample") #to enusre chronological order
    events = []
    last_event = -min_gap
    used_swr_regions = 0
    for row in df_regions.itertuples():
        peak_sample = row.peak_sample
        if peak_sample >= last_event + min_gap:
            events.append([peak_sample, 0, 1])
            last_event = peak_sample
            used_swr_regions += 1

    if used_swr_regions < no_of_epochs:
     print(f"{patient_id}: Only {used_swr_regions} SWR epochs could be used after applying min_gap constraint.")
    
    #events for no SWRs
    if len(low_regions) < no_of_epochs:
        print(f"{patient_id}: Less than {no_of_epochs} low-SWR regions detected. Using all {len(low_regions)} available regions.")
        random_low_regions = low_regions
    else:   
        random_low_regions = random.sample(low_regions, no_of_epochs)
        random_low_regions = sorted(random_low_regions, key=lambda x: x["midpoint_sample"]) 

    for region in random_low_regions:
        midpoint_sample = region["midpoint_sample"]
        events.append([midpoint_sample, 0, 2])
    
    events = np.array(events, dtype=int)
    if len(events) == 0:
       print(f"{patient_id}: No valid events found for epoching.")
       continue


    #set picks to either picks = hippocampus_electrodes[patient_id] for ieeg electrdoes, or pick = ['mag', 'grad'] for MEG sensors
    high_epochs = mne.Epochs(data_elect, events, event_id=1, tmin=-1, tmax=1, picks = hippocampus_electrodes[patient_id], preload=True)
    low_epochs = mne.Epochs(data_elect, events, event_id=2, tmin=-1, tmax=1, picks = hippocampus_electrodes[patient_id], preload=True)
    all_epochs = mne.Epochs(data_elect, events, event_id= [1, 2], tmin=-1, tmax=1, picks = hippocampus_electrodes[patient_id], preload=True)
   
   
    #save
    high_epoch_filename = op.join(epochs_save_path, f"{patient_id}_high_SWR_epochs.fif")
    low_epoch_filename = op.join(epochs_save_path, f"{patient_id}_low_SWR_epochs.fif")
    all_epoch_filename = op.join(epochs_save_path, f"{patient_id}_all_SWR_epochs.fif")
    high_epochs.save(high_epoch_filename, overwrite=True)
    low_epochs.save(low_epoch_filename, overwrite=True)
    all_epochs.save(all_epoch_filename, overwrite=True)
