In [1]:
import sys
sys.path.insert(0, '/home/jovyan/braindecode/')
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import mne
from mne.channels import make_standard_montage
from mne.viz import plot_topomap

from braindecode.visualization import compute_amplitude_gradients

from decode_tueg import load_exp, DataScaler, TargetScaler, Augmenter

Mon Feb  6 13:02:14 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.07    Driver Version: 515.48.07    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:85:00.0 Off |                  N/A |
| 22%   44C    P5    20W / 250W |      4MiB / 12288MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
def add_cbar(fig, ax_img):
    # manually add colorbar
    ax_x_start = 0.95
    ax_x_width = 0.04
    ax_y_start = 0.1
    ax_y_height = 0.9
    cbar_ax = fig.add_axes([ax_x_start, ax_y_start, ax_x_width, ax_y_height])
    clb = fig.colorbar(ax_img, cax=cbar_ax)
    clb.ax.set_title('') # gradient?

In [3]:
def freq_to_bin(bins, freq):
    return np.abs(bins - freq).argmin()


def freqs_to_bin(bins, freqs):
    return [np.abs(bins - freq).argmin() for freq in freqs]

In [18]:
batch_size = 64
n_jobs = 4
n_recordings = 200
exp_dir = '/home/jovyan/experiments/'
checkpoint = 'train_end'

In [19]:
clf, data_scaler, target_scaler, config = load_exp(
    exp_dir,
    '2023-02-03T22:27:26.056071/20230203/0',
    checkpoint
)

In [20]:
for ds_name in ['transition']:  # non-pathological, pathological
    break

In [21]:
with open(f'/home/jovyan/longitudinal/{ds_name}_pre_win.pkl', 'rb') as f:
    ds = pickle.load(f)

In [22]:
names = [ch.split(' ')[-1] for ch in ds.datasets[0].windows.info['ch_names']]
names = [ch.replace('Z', 'z') for ch in names]
names = [ch.replace('P', 'p') if ch in ['FP1', 'FP2'] else ch for ch in names]

In [23]:
sfreq = ds.datasets[0].windows.info['sfreq']

In [24]:
freqs = np.fft.rfftfreq(ds[0][0].shape[1], 1/ds.datasets[0].windows.info['sfreq'])

In [25]:
montage = make_standard_montage('standard_1020')
info = mne.create_info(names, sfreq, ch_types='eeg')
info = info.set_montage(montage)

In [26]:
if n_recordings is not None:
    ds = ds.split(list(range(n_recordings)))['0']

In [27]:
def compute_gradients(clf, ds, batch_size, n_jobs):
    all_grads = {}
    for n, d in ds.split('pathological').items():
        grads = compute_amplitude_gradients(clf.module, d, batch_size, n_jobs)
        avg_grads = grads.mean((1, 0))
        all_grads['Non-pathological' if n == '0' else 'Pathological'] = avg_grads
    if 'Non-pathological' in all_grads.keys() and 'Pathological' in all_grads.keys():
        all_grads['Non-pathological – Pathological'] = all_grads['Non-pathological'] - all_grads['Pathological']
    return all_grads

In [None]:
%%time
all_grads = compute_gradients(clf, ds, batch_size, n_jobs)

In [None]:
bands = [(0,4),(4,8),(8,13),(13,30),(30,50)]

In [None]:
for band in bands:
    all_band_grads = []
    keys = list(all_grads.keys())
    for n_i, key in enumerate(keys):
        grads = all_grads[key]
        #print(band, key)
        # compute all bands gradients
        l, h = freqs_to_bin(freqs, band)
        band_freqs = freqs[l:h+1]
        band_grads = grads[:,l:h+1].mean(axis=1)
        all_band_grads.append(band_grads)        

    fig, ax_arr = plt.subplots(1, 3, figsize=(10, 3))
    # for better comparability, compute vlim over all pathological subsets of a freq band
    vmin, vmax = np.min(all_band_grads), np.max(all_band_grads)
    max_abs = np.abs([vmin, vmax]).max()

    for band_i, band_grads in enumerate(all_band_grads):
        ax_img, contours = plot_topomap(
            band_grads, 
            info,
            size=3,
            names=names,
            show=False,
            axes=ax_arr[band_i],
            vlim=(-max_abs, max_abs),
        )
        ax_img.axes.set_title(f'{keys[band_i]}\n')
        if band_i == 0:
            ax_img.axes.set_ylabel('-'.join([str(i) for i in band])+' Hz\n')
    # move one in for sanity check. one band, all subsets -> same cbar vlim
    add_cbar(fig, ax_img)