## This notebook contains code for plotting tfr-representaion of EEG data using plotly.

In [None]:
import numpy as np
import pandas as pd
import mne
import plotly.graph_objects as go
from plotly.subplots import make_subplots

In [None]:
from plotly.io import templates

templates.default = 'plotly'
templates

In [None]:
from pathlib import Path

path = Path('../data/asd/')
freqs = np.linspace(4, 12, 41)

info = pd.read_csv(path / 'path_file.csv')
eegs = {i: mne.io.read_raw_fif(path/fn, verbose=False) for i, fn in info['fn'].items()}

typical = np.where(info['target'] == 'typical')[0]
asd = np.where(info['target'] == 'asd')[0]

ch_names = eegs[0].ch_names

assert np.all([eegs[i].info['sfreq'] == sfreq for i, sfreq in info['sfreq'].items()])

In [None]:
def compute_tfr(data, freqs, sfreq=125):
    return mne.time_frequency.tfr_array_multitaper(data[np.newaxis], sfreq=sfreq, freqs=freqs, output='avg_power')


# def plot_tfr(raw, channels, freqs, tmin=0, tmax=None):
#     data, times = raw[channels, tmin:tmax]
#     tfr = compute_tfr(data, freqs, raw.info['sfreq'])
#     fig = go.Figure(
#         data=[go.Heatmap(x=times, y=freqs, z=tfr[0])],
#         layout=go.Layout(title='TFR Representation',
#                          title_xanchor='center',
#                          title_x= 0.5,
#                          xaxis_title='Time (seconds)',
#                          yaxis_title='Frequncy (Hz)')
#     )
#     if not isinstance(channels, str):
#         fig.update_layout(
#             updatemenus=[dict(type='dropdown',
#                             buttons=[
#                                 dict(args=[{'z': [tfr[i]]},
#                                            {'title': f'TFR Representation ({ch_name})'}],
#                                     label=ch_name,
#                                     method='update')
#                                 for i, ch_name in enumerate(channels)]
#             )]
#         )
#     return fig


# def plot_signal(raw, channels, tmin=0, tmax=None):
#     data, times = raw[channels, tmin:tmax]
#     fig = go.Figure(
#         data=[go.Scatter(x=times, y=data[0], mode='lines')],
#         layout=go.Layout(title='EEG Signal',
#                          title_xanchor='center',
#                          title_x= 0.5,
#                          xaxis_title='Time (seconds)',
#                          yaxis_title='Power')
#     )
#     if not isinstance(channels, str):
#         fig.update_layout(
#             updatemenus=[dict(type='dropdown',
#                             buttons=[
#                                 dict(args=[{'y': [data[i]]},
#                                            {'title': f'EEG Signal ({ch_name})'}],
#                                     label=ch_name,
#                                     method='update')
#                                 for i, ch_name in enumerate(channels)]
#             )]
#         )
#     return fig


def plot_channels(raw, tmin=0, tmax=None):
    data, times = raw[ch_names, tmin:tmax]
    fig = make_subplots(rows=len(ch_names),
                        cols=1,
                        shared_xaxes=True,
                        vertical_spacing=0.005)

    for i, ch_name in enumerate(ch_names):
        fig.add_trace(go.Scatter(x=times, y=data[i], mode='lines', name=ch_name), row=i+1, col=1)
        fig.update_yaxes(title_text=ch_name, row=i+1, col=1)

    fig.update_layout(title='EEG Signal',
                      title_xanchor='center',
                      title_x= 0.5,
                      height=1500)

    fig.update_xaxes(title_text='Time (s)', row=len(ch_names), col=1)

    return fig

def plot_combined(raw, channels, freqs, tmin=0, tmax=None):
    data, times = raw[channels, tmin:tmax]
    filtered_data = mne.filter.filter_data(data, raw.info['sfreq'], freqs.min(), freqs.max(), verbose=False)
    tfr = compute_tfr(data, freqs, raw.info['sfreq'])
    fig = make_subplots(rows=3,
                        cols=1,
                        shared_xaxes=True, 
                        subplot_titles=('TFR Representation', 'Raw Signal', 'Filtered signal'),
                        vertical_spacing=0.035)

    fig.add_trace(go.Heatmap(x=times, y=freqs, z=tfr[0], name='spectrogram'), row=1, col=1)
    fig.add_trace(go.Scatter(x=times, y=data[0], mode='lines', name='raw signal'), row=2, col=1)
    fig.add_trace(go.Scatter(x=times, y=filtered_data[0], mode='lines', name='filtered signal'), row=3, col=1)

    fig.update_layout(height=1000)

    fig.update_yaxes(title_text='Frequncy (Hz)', row=1, col=1)
    fig.update_yaxes(title_text='Amplitude (μV)', row=2, col=1)
    fig.update_yaxes(title_text='Amplitude (μV)', row=3, col=1)
    fig.update_xaxes(title_text='Time (s)', row=3, col=1)

    if not isinstance(channels, str):
        fig.update_layout(
            updatemenus=[dict(type='dropdown',
                            buttons=[
                                dict(args=[{'y': [freqs, data[i], filtered_data[i]], 'z': [tfr[i], None, None]}],
                                    label=ch_name,
                                    method='restyle')
                                for i, ch_name in enumerate(channels)]
            )]
        )
    return fig

In [None]:
plot_combined(eegs[typical[0]], ch_names, freqs, tmax=1000)

In [None]:
plot_combined(eegs[asd[1]], ch_names, freqs, tmax=1000)

In [None]:
plot_channels(eegs[typical[0]])

In [None]:
plot_channels(eegs[asd[0]])