In [None]:
from braivest.analysis.wandb_utils import load_wandb_model
from braivest.utils import load_data
from braivest.model.emgVAE import emgVAE
from braivest.preprocess.dataset_utils import bin_data, find_artifacts
from braivest.analysis.plotting_utils import *
from braivest.analysis.hmm_utils import *

import plotly.express as px
import wandb
import tensorflow as tf
import matplotlib.pyplot as plt

from ssm.hmm import MultiHMM, HMM
import ssm
from pyvis.network import Network
import plotly.graph_objects as go
import seaborn as sns
import matplotlib
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import pickle

In [None]:
with wandb.init(project="braivest_tutorial", job_type="download") as run:
    artifact = run.use_artifact("analysis_set:v0")
    artifact_dir = artifact.download()
subject0_sess = load_data(artifact_dir, 'sess_datas.npy', allow_pickle=True)

In [None]:
model =load_wandb_model("juliahwang/lfp_VAE/v2l9tltt", 31)

In [None]:
with wandb.init(project="braivest_tutorial", job_type="download") as run:
    raw_artifact = run.use_artifact("raw_data:v0")
    raw_artifact_dir = raw_artifact.download()

In [None]:
# For fitting the HMM, we want datasets that are continuous. So we need to split the data every time there is an artifact.
subject0_sessions = [0,2,4,5,6,7,8,9,10,11,12]

encodings_all = []

for sess in range(1, len(subject0_sessions)):
    lfp = load_data(raw_artifact_dir, "lfp_session{}.npy".format(0), allow_pickle=True)
    artifacts = find_artifacts(lfp)
    encodings_full = model.encode(subject0_sess[sess])
    encodings_split = np.split(encodings_full, artifacts)
    encodings_split_mod = [split[1:] for split in encodings_split if len(split) > 1]
    encodings_all.extend(encodings_split_mod)


In [None]:

#We can also load preprovided encodings split by artifact
encodings_all = np.load("subject0_visual11_encodings.npy", allow_pickle=True)

In [None]:
#First use cross validation to find the correct number of clusters
scores = []
scores_std = []
for n_clusters in range(2, 15):
    hmm, train_scores, test_scores = hmm_cross_val(n_clusters, encodings_all, n_repeats=3)
    scores.append(np.mean(test_scores))
    scores_std.append(np.std(test_scores))

In [None]:
plt.errorbar(range(2, 15),np.asarray(scores)*-1, yerr = scores_std)
plt.xlabel("Number of Components")
plt.ylabel("Negative Log Likelihood")

In [None]:
scores = np.asarray(scores)
plt.plot(range(3,15),(scores[1:]*-1 - scores[:-1]*-1)/scores[:-1]*-1)
plt.xlabel("Number of Components")
plt.ylabel("Percent change")

In [None]:
hmm = HMM(8, 2)
hmm_lls = hmm.fit(encodings_all, method="em", num_iters=50, init_method="kmeans")

In [None]:
# Load preprovided hmm for consistency
hmm = pickle.load(open('subject0_hmm.p', "rb"))


In [None]:
sess_labels = get_hmm_labels(hmm, encodings_all[:100])
fig = plot_encodings(np.concatenate(encodings_all[:100], axis=0), color=np.concatenate(sess_labels), x_range = (-6, 3))


In [None]:
fig.show()

In [None]:
legend = {0:'SWS1', 1:'AE', 2: 'REM', 3: 'T1', 4: 'SWS2',5:'SWS3', 6:'Wake', 7:'T2'}
color_map = {'REM':'#2986cc', "Wake":"#e67f38", 'AE':'#f44336', "T1":"#d3f758", 'SWS1':'#93c432','SWS2':'#789837' ,'SWS3':'#38761d', 'T2':'#ecf132'} 
fig = plot_encodings(np.concatenate(encodings_all[:100], axis=0), color=[legend[l] for l in np.concatenate(sess_labels)], color_map= color_map, x_range = (-6, 3))
fig.show()

In [None]:
sess_labels = np.concatenate(get_hmm_labels(hmm, encodings_all))
colors = [color_map[legend[s]] for s in range(8)]
inferred_durations, fig = plot_state_duration(sess_labels, 0, color=colors[0])
fig.show()

In [None]:
plot_transition_graph(8, hmm.transitions.transition_matrix, sess_labels, colors, "transition_graph.html")
