In [None]:
import delfi.distribution as dd
import matplotlib as mpl
import numpy as np
import os
import pandas as pd
import pickle
import time

from delfi.generator import Default
from delfi.utils.viz import plot_pdf

from lfimodels.channelomics.ChannelSingle import ChannelSingle
from lfimodels.channelomics.ChannelSuper import ChannelSuper
from lfimodels.channelomics.ChannelStats import ChannelStats
from matplotlib import pyplot as plt

import sys 
sys.path.append('../../')
from model_comparison.utils import *
from model_comparison.mdns import ClassificationSingleLayerMDN, Trainer, MultivariateMogMDN, PytorchMultivariateMoG

%matplotlib inline

In [None]:
# set groung truth
GT = {'k': np.array([9, 25, 0.02, 0.002]),
      'na': np.array([-35, 9, 0.182, 0.124, -50, -75, 5, -65, 6.2, 0.0091, 0.024])}

LP = {'k': ['qa','tha','Ra','Rb'],
      'na': ['tha','qa','Ra','Rb','thi1','thi2','qi','thinf','qinf','Rg','Rd']}

E_channel = {'k': -86.7, 'na': 50}
fact_inward = {'k': 1, 'na': -1}

## Set "k" as underlying ground truth model, generate observed data 

In [None]:
channel_type = 'k'
gt = GT[channel_type]
cython = True
third_exp_model = True

n_params = len(gt)
labels_params = LP[channel_type]
prior_lims = np.sort(np.concatenate((0.5 * gt.reshape(-1,1), 1.5 * gt.reshape(-1,1)), axis=1))

In [None]:
m = ChannelSuper(channel_type=channel_type, third_exp_model=third_exp_model, cython=cython)
p = dd.Uniform(lower=prior_lims[:,0], upper=prior_lims[:,1])
s = ChannelStats(channel_type=channel_type)

In [None]:
# generate observed data
n_params_obs = len(gt)
m_obs = ChannelSingle(channel_type=channel_type, n_params=n_params_obs, cython=cython)
xo = m_obs.gen(gt.reshape(1,-1))
xo_stats = s.calc(xo[0])

## Load training data 

In [None]:
filename = 'training_data_k_na_N10000seed1.p'
folder = '../data/'
fullpath = os.path.join(folder, filename)

with open(fullpath, 'rb') as f: 
    result_dict = pickle.load(f)

In [None]:
params_k, sx_k, gt_k, prior_lims_k, params_na, sx_na, gt_na, prior_lims_na, seed, n_samples, cython = result_dict.values()

ntrain, n_stats = sx_k.shape

# shuffle and set up model index target vector 
sx = np.vstack((sx_k, sx_na))

# define model indices
m = np.hstack((np.zeros(ntrain), np.ones(ntrain))).squeeze().astype(int)

# shuffle data
shuffle_indices = np.arange(ntrain)
np.random.shuffle(shuffle_indices)
sx = sx[shuffle_indices,]
m = m[shuffle_indices].tolist()

# normalize
sx, training_norm = normalize(sx)

## Set up the NN and train it 

In [None]:
model_models = ClassificationSingleLayerMDN(ndim_input=n_stats, n_hidden=10)
optimizer = torch.optim.Adam(model_models.parameters(), lr=0.01)
trainer = Trainer(model_models, optimizer, verbose=True, classification=True)

n_epochs = 100 
n_minibatch = int(ntrain / 100)

# train with training data
loss_trace = trainer.train(sx, m, n_epochs=n_epochs, n_minibatch=n_minibatch)

In [None]:
plt.figure(figsize=(18, 3))
plt.plot(loss_trace[:100])
plt.ylabel('loss')
plt.xlabel('iterations');

## Predict underlying model given observed data

In [None]:
# predict 
# normalize using training data normalization 
sx_obs, training_norm = normalize(xo_stats.squeeze(), training_norm)

softmax = nn.Softmax(dim=0)
out_act = model_models(Variable(torch.Tensor(sx_obs)))
p_vec = softmax(out_act).data.numpy()
print('P(K | sx) = {:.2f}'.format(p_vec[0]))

## Given the predicted underlying model we can learn the posterior of its parameters

In [None]:
# define a network to approximate the posterior with a MoG 
model_params = MultivariateMogMDN(ndim_input=n_stats, ndim_output=params_k.shape[1], n_hidden=10, n_components=3)
optimizer = torch.optim.Adam(model_params.parameters(), lr=0.01)
trainer = Trainer(model_params, optimizer, verbose=True)

In [None]:
sx_k_normed, training_norm = normalize(sx_k)

In [None]:
loss_trace = trainer.train(sx_k_normed, params_k, n_epochs=100, n_minibatch=int(ntrain / 200))

In [None]:
plt.figure(figsize=(18, 3))
plt.plot(loss_trace[:100])
plt.ylabel('loss')
plt.xlabel('iterations');

In [None]:
sx_obs, training_norm = normalize(xo_stats.squeeze(), training_norm)
sx_obs_pt = Variable(torch.Tensor(sx_obs.reshape(1, -1)))

In [None]:
post1 = PytorchMultivariateMoG(*model_params(sx_obs_pt))

In [None]:
alphas = post1.alphas.data.numpy()
mus = post1.mus[0, :, np.argmax(alphas)].data.numpy()

In [None]:
post1.alphas

In [None]:
print([t for t in gt])

In [None]:
print([t for t in mus])