In [None]:
import wandb

from braivest.utils import load_wandb_model
from braivest.model.emgVAE import emgVAE
import plotly.express as px
from braivest.preprocess.dataset_utils import load_data
from braivest.analysis.plotting_utils import *
import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import pickle
from sklearn.mixture import GaussianMixture
from sklearn import metrics
import seaborn as sns

In [None]:
artifact_dir = "../../train_scripts/artifacts/probe12_subject4_test2:v0"
train = load_data(artifact_dir, 'train.npy', allow_pickle=True)
test = load_data(artifact_dir, 'test.npy', allow_pickle=True)
val_hypno = load_data(artifact_dir, 'hypno.npy', allow_pickle=True)[0]


In [None]:
model =load_wandb_model("juliahwang/lfp_VAE/bhyp8z3h")
encodings = model.encode(test)
train_encodings = model.encode(train)

In [None]:
# First plot colored by hypno

hypno_unique = np.unique(val_hypno)
legend = {hypno_unique[0]:'REM',hypno_unique[1]:'SWS',hypno_unique[2]:'Wake', hypno_unique[3]: 'X'}
color_map = {'REM':"#0000ff", "Wake":"#ff0000", "SWS":"#00ff00", 'X': 'purple'}
fig = plot_encodings(encodings=encodings, color=[legend[i] for i in val_hypno], color_map=color_map, x_range=(-2, 2))
fig.show()

In [None]:

gmm_diag = GaussianMixture(n_components=3, covariance_type='full', means_init=[[-3, -1], [1, -1], [-1, 1.5]], reg_covar=1e-3, n_init=20)
gmm_diag.fit(train_encodings) 
labels=gmm_diag.predict(encodings)


In [None]:
fig = plot_encodings(encodings, color=labels, x_range=(-6, 3))
fig.show()

In [None]:
# Plot confusion matrix
val_hypno_temp = val_hypno + 3
confusion = metrics.confusion_matrix(val_hypno_temp, labels[:len(val_hypno_temp)])[:3, :3]
sums = np.sum(confusion, axis=1)[:, np.newaxis]
sns.heatmap(confusion/sums, annot=True, xticklabels=['REM', 'SWS', 'Wake'], yticklabels=['REM', 'SWS', 'Wake'], cmap = 'Blues')
plt.tick_params(axis='x', which='both', bottom=False, top=False)
plt.tick_params(axis='y', which='both', left=False)