In [None]:
import re
import sys
import os
import glob

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import pandas as pd
import tensorflow as tf

sys.path.append("../../ndsvae")
import ndsvae as ndsv

sys.path.append("..")
import util
import plotutils as pu


%matplotlib inline

In [None]:
conn = "linw"
preproc = "dicer"
modelname = "AB"
config = "ns_3_mreg_3_msub_0_nf_32"
run = util.select_run_fc(f"hcp100_{conn}_{preproc}", modelname, config, [0,1], "hcp")

In [None]:
hist = pd.read_csv(f"../run/hcp/hcp100_{conn}_{preproc}/model{modelname}/{config}/run{run:02d}/fit/hist.csv")

In [None]:
hist

In [None]:
plt.figure(figsize=(14,8), dpi=100)

plt.plot(hist.epoch, hist.loss)
plt.plot(hist.epoch, hist.loss_test)
plt.ylim(1400, 1600)

#  plt.axvline(epoch, ls='--', color='0.5')
plt.grid()

## ELBO for train and test samples

In [None]:
ds = ndsv.Dataset.from_file(f"../run/hcp/hcp100_{conn}_{preproc}/dataset.npz")
direc = f"../run/hcp/hcp100_{conn}_{preproc}/model{modelname}/{config}/run{run:02d}"
train_mask = np.load(f"{direc}/fit/train_mask.npy")

model = util.get_model(modelname, config, ds)
model.load_weights(os.path.join(direc, "fit/model"))

batch_size = ds.nreg
dataset = ndsv.training._prep_training_dataset(ds, batch_size, model.training_mode, model.upsample_factor,
                                               mask=np.ones((ds.nsub, ds.nreg), dtype=bool), shuffle=False)

In [None]:
nsamples = 8
elbos = np.zeros((ds.nsub, ds.nreg, nsamples))
for i, batch in enumerate(dataset):
    print(i, end=' ', flush=True)
    
    model.loss(batch, nsamples=nsamples, betax=1.0, betap=1.0)
    elbos[i,:,:] = model.elbo.numpy()
    
elbo = np.mean(elbos, axis=2)

In [None]:
rows = []
for i in range(ds.nsub):
    for j in range(ds.nreg):
        rows.append(dict(sub=i, reg=j, train=train_mask[i,j], elbo=elbo[i,j]))
df = pd.DataFrame(rows)

## Figure

In [None]:
plt.rcParams['font.family'] = "sans-serif"
plt.rcParams['font.sans-serif'] = "Arial"

SMALL_SIZE = 8
MEDIUM_SIZE = 10
BIGGER_SIZE = 12

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=SMALL_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [None]:
plt.figure(figsize=(3,2), dpi=150)
sns.violinplot(data=df, y='elbo', x='train', zorder=10)
plt.grid(axis='y')
plt.ylabel("Datapoint ELBO")
plt.xlabel("")
plt.xticks([0,1], [f"Train set\n(n = {np.sum(df.train)})", f"Test set\n(n = {np.sum(~df.train)})"])

plt.gca().set_axisbelow(True)
pu.bottomleft_spines(plt.gca())
plt.tight_layout()

plt.savefig("img/Fig_HCP-overfitting.pdf")

In [None]:
np.mean(df.elbo[df.train]), np.mean(df.elbo[~df.train])