In [1]:
import logging
import mne
import os
import datetime
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots


## Functions

In [57]:
def plot_PSD_channels(df1, df2, names):
    import plotly.express as px
    from plotly.subplots import make_subplots

    # Create subplots with shared y-axis
    fig_subplots = make_subplots(rows=1, cols=2, shared_yaxes=True,
                                 subplot_titles=[f"{name} signals" for name in names])

    showlegend=True
    for i, df in enumerate([df1, df2]):
        # Generate Plotly Express line plot
        fig = px.line(df, y=df.columns[df.columns!='Frequency (Hz)'], x='Frequency (Hz)')

        # Add traces from original plot to the subplots
        for trace in fig.data:
            # Assign unique legend group names to each trace to avoid duplicate labels
            fig_subplots.add_trace(trace.update(showlegend=showlegend), row=1, col=i+1)
        showlegend=False

    min_f = min(np.min(orig_chns.loc[:,'Frequency (Hz)'].values),np.min(dn_chns.loc[:,'Frequency (Hz)'].values))
    max_f = min(np.max(orig_chns.loc[:,'Frequency (Hz)'].values),np.max(dn_chns.loc[:,'Frequency (Hz)'].values))
    # fig_subplots.update_xaxes(range=[min_f, max_f])
    # Update layout
    fig_subplots.update_layout(
        title="Normalized PSD",
        height=400,
        xaxis_range=[min_f, max_f]
    )

    return fig_subplots


In [59]:
def get_PSD(edf_path:str, length_segment: float=3.0):
    import pyedflib
    edf = pyedflib.EdfReader(edf_path)
    labels = edf.getSignalLabels()
    srate = edf.getSampleFrequencies()[0]/edf.datarecord_duration

    # Read signals
    signals = []
    for i in range(len(labels)):
        signals.append(edf.readSignal(i))
    signals = np.array(signals)
    edf.close()

    f, welchpow = welch(signals, fs=srate, nperseg=int(length_segment*srate), axis=1)
    welchpow = np.divide(welchpow, np.sqrt(np.sum(welchpow**2, axis=1)).reshape(welchpow.shape[0],1))
    df = pd.DataFrame(welchpow.T, columns=labels)
    df.insert(0, 'Frequency (Hz)', f)
    return df

## Downsampling

In [60]:
import pyedflib
out_edf_dn = '/scratch/mcesped/Results/tmp_seeg/sub-097_ses-007_task-full_rec-dn_run-01_clip-01_ieeg.edf'
dn_chns = get_PSD(out_edf_dn)
dn_chns.head()

Unnamed: 0,Frequency (Hz),Patient Event,LOFr1,LOFr2,LOFr3,LOFr4,LOFr5,LOFr6,LOFr7,LOFr8,...,DC11,DC12,DC13,DC14,DC15,DC16,TRIG,OSAT,PR,Pleth
0,0.0,,0.12059,0.123793,0.117794,0.121683,0.092262,0.115292,0.137477,0.133593,...,0.009108,0.012234,0.017791,0.014107,0.011009,0.015593,0.00019,,,0.001496
1,0.333333,,0.696886,0.735497,0.794546,0.759633,0.858451,0.794265,0.765808,0.740604,...,0.051005,0.065937,0.094972,0.087363,0.05931,0.074247,0.000908,,,0.004501
2,0.666667,,0.502645,0.493691,0.438539,0.458823,0.396737,0.388258,0.435754,0.470508,...,0.052021,0.060545,0.08603,0.093689,0.074754,0.068992,0.000255,,,0.001529
3,1.0,,0.309788,0.286729,0.23264,0.258562,0.170977,0.22041,0.262665,0.291375,...,0.047496,0.054964,0.078639,0.080097,0.079554,0.067554,0.000161,,,0.001295
4,1.333333,,0.227648,0.191405,0.164937,0.17954,0.128576,0.178869,0.177319,0.179978,...,0.048446,0.051969,0.073932,0.075587,0.06742,0.067088,8.2e-05,,,0.000764


In [61]:
import pyedflib
edf_orig = '/scratch/mcesped/Results/tmp_seeg/sub-097_ses-007_task-full_rec-clip_run-01_clip-01_ieeg.edf'
orig_chns = get_PSD(edf_orig)
orig_chns.head()

Unnamed: 0,Frequency (Hz),Patient Event,LOFr1,LOFr2,LOFr3,LOFr4,LOFr5,LOFr6,LOFr7,LOFr8,...,DC11,DC12,DC13,DC14,DC15,DC16,TRIG,OSAT,PR,Pleth
0,0.0,,0.12062,0.123816,0.117915,0.12176,0.092202,0.1153,0.137437,0.133518,...,0.007575,0.010003,0.007979,0.005964,0.005272,0.00687,,,,
1,0.333333,,0.696882,0.735476,0.794402,0.759242,0.858134,0.794042,0.765541,0.740162,...,0.043107,0.05521,0.04109,0.038926,0.02738,0.032523,,,,
2,0.666667,,0.502661,0.493565,0.438326,0.458579,0.396631,0.388031,0.435508,0.470275,...,0.043733,0.049775,0.039018,0.040171,0.032832,0.031408,,,,
3,1.0,,0.309631,0.286586,0.232539,0.258367,0.170886,0.220211,0.262388,0.291072,...,0.040517,0.044316,0.03518,0.033457,0.033535,0.030358,,,,
4,1.333333,,0.227369,0.191149,0.16474,0.17928,0.128493,0.178661,0.177134,0.17982,...,0.040692,0.042742,0.032283,0.032395,0.027798,0.029618,,,,


In [None]:
fig = plot_PSD_channels(orig_chns, dn_chns, ["Original","Dn"])
fig.show()