In [1]:
import sys
sys.path.insert(0, '/home/jovyan/braindecode/')
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import mne
from mne.channels import make_standard_montage
from mne.viz import plot_topomap

from braindecode.visualization import compute_amplitude_gradients

from decode_tueg import load_exp, DataScaler, TargetScaler, Augmenter

In [2]:
def add_cbar(fig, ax_img):
    # manually add colorbar
    ax_x_start = 0.95
    ax_x_width = 0.04
    ax_y_start = 0.1
    ax_y_height = 0.9
    cbar_ax = fig.add_axes([ax_x_start, ax_y_start, ax_x_width, ax_y_height])
    clb = fig.colorbar(ax_img, cax=cbar_ax)
    clb.ax.set_title('uV')

In [3]:
names = ['A1', 'A2', 'C3', 'C4', 'Cz', 'F3', 'F4', 'F7', 'F8', 'Fp1', 'Fp2',
       'Fz', 'O1', 'O2', 'P3', 'P4', 'Pz', 'T3', 'T4', 'T5', 'T6']

In [4]:
montage = make_standard_montage('standard_1020')
info = mne.create_info(names, 100, ch_types='eeg')
info = info.set_montage(montage)

In [5]:
batch_size = 64
n_recordings = None
exp_dir = '/home/jovyan/experiments/'
checkpoint = 'train_end'

In [6]:
clf, data_scaler, target_scaler, config = load_exp(
    exp_dir,
    '2022-10-17T16:24:10.124445/20220429/0',
    checkpoint
)

In [7]:
for ds_name in ['transition']:  # non-pathological, pathological
    break

In [8]:
with open(f'/home/jovyan/longitudinal/{ds_name}_pre_win.pkl', 'rb') as f:
    ds = pickle.load(f)

In [9]:
if n_recordings is not None:
    ds = ds.split(list(range(n_recordings)))['0']

In [None]:
all_grads = {}
for n, d in ds.split('pathological').items():
    grads = compute_amplitude_gradients(clf.module, d, batch_size)
    avg_grads = grads.mean((3, 1, 0))
    all_grads['Non-pathological' if n == '0' else 'Pathological'] = avg_grads
all_grads['Non-pathological – Pathological'] = all_grads['Non-pathological'] - all_grads['Pathological']

In [None]:
vmin, vmax = pd.DataFrame(all_grads).min().min(), pd.DataFrame(all_grads).max().max()
max_abs = np.abs([vmin, vmax]).max()

In [None]:
fig, ax_arr = plt.subplots(1, 3, figsize=(12, 3))
for i, (n, g) in enumerate(all_grads.items()):
    ax_img, contours = plot_topomap(
        g, 
        info,
        size=5,
        names=names,
        show=False,
        axes=ax_arr[i],
        vlim=(-max_abs, max_abs),
    )
    ax_img.axes.set_title(n)
add_cbar(fig, ax_img)