In [None]:
#SWR analysis

import pandas as pd
import matplotlib.pyplot as plt
import glob
import re
import ast
import numpy as np
from itertools import combinations
import scipy.stats as stats

In [None]:
#import the reuslts of SWR detection algorithm
file_paths = []

In [None]:
#dictionary of recording durations in seconds
recording_durations = {"PatientX" : 110}

#eeg sampling fequency
eeg_fs = 1000

In [None]:
#dictionary of hippocampus contacts of each patient
hippocampus_electrodes = {"PatientX" : ['EEG001']}

In [None]:
data = {}
for file_path in file_paths:
    patient_id = file_path.split('/')[-1].replace('_SWRs.csv', '')
    data[patient_id] = pd.read_csv(file_path)

In [None]:
#function to comput jaccard index
def get_jaccard_index(SWR_array):
    num_channels = SWR_array.shape[0]
    jaccard_results = []

    for channel1, channel2 in combinations(range(num_channels), 2):
        SWR1 = SWR_array[channel1, :]
        SWR2 = SWR_array[channel2, :]

        intersection = np.sum(SWR1 & SWR2)
        union = np.sum(SWR1 | SWR2)

        jaccard_index = intersection / union if union > 0 else 0
        jaccard_results.append(jaccard_index)
    
    average_jaccard_index = np.mean(jaccard_results)
    return average_jaccard_index

In [None]:
#calcualte average rate per contact, average duration/peak amplitude across all detected ripples and jaccard index per patient
hippocampus_all_incidence_rates = []
hippocampus_all_durations = []
hippocampus_all_peak_amplitudes = []
hippocampus_all_jaccard_indices = []


for patient_id, electrodes in hippocampus_electrodes.items():
    df = data[patient_id]
    df = df[df['Channel'].isin(electrodes)]
    if df.empty:
        print(f"No hippocampal channels found for {patient_id}")
        continue

    SWR_times = (df['rippleTime'].apply(lambda x: eval(x, {'array': np.array}))).tolist()
    duration = recording_durations[patient_id]
    eeg_no_samples = int(duration * eeg_fs)
    no_channels = len(df)

    SWR_array = np.zeros((no_channels, eeg_no_samples), dtype=int)
    for channel, times in enumerate(SWR_times):
        if not times or all(len(t) == 0 for t in times):
            ripple_time_value = df.iloc[channel]['rippleTime']
            print(f"Skipping empty SWR_times for channel index {channel} in patient {patient_id}")
            continue

        channel_SWR_times = np.hstack(times)
        channel_SWR_idx = np.round((channel_SWR_times * eeg_fs), 0).astype(int)
        SWR_array[channel, channel_SWR_idx] = 1

    jaccard_index = get_jaccard_index(SWR_array)
    hippocampus_all_jaccard_indices.append(jaccard_index)

    for index, row in df.iterrows():
        if row['nEvents'] > 0:
            duration_list = [float(x) for x in row['durations'].strip('[]').split()]
            ripple_amp_raw = row['rippleAmp']
            array_strings = ripple_amp_raw.split('array')
            peak_amplitudes = []
            for array_str in array_strings:
                array_values = [float(val) for val in re.findall(r'[-+]?[0-9]*\.?[0-9]+', array_str)]
                if array_values:
                    peak_amplitudes.append(max(array_values))

            hippocampus_all_durations.extend(duration_list)
            hippocampus_all_peak_amplitudes.extend(peak_amplitudes)
            hippocampus_all_incidence_rates.append(row['nEvents'] / duration)



In [None]:
#compute median values and standard deviation
median_rate = np.median(hippocampus_all_incidence_rates)
median_duration = np.median(hippocampus_all_durations)
median_amp = np.median(hippocampus_all_peak_amplitudes)

jaccard = [j for j in hippocampus_all_jaccard_indices if not np.isnan(j)]
median_jaccard = np.median(jaccard)

std_rate = np.std(hippocampus_all_incidence_rates)
std_duration = np.std(hippocampus_all_durations)
std_amplitude = np.std(hippocampus_all_peak_amplitudes)
std_jaccard = np.std(jaccard)


In [None]:
#plot histograms
fig, axes = plt.subplots(2, 2, figsize=(12, 8))


# Histogram A: Ripple Rate per Channel
axes[0, 0].hist(hippocampus_all_incidence_rates, bins=20, color='steelblue', edgecolor='black', weights=np.ones_like(hippocampus_all_incidence_rates) * 100. / len(hippocampus_all_incidence_rates))
axes[0, 0].text(-0.2, 1.05, 'A', transform=axes[0, 0].transAxes, fontsize=22, fontweight='bold')
axes[0, 0].set_xlabel("SWR rate (Hz)")
axes[0, 0].set_ylabel("Frequency (%)")
axes[0, 0].axvline(median_rate, color='red', linestyle='--')

# Histogram B: Duration of Ripples
# convert durations from seconds to milliseconds
durations = [d * 1000 for d in hippocampus_all_durations]
axes[0, 1].hist(durations, bins=20, color='steelblue', edgecolor='black', weights=np.ones_like(durations) * 100. / len(durations))
#axes[0, 1].set_title(f"Duration of ripples\n(n={len(durations)} ripples)", fontsize=12, fontweight='bold')
axes[0, 1].text(-0.2, 1.05, 'B', transform=axes[0, 1].transAxes, fontsize=22, fontweight='bold')
axes[0, 1].set_xlabel("Duration (ms)")
axes[0, 1].set_ylabel("Frequency (%)")
median_duration = np.median(durations)
axes[0, 1].axvline(median_duration, color='red', linestyle='--', label=f'Mean = {median_duration:.1f} ms')


# Histogram C: Peak Amplitude
axes[1, 0].hist(hippocampus_all_peak_amplitudes, bins=20, color='steelblue', edgecolor='black', weights=np.ones_like(hippocampus_all_peak_amplitudes) * 100. / len(hippocampus_all_peak_amplitudes))
axes[1, 0].text(-0.2, 1.05, 'C', transform=axes[1, 0].transAxes, fontsize=22, fontweight='bold')
axes[1, 0].set_xlabel("Peak Amplitude (Z-score)")
axes[1, 0].set_ylabel("Frequency (%)")
axes[1, 0].axvline(median_amp, color='red', linestyle='--')

# Histogram D: Jaccard Index per Patient
# Jaccard Index plot (D)
jaccard = [j for j in hippocampus_all_jaccard_indices if not np.isnan(j)] #inlcude only the patients with more than 1 hippocampal contact
axes[1, 1].hist(jaccard, bins=8, color='steelblue', edgecolor='black')
axes[1, 1].text(-0.2, 1.05, 'D', transform=axes[1, 1].transAxes, fontsize=22, fontweight='bold')
axes[1, 1].set_xlabel("Jaccard index")
axes[1, 1].set_ylabel("Number of patients")
axes[1, 1].axvline(x=median_jaccard, ymin=0, ymax=1, color='red', linestyle='--', label=f'Mean = {mean_jaccard:.2f}')