In [None]:
import numpy as np
import torch
import sys, os

sys.path.append("../")

from vi_rnn.saving import load_model
from vi_rnn.evaluation import eval_VAE
from vi_rnn.datasets import Basic_dataset

from scipy.stats import median_abs_deviation as mad

In [None]:
# initialise dataset
eval_data = np.float32(np.load("../data/eeg/EEG_data_smoothed.npy"))
task_params = {"name": "EEG", "dur": 50, "n_trials": 500}
task = Basic_dataset(task_params, eval_data, eval_data)

In [None]:
# load and eval models
directory = "../models/sweep_eeg/"

directory_bs = os.fsencode(directory)

data_kl = []
data_ph = []

for file in os.listdir(directory_bs):
    filename = os.fsdecode(file)
    if filename.endswith("_vae_params.pkl"):
        model_name = filename.removesuffix("_vae_params.pkl")
        print(model_name)
        vae, params, task_params, training_params = load_model(
            directory + model_name, load_encoder=False
        )
        klx_bin, psH, _ = eval_VAE(
            vae,
            task,
            smoothing=20,
            cut_off=2400,
            freq_cut_off=-1,
            sim_obs_noise=1,
            sim_latent_noise=True,
            smooth_at_eval=True,
        )
        data_kl.append(klx_bin)
        data_ph.append(psH)

In [None]:
# check if we have 20 models
assert len(data_kl) == 20

In [None]:
# print median and mad hellinger distance
print(np.median(data_ph), mad(data_ph))

In [None]:
# print median and mad KL divergence
print(np.median(data_kl), mad(data_kl))

In [None]:
# Print number of parameters
# Weights + Biases + Out biases + Out weights + Cholesky latent covariance + Time constant
# + Observation variance + Intial covariance, initial mean


def n_el(n):
    """return number of elements in a triangular matrix"""
    return int(((n**2) + n) / 2)


dz = 3
dx = 64
N = 512
n_params = N * dz * 2 + N + dx + dz * dx + n_el(dz) + 1 + dx + n_el(dz) + dz
print(n_params)