In [None]:
import numpy as np
import torch
import sys, os

sys.path.append("../")
from vi_rnn.saving import load_model
from evaluation.eval_kl_pse import eval_kl_pse as eval_VAE
from vi_rnn.datasets import Basic_dataset
from scipy.stats import median_abs_deviation as mad
from matplotlib.colors import colorConverter as cc
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
# initialise dataset
eval_data = np.float32(np.load("../data/eeg/EEG_data_smoothed.npy"))
task_params = {"name": "EEG", "dur": 50, "n_trials": 500}
task = Basic_dataset(task_params, eval_data.T, eval_data.T)

In [None]:
# load and eval models, rank 3
directory = "../models/sweep_eeg_rs/"
directory_bs = os.fsencode(directory)

data_kl = []
data_ph = []
for file in os.listdir(directory_bs):
    filename = os.fsdecode(file)
    if filename.endswith("_vae_params.pkl"):
        model_name = filename.removesuffix("_vae_params.pkl")
        print(model_name)
        vae, training_params,task_params = load_model(
            directory + model_name, load_encoder=False, backward_compat=True
        )
        klx_bin, psH, _ = eval_VAE(
            vae,
            task,
            smoothing=20,
            cut_off=2400,
            freq_cut_off=-1,
            smooth_at_eval=True,
        )
        data_kl.append(klx_bin)
        data_ph.append(psH)

In [None]:
# check if we have 20 models
print(len(data_kl))
assert len(data_kl) == 20

In [None]:
# print median and mad hellinger distance
print(np.median(data_ph), mad(data_ph))

# print median and mad KL divergence
print(np.median(data_kl), mad(data_kl))

In [None]:
# Print number of parameters
# Weights + Biases + Out biases + Out weights + Cholesky latent covariance + Time constant
# + Observation variance + Intial covariance, initial mean


def n_el(n):
    """return number of elements in a triangular matrix"""
    return int(((n**2) + n) / 2)


dz = 3
dx = 64
N = 512
n_params = N * dz * 2 + N + dx + dz * dx + n_el(dz) + 1 + dx + n_el(dz) + dz
print(n_params)

Note that the full-rank EEG models were not include in the public repo to keep it light. Run 

   `train_scripts/eeg/train_EEG_full_rank.py`

to obtain the models, or inquire with the authors

In [None]:
# load and eval models full rank, 30 units
directory = "../data_untracked/additional_eeg/sweep_eeg_fullrank/"
directory_bs = os.fsencode(directory)


data_kl_FR = []
data_ph_FR = []
for file in os.listdir(directory_bs):
    filename = os.fsdecode(file)
    if filename.endswith("_vae_params.pkl"):
        model_name = filename.removesuffix("_vae_params.pkl")
        print(model_name)
        vae, params, task_params, training_params = load_model(
            directory + model_name, load_encoder=False
        )
        klx_bin, psH, _ = eval_VAE(
            vae,
            task,
            smoothing=20,
            cut_off=2400,
            freq_cut_off=-1,
            smooth_at_eval=True,
        )
        data_kl_FR.append(klx_bin)
        data_ph_FR.append(psH)

In [None]:
# Print number of parameters
# Weights + Biases + Out biases + Out weights + Cholesky latent covariance + Time constant
# + Observation variance + Intial covariance, initial mean


def n_el(n):
    """return number of elements in a triangular matrix"""
    return int(((n**2) + n) / 2)


dz = 30
dx = 64
N = 30
n_params = N * dz + N + dx + dz * dx + n_el(dz) + 1 + dx + n_el(dz) + dz
print(n_params)

In [None]:
# print median and mad hellinger distance
print(np.median(data_ph), mad(data_ph))

# print median and mad KL divergence
print(np.median(data_kl), mad(data_kl))

In [None]:
# load and eval models
directory = "../data_untracked/additional_eeg/eeg_rank128/"
directory_bs = os.fsencode(directory)

data_kl_FR128 = []
data_ph_FR128 = []
for file in os.listdir(directory_bs):
    filename = os.fsdecode(file)
    if filename.endswith("_vae_params.pkl"):
        model_name = filename.removesuffix("_vae_params.pkl")
        print(model_name)
        vae, params, task_params, training_params = load_model(
            directory + model_name, load_encoder=False
        )
        klx_bin, psH, _ = eval_VAE(
            vae,
            task,
            smoothing=20,
            cut_off=2400,
            freq_cut_off=-1,
            smooth_at_eval=True,
        )
        data_kl_FR128.append(klx_bin)
        data_ph_FR128.append(psH)

In [None]:
# Print number of parameters
# Weights + Biases + Out biases + Out weights + Cholesky latent covariance + Time constant
# + Observation variance + Intial covariance, initial mean


def n_el(n):
    """return number of elements in a triangular matrix"""
    return int(((n**2) + n) / 2)


dz = 128
dx = 64
N = 128
n_params = N * dz + N + dx + dz * dx + n_el(dz) + 1 + dx + n_el(dz) + dz
print(n_params)

In [None]:
# print median and mad hellinger distance
print(np.median(data_ph_FR128), mad(data_ph_FR128))

# print median and mad KL divergence
print(np.median(data_kl_FR128), mad(data_kl_FR128))

In [None]:
# Plot accuracy as box plots
alpha = 0.5
plt_colors = ["teal"] * 6
fig, ax = plt.subplots(1, 2, figsize=(3.3, 1), dpi=300)

c2 = "darkturquoise"
c1 = "thistle"
c3 = "lightpink"


# KL
ax[0].boxplot(
    [data_kl],
    positions=[0],
    widths=0.6,
    patch_artist=True,
    boxprops=dict(facecolor=cc.to_rgba(c2, alpha=alpha), color=c2),
    capprops=dict(color=c2),
    whiskerprops=dict(color=c2),
    medianprops=dict(color=c2),
    flierprops={
        "marker": "o",
        "markersize": 1,
        "markerfacecolor": c2,
        "markeredgecolor": c2,
    },
)

ax[0].boxplot(
    [data_kl_FR],
    positions=[1],
    widths=0.6,
    patch_artist=True,
    boxprops=dict(facecolor=cc.to_rgba(c1, alpha=alpha), color=c1),
    capprops=dict(color=c1),
    whiskerprops=dict(color=c1),
    medianprops=dict(color=c1),
    flierprops={
        "marker": "o",
        "markersize": 1,
        "markerfacecolor": c1,
        "markeredgecolor": c1,
    },
)

ax[0].boxplot(
    [data_kl_FR128],
    positions=[2],
    widths=0.6,
    patch_artist=True,
    boxprops=dict(facecolor=cc.to_rgba(c3, alpha=alpha), color=c3),
    capprops=dict(color=c3),
    whiskerprops=dict(color=c3),
    medianprops=dict(color=c3),
    flierprops={
        "marker": "o",
        "markersize": 1,
        "markerfacecolor": c3,
        "markeredgecolor": c3,
    },
)

# Hellinger
ax[1].boxplot(
    [data_ph],
    positions=[0],
    widths=0.6,
    patch_artist=True,
    boxprops=dict(facecolor=cc.to_rgba(c2, alpha=alpha), color=c2),
    capprops=dict(color=c2),
    whiskerprops=dict(color=c2),
    medianprops=dict(color=c2),
    flierprops={
        "marker": "o",
        "markersize": 1,
        "markerfacecolor": c2,
        "markeredgecolor": c2,
    },
    label="Rank 3 (N=512)",
)

ax[1].boxplot(
    [data_ph_FR],
    positions=[1],
    widths=0.6,
    patch_artist=True,
    boxprops=dict(facecolor=cc.to_rgba(c1, alpha=alpha), color=c1),
    capprops=dict(color=c1),
    whiskerprops=dict(color=c1),
    medianprops=dict(color=c1),
    flierprops={
        "marker": "o",
        "markersize": 1,
        "markerfacecolor": c1,
        "markeredgecolor": c1,
    },
    label="Full Rank (N=30)",
)

ax[1].boxplot(
    [data_ph_FR128],
    positions=[2],
    widths=0.6,
    patch_artist=True,
    boxprops=dict(facecolor=cc.to_rgba(c3, alpha=alpha), color=c3),
    capprops=dict(color=c3),
    whiskerprops=dict(color=c3),
    medianprops=dict(color=c3),
    flierprops={
        "marker": "o",
        "markersize": 1,
        "markerfacecolor": c3,
        "markeredgecolor": c3,
    },
    label="Full Rank (N=128)",
)

# Labels
ax[0].set_xticks([])
ax[1].set_xticks([])
ax[1].set_yticks([0.0, 0.1, 0.2])
ax[1].set_yticklabels(["0", ".1", ".2"])
ax[0].tick_params(axis="x", length=0)
ax[1].tick_params(axis="x", length=0)
ax[0].set_ylim(1, 4)
ax[1].set_ylim(0, 0.3)
ax[0].set_xlim(-1, 5)
ax[1].set_xlim(-1, 5)
ax[0].spines[["bottom"]].set_visible(False)
ax[1].spines[["bottom"]].set_visible(False)
ax[0].set_ylabel(r"$D_{stsp}$" "\n" "$\downarrow$", rotation=0, labelpad=10)
ax[1].set_ylabel(r"$D_{H}$" "\n" "$\downarrow$", rotation=0, labelpad=5)

# Legend
handles, labels = plt.gca().get_legend_handles_labels()
order = [0, 1, 2]
legend = ax[1].legend(
    [handles[idx] for idx in order],
    [labels[idx] for idx in order],
    loc=1,
    bbox_to_anchor=(1.9, 1),
    fontsize=6,
    handlelength=0.55,
    handleheight=0.5,
)


legend_colors = [c2, c1, c3]
for text, color in zip(legend.get_texts(), legend_colors):
    text.set_color(color)


plt.tight_layout(pad=0.3)
plt.savefig("../figures/FR_boxplot.svg", bbox_inches="tight")
plt.savefig("../figures/EEG_FR.png", bbox_inches="tight")