In [None]:
import os
import mne
import numpy as np
import pandas as pd
import ast

from 1_swr_detection_algorithm import detect_SWR


from scipy.stats import pearsonr, linregress
import matplotlib.pyplot as plt
import matplotlib as mpl



In [None]:
#input paths
base_path = "path"
virtual_path = os.path.join(base_path, "virtual_time_series") #path tp flder with all virtual time series
ied_dir = os.path.join(base_path, "IED_boolean_array")
swr_dir = os.path.join(base_path, "SWR_times")
output_csv = os.path.join(base_path, "SWR_virtual_vs_ieeg.csv")

In [None]:
#dictioanry of closest channel for each voxel
closest_channels_dict = {"patient_X_vertexY": ["EEG00Z"]}

#dictionary of ied files (if needed) per patient
ied_file_dict = {}

#dictionary of detected SWRs on iEEG contacts per patients
swr_file_dict ={"patient_X": "file path"}

#dictionary of data dropout (if needed)
data_dropout_dict ={"patient_X": []}

In [None]:
#ripple detection paramteres
gaussFilt = 15 
zThresh = 2 
durTresh = [0.01, 0.1 ]
rippleBand = [80, 150] 


In [None]:
results = []
current_subject = None
ied_bool_arrays = None
swr_df = None

for filename in os.listdir(virtual_path):
    if not filename.endswith("_virtual_ts.fif"):
        continue

    filepath = os.path.join(virtual_path, filename)

    #get base key for the nearest channels
    base_key = filename.replace("_virtual_ts.fif", "") 

    #get subject key for correct ied and swr files
    subject_key_parts = base_key.split("_vertex")[0].split("_")
    subject_key = "_".join(subject_key_parts[:4]) if "part" in base_key else "_".join(subject_key_parts[:2])
 
    #load IED file
    if subject_key != current_subject:
        if subject_key in ied_file_dict:
            ied_path = ied_file_dict[subject_key]
            ied_bool_arrays = np.load(ied_path)
            current_subject = subject_key
        else:
            print(f"no IED file for {subject_key}")
            continue
    
    #load SWR file
    if subject_key in swr_file_dict:  
        swr_df = pd.read_csv(swr_file_dict[subject_key])
        swr_df['rippleTime'] = swr_df['rippleTime'].apply(eval).apply(lambda lst: [np.array(seg) for seg in lst]).tolist()
    else:
        print(f"no SWR file for {subject_key}")
        continue

    #load virtual electrode series
    raw = mne.io.read_raw_fif(filepath, preload=True)
    raw.set_channel_types({raw.ch_names[0]: 'misc'})
    data = raw.get_data()
    sfreq = raw.info['sfreq']
    data_dropout = data_dropout_dict[subject_key]
    

    #loop thorugh the nearest channels
    for ch in closest_channels_dict[base_key]:
        ied_bool = ied_bool_arrays[ch]
        output = detect_SWR(data, sfreq, ied_bool, data_dropout, rippleBand=rippleBand, gaussFilt=gaussFilt,
                    zThresh=zThresh, durTresh=durTresh)
        n_ve = output["nEvents"]
        ripple_times = output["rippleTime"]

        row = swr_df[swr_df["Channel"] == ch]
        n_ieeg = int(row["nEvents"].values[0]) if not row.empty else 0
        swr_times_ieeg = row["rippleTime"].values[0] if not row.empty else []
        ieeg_durations = row['durations'].values[0] if 'durations' in row else []

        #results
        results.append({
                "subject": subject_key,
                "vertex": base_key.split("_vertex")[1],
                "channel": ch,
                "virtual_SWRs": n_ve,
                "iEEG_SWRs": n_ieeg,
                "vitual_ripple_times": ripple_times,
                "iEEG_ripple_times": swr_times_ieeg,
                "iEEG_ripple_durations": ieeg_durations
            })
        

df = pd.DataFrame(results)
df.to_csv(output_csv, index=False)
print(f"Saved results to: {output_csv}")



    

In [None]:
#compute the correlation between the number of detected events
r, p = pearsonr(df["iEEG_SWRs"], df["virtual_SWRs"])
print(f"Overall Pearson Correlation (all data): r = {r:.2f}, p = {p:.3e}")

# plot
fig, ax = plt.subplots(figsize=(6,6))
ax.scatter(df["iEEG_SWRs"], df["virtual_SWRs"], 
           facecolors='steelblue', edgecolors='steelblue', s=30, linewidth=1.5)

# Identity line (y = x)
max_val = max(df["iEEG_SWRs"].max(), df["virtual_SWRs"].max()) * 1.05
ax.plot([0, max_val], [0, max_val], color='black', linestyle='--', linewidth=1.5)
ax.set_xlabel("iEEG SWR Count")
ax.set_ylabel("Virtual Electrode SWR Count")




In [None]:
#compute jaccard index between the virtual time series and corrresponding actual iEEG electrode


eeg_fs = 1000 #sampling frequency
recording_durations = {"patient_X" : 100} #for all patients

def ripple_times_to_array(ripple_times, eeg_fs, eeg_num_samples):
    swr_array = np.zeros(eeg_num_samples, dtype=int)
    if ripple_times is not None and np.size(ripple_times) > 0:
        ripple_indices = np.round(np.array(ripple_times) * eeg_fs).astype(int)
        ripple_indices = ripple_indices[(ripple_indices >= 0) & (ripple_indices < eeg_num_samples)]
        swr_array[ripple_indices] = 1
    return swr_array

jaccard_results = []
for idx, row in df.iterrows():
    subject = row['subject']
    eeg_duration_seconds = recording_durations[subject]
    eeg_num_samples = int(eeg_fs * eeg_duration_seconds)
    
    virtual_times = np.concatenate(eval(row['vitual_ripple_times'], {'array': np.array})) if row['vitual_ripple_times'] else []
    ieeg_times = np.concatenate(eval(row['iEEG_ripple_times'], {'array': np.array})) if row['iEEG_ripple_times'] else []
    
    swr_array_virtual = ripple_times_to_array(virtual_times, eeg_fs, eeg_num_samples)
    swr_array_ieeg = ripple_times_to_array(ieeg_times, eeg_fs, eeg_num_samples)

    #comptue the acutal jaccard index
    intersection = np.sum(swr_array_virtual & swr_array_ieeg)
    union = np.sum(swr_array_virtual | swr_array_ieeg)

    jaccard_index = intersection / union if union > 0 else 0
    jaccard_results.append(jaccard_index)

    print(f"Subject {subject}, channel {row['channel']}: Jaccard Index = {jaccard_index:.3f}")


df['jaccard_index'] = jaccard_results
df.to_csv('path', index=False)



In [None]:
#permutation test to get statistical significance of the jaccard indices
n_shuffles = 1000

#function to generate random arrays - same lenght, same number of events, with same real durations
def generate_shuffled_array(real_durations, eeg_fs, eeg_duration_seconds, eeg_num_samples):
    swr_array = np.zeros(eeg_num_samples, dtype=int)
    for duration in real_durations:
        random_start = np.random.uniform(0, eeg_duration_seconds)
        start_idx = int(random_start * eeg_fs)
        end_idx = int((random_start + duration) * eeg_fs)
        end_idx = min(end_idx, eeg_num_samples)
        swr_array[start_idx:end_idx] = 1
    return swr_array

shuffle_results = []
all_percentiles = []
all_pvalues =[]
for idx, row in df.iterrows():
    subject = row['subject']
    eeg_duration_seconds = recording_durations[subject]
    eeg_num_samples = int(eeg_fs * eeg_duration_seconds)

    # Load ripple times and the durations
    virtual_times = np.concatenate(eval(row['vitual_ripple_times'], {'array': np.array})) if row['vitual_ripple_times'] else []
    ieeg_times = np.concatenate(eval(row['iEEG_ripple_times'], {'array': np.array})) if row['iEEG_ripple_times'] else []
    ripple_durations = np.array([float(x) for x in row['iEEG_ripple_durations'].replace('[', '').replace(']', '').split()])
 
    
    swr_array_virtual = ripple_times_to_array(virtual_times, eeg_fs, eeg_num_samples)
    swr_array_ieeg = ripple_times_to_array(ieeg_times, eeg_fs, eeg_num_samples)

    #get the actual/true jaccard index
    true_jaccard = row['jaccard_index']

    #shuffle - generate 1000 (n_shuffles) shuffled ieeg arrays
    shuffled_jaccards = []
    for _ in range(n_shuffles):
        swr_array_shuffled_ieeg = generate_shuffled_array(
            ripple_durations,
            eeg_fs,
            eeg_duration_seconds,
            eeg_num_samples
        )

        #each time comptue jaccard index
        intersection = np.sum(swr_array_virtual & swr_array_shuffled_ieeg)
        union = np.sum(swr_array_virtual | swr_array_shuffled_ieeg)

        jaccard_shuffle = intersection / union if union > 0 else 0
        shuffled_jaccards.append(jaccard_shuffle)
    
    #calculate percentiles and p-values of true jaccard indices
    percentile = (np.sum(np.array(shuffled_jaccards) < true_jaccard) / n_shuffles) * 100
    p_value = (np.sum(np.array(shuffled_jaccards) >= true_jaccard) + 1) / (n_shuffles + 1) 

    shuffle_results.append(shuffled_jaccards)
    all_percentiles.append(percentile)
    all_pvalues.append(p_value)

    #print the results
    print(f"Subject {subject}, channel {row['channel']}: True Jaccard = {true_jaccard:.3f}, Percentile = {percentile:.1f}%, p-value = {p_value:.4f}")


# Save percentile and p-values
df['percentile'] = all_percentiles
df['p_value'] = all_pvalues
df.to_csv('path', index=False)




In [None]:
#can plot some histograms for visualization
selected_indices = [11, 53 ]  #for specific contacts/voxels
custom_titles = []

fig, axes = plt.subplots(1, 3, figsize=(12, 8))
for ax, idx, panel_label, custom_title in zip(axes[:2], selected_indices, ['A', 'B'], custom_titles):
    shuffled_jaccards = shuffle_results[idx]   
    true_jaccard = df.loc[idx, 'jaccard_index']
    p_value = df.loc[idx, 'p_value']

    # histogram
    ax.hist(shuffled_jaccards, bins=20, color='steelblue', edgecolor='black')
    ax.axvline(true_jaccard, color='red', linestyle='--', linewidth=2)
    ax.set_title(f'{custom_title}\n$p$ = {p_value:.4f}', fontsize=14, fontweight='bold')
    ax.set_xlabel('Jaccard Index', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.text(-0.2, 1.05, panel_label, transform=ax.transAxes, fontsize=22, fontweight='bold')

    

# Histogram of all p-values
ax = axes[2]
ax.hist(all_pvalues, bins=20, color='steelblue', edgecolor='black')
ax.axvline(0.05, color='red', linestyle='--', linewidth=2)
ax.set_title('All p-values\n$p$ threshold = 0.05', fontsize=14, fontweight='bold')
ax.set_xlabel('p-value', fontsize=12)
ax.set_ylabel('Frequency', fontsize=12)
ax.text(-0.2, 1.05, 'C', transform=ax.transAxes, fontsize=22, fontweight='bold')
