# Notebook for training a DDPM on the BCI Challenge @ NER 2015

This notebooks trains a DDPM from scratch to generate synthetic EEG trials.
The generated trials can then be plotted and compared to the real data.
The data is provided in the `data` folder.

In [None]:
import logging

import matplotlib.pyplot as plt
import mne
import numpy as np
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
from pandas import read_csv

from ntd.datasets import NER_BCI
from ntd.train_diffusion_model import training_and_eval_pipeline
from ntd.utils.plotting_utils import (
    basic_plotting,
    plot_overlapping_signal,
    plot_sd,
)
from ntd.utils.utils import l2_distances

logging.basicConfig(level=logging.INFO)


## Create the config

If you want to make changes to the network architecture or optimization, override the config in `overrides`.

In [None]:
data_path = "../data"
with initialize(version_base=None, config_path="../conf"):
    cfg = compose(
        config_name="config",
        overrides=[
            "base.experiment=ner_example",
            "base.tag=unconditional_wn",
            "base.wandb_mode=disabled",
            f"dataset.filepath={data_path}",
            "base.save_path=null",
            "optimizer.num_epochs=1",
            "optimizer.lr=0.0004",
            "diffusion=diffusion_quad_50",
            "+experiments/generate_samples=generate_samples",
        ],
    )
    print(OmegaConf.to_yaml(cfg))


## Train the model and and generate samples

In [None]:
diffusion_model, samples = training_and_eval_pipeline(cfg)
samples_numpy = samples.numpy()


In [None]:
raw_numpy = NER_BCI(cfg.dataset.patient_id, filepath=cfg.dataset.filepath).data_array

assert samples_numpy.shape == raw_numpy.shape
num_trials, num_channels, sig_length = raw_numpy.shape


## Plot some random real and generated samples

In [None]:
rand_id = np.random.randint(len(samples_numpy))
print(rand_id)
offset = -1.1
fig, ax = plt.subplots()
plot_overlapping_signal(
    fig,
    ax,
    sig=raw_numpy[rand_id] + offset * np.arange(num_channels)[:, np.newaxis],
    colors=["dimgrey"],
)
basic_plotting(
    fig,
    ax,
    y_ticks=[],
    x_lim=(0, 260),
    x_ticks=[0, 260],
    x_ticklabels=[0, 1.3],
    x_label="time [s]",
)
fig.tight_layout()
plt.show()

fig, ax = plt.subplots()
plot_overlapping_signal(
    fig,
    ax,
    samples_numpy[rand_id] + offset * np.arange(num_channels)[:, np.newaxis],
    colors=["black"],
)
basic_plotting(
    fig,
    ax,
    y_ticks=[],
    x_lim=(0, 260),
    x_ticks=[0, 260],
    x_ticklabels=[0, 1.3],
    x_label="time [s]",
)
fig.tight_layout()
plt.show()


## Plot the power spectral density

For all 56 channels up until 60Hz.
Red is generated, black is real.
Pointwise median and 25% / 75% percentiles are shown.

In [None]:
fig, axs = plt.subplots(num_channels // 10 + 1, 10, figsize=(45, 35))
for idx in range(num_channels):
    plot_sd(
        fig=fig,
        ax=axs[idx // 10, idx % 10],
        arr_one=raw_numpy[:, idx, :],
        arr_two=samples_numpy[:, idx, :],
        fs=200,
        nperseg=260,
        agg_function=np.median,
        with_quantiles=True,
        x_ss=slice(0, 60),
        color_one="black",
        color_two="C3",
    )
plt.show()


## Plot the evoked potentials

For all 56 channels. Red is generated, black is real. Mean and standard deviation are shown.

In [None]:
fig, axs = plt.subplots(num_channels // 5 + 1, 5, figsize=(45, 45))
for idx in range(num_channels):
    axs[idx // 5, idx % 5].fill_between(
        np.arange(260),
        np.quantile(raw_numpy[:, idx, :], 0.1, axis=0),
        np.quantile(raw_numpy[:, idx, :], 0.9, axis=0),
        color="black",
        alpha=0.2,
    )
    axs[idx // 5, idx % 5].fill_between(
        np.arange(260),
        np.quantile(samples_numpy[:, idx, :], 0.1, axis=0),
        np.quantile(samples_numpy[:, idx, :], 0.9, axis=0),
        color="C3",
        alpha=0.2,
    )
    axs[idx // 5, idx % 5].plot(
        np.mean(raw_numpy[:, idx, :], axis=0),
        color="black",
    )
    axs[idx // 5, idx % 5].plot(
        np.mean(samples_numpy[:, idx, :], axis=0),
        color="C3",
    )
plt.show()


## Plot the topomaps

For both real and generated data.

In [None]:
csv_path = "../data/CorrectedChannelsLocation.csv"
chan_info = read_csv(csv_path)

montage = mne.channels.make_standard_montage("standard_1020")
info = mne.create_info(ch_names=list(chan_info["Labels"]), sfreq=200, ch_types="eeg")
info.set_montage(montage)
times = np.arange(0.15, 0.35, 0.05)
for samples in [raw_numpy, samples_numpy]:
    evoked = mne.EvokedArray(samples.mean(0), info)
    evoked.plot_topomap(
        times,
        ch_type="eeg",
        scalings=1.0,
        vlim=(-0.5, 0.5),
        image_interp="cubic",
        colorbar=False,
        res=300,
        size=1.5,
    )
