In [None]:
import mne
import numpy as np
import matplotlib.pyplot as plt
import os

In [None]:
#dictionary of iEEG epochs per patient
iEEG_high_epochs_files = {"patient_X" : "file path"} #epochs with high SWR density

iEEG_low_epochs_files = {"patient_X" : "file path"} #epochs with no SWRs

In [None]:
#dictionary of MEG epochs per patient
MEG_high_epochs_files = {"patient_X" : "file path"}

MEG_low_epochs_files = {"patient_X" : "file path"}

In [None]:
#parameters
frequencies = np.arange(30, 201, 10)
n_cycles = frequencies / 5

In [None]:
#average TRF of high SWR epochs across all iEEG hippocampal contacts for all patients
iEEG_baseline_powers_high = {}
for patient_id, epoch_file in iEEG_high_epochs_files.items():
    #load epochs
    epochs = mne.read_epochs(epoch_file, preload=True)
    #compute TRF average
    power, itc = epochs.compute_tfr(method="morlet", freqs=frequencies, n_cycles=n_cycles, decim=1, return_itc=True, average=True)
    #apply baseline correction
    power.apply_baseline((-0.7, -0.5), mode='logratio')
    #average across channels
    avg_data = np.mean(power.data, axis=0, keepdims=True)
    power._data = avg_data
    power.pick_channels([power.ch_names[0]])
    iEEG_baseline_powers_high[patient_id] = power

summed_power_high = None
for power in iEEG_baseline_powers_high.values():
    if summed_power_high is None:
        summed_power_high = power
    else:
        summed_power_high = summed_power_high.__add__(power)




In [None]:
#average TRF of no SWR epochs across all iEEG hippocampal contacts for all patients
iEEG_baseline_powers_low = {}
for patient_id, epoch_file in iEEG_low_epochs_files.items():
    epochs = mne.read_epochs(epoch_file, preload=True)
    power, itc = epochs.compute_tfr(method="morlet", freqs=frequencies, n_cycles=n_cycles, decim=1, return_itc=True, average=True)
    power.apply_baseline((-0.7, -0.5), mode='logratio')
    avg_data = np.mean(power.data, axis=0, keepdims=True)
    power._data = avg_data
    power.pick_channels([power.ch_names[0]])
    iEEG_baseline_powers_low[patient_id] = power

summed_power_low = None
for power in iEEG_baseline_powers_low.values():
    if summed_power_low is None:
        summed_power_low = power
    else:
        summed_power_low = summed_power_low.__add__(power)

In [None]:
#average TRF of high SWR epochs across all MEG sensors for all patients
MEG_baseline_powers_high = {}
valid_patient_ids_meg_high = []
for patient_id, epoch_file in MEG_high_epochs_files.items():
    # Load and pick MEG channels
    epochs = mne.read_epochs(epoch_file, preload=True, verbose=False)
    epochs.pick_types(meg=True)
    # Compute TFR average
    power, itc = epochs.compute_tfr(
        method="morlet", 
        freqs=frequencies, 
        n_cycles=n_cycles, 
        decim=1, 
        return_itc=True,
        average=True
    )
    # Apply baseline correction
    power.apply_baseline((-0.7, -0.5), mode='logratio')
    if np.isnan(power.data).all():  
        continue
    # Average across channels
    avg_data = np.mean(power.data, axis=0, keepdims=True) 
    new_power = power.copy()
    new_power._data = avg_data
    new_power.info = mne.create_info(
        ch_names=['MeanChannel'],
        sfreq=power.info['sfreq'],
        ch_types='eeg' #so MNE can handle it and plot it as it handles MEG channels differently
    )
    new_power.ch_names[0] = 'MeanChannel'
    MEG_baseline_powers_high[patient_id] = new_power
    valid_patient_ids_meg_high.append(patient_id)

summed_power_meg_high = None
for patient_id in valid_patient_ids_meg_high:
    power = MEG_baseline_powers_high[patient_id]
    if summed_power_meg_high is None:
        summed_power_meg_high = power.copy()
    else:
        summed_power_meg_high = summed_power_meg_high.__add__(power)


In [None]:
#average TRF of no SWR epochs across all MEG sensors for all patients
MEG_baseline_powers_low = {}
valid_patient_ids_meg_low = []
for patient_id, epoch_file in MEG_low_epochs_files.items():
    epochs = mne.read_epochs(epoch_file, preload=True, verbose=False)
    epochs.pick_types(meg=True)
    power, itc = epochs.compute_tfr(
        method="morlet", 
        freqs=frequencies, 
        n_cycles=n_cycles, 
        decim=1, 
        return_itc=True,
        average=True
    )
    power.apply_baseline((-0.7, -0.5), mode='logratio')
    if np.isnan(power.data).all():  
        continue
    avg_data = np.mean(power.data, axis=0, keepdims=True) 
    new_power = power.copy()
    new_power._data = avg_data
    new_power.info = mne.create_info(
        ch_names=['MeanChannel'],
        sfreq=power.info['sfreq'],
        ch_types='eeg' 
    )
    new_power.ch_names[0] = 'MeanChannel'
    MEG_baseline_powers_low[patient_id] = new_power
    valid_patient_ids_meg_low.append(patient_id)

summed_power_meg_low = None
for patient_id in valid_patient_ids_meg_low:
    power = MEG_baseline_powers_low[patient_id]
    if summed_power_meg_low is None:
        summed_power_meg_low = power.copy()
    else:
        summed_power_meg_low = summed_power_meg_low.__add__(power)

In [None]:
#plot the resutls
fig, axs = plt.subplots(2, 2, figsize=(12, 8))

# A: High SWR time-frequency plot across iEEG hippocampal contacts
im0 = axs[0,0].imshow(summed_power_high.data[0], aspect='auto', origin='lower', cmap='RdBu_r',
                      extent=[summed_power_high.times[0], summed_power_high.times[-1], frequencies[0], frequencies[-1]],
                      vmin=-0.4, vmax=0.4)
axs[0,0].set_title('High SWR epochs', fontsize=14, fontweight='bold')
axs[0,0].text(-0.15, 1.05, 'A', transform=axs[0,0].transAxes, fontsize=22, fontweight='bold')
axs[0,0].set_xlabel('Time (s)')
axs[0,0].set_ylabel('Frequency (Hz)')
axs[0,0].set_xlim(-0.5, 0.5)
fig.colorbar(im0, ax=axs[0,0], label='Power (log ratio)')


# B: No SWR time-frequency plot across iEEG hippocampal contacts
im2 = axs[0,1].imshow(summed_power_low.data[0], aspect='auto', origin='lower', cmap='RdBu_r',
                      extent=[summed_power_high.times[0], summed_power_high.times[-1], frequencies[0], frequencies[-1]],
                      vmin=-0.4, vmax=0.4)
axs[0,1].set_title('No SWR epochs', fontsize=14, fontweight='bold')
axs[0,1].text(-0.15, 1.05, 'B', transform=axs[0,1].transAxes, fontsize=22, fontweight='bold')
axs[0,1].set_xlabel('Time (s)')
axs[0,1].set_ylabel('Frequency (Hz)')
axs[0,1].set_xlim(-0.5, 0.5)
fig.colorbar(im2, ax=axs[0,1], label='Power (log ratio)')

#C: High SWR time-frequency plot across all MEG sensors 
im1 = axs[1,0].imshow(summed_power_meg_high.data[0], aspect='auto', origin='lower', cmap='RdBu_r',
                      extent=[summed_power_high.times[0], summed_power_high.times[-1], frequencies[0], frequencies[-1]],
                      vmin=-0.4, vmax=0.4)
axs[1,0].set_title('High SWR epochs', fontsize=14, fontweight='bold')
axs[1,0].text(-0.15, 1.05, 'C', transform=axs[1,0].transAxes, fontsize=22, fontweight='bold')
axs[1,0].set_xlabel('Time (s)')
axs[1,0].set_ylabel('Frequency (Hz)')
axs[1,0].set_xlim(-0.5, 0.5)
fig.colorbar(im1, ax=axs[1,0], label='Power (log ratio)')

# D: No SWR time-frequency plot across all MEG sensors 
im3 = axs[1,1].imshow(summed_power_meg_low.data[0], aspect='auto', origin='lower', cmap='RdBu_r',
                      extent=[summed_power_high.times[0], summed_power_high.times[-1], frequencies[0], frequencies[-1]],
                      vmin=-0.4, vmax=0.4)
axs[1,1].set_title('No SWR epochs', fontsize=14, fontweight='bold')
axs[1,1].text(-0.15, 1.05, 'D', transform=axs[1,1].transAxes, fontsize=22, fontweight='bold')
axs[1,1].set_xlabel('Time (s)')
axs[1,1].set_ylabel('Frequency (Hz)')
axs[1,1].set_xlim(-0.5, 0.5)
fig.colorbar(im3, ax=axs[1,1], label='Power (log ratio)')



