## learning kernels for the GLM example. 

we optimize kernels such that
$ K(x_n, x_0) p(\theta_n) / \tilde{p}(\theta_n) \approx 1$. 

Spoiler:
starts to work.


# approach

The above problem doesn't require MDNs at all. 
Once prior, proposal, kernel and simulator are fixed and we drew an artificial dataset $(x_n, \theta_n)$, we're good to play. 
Let's run SNPE as usual, note down the data-sets $(x_n, \theta_n)$, proposal priors and importance weights it produced over rounds, and afterwards play with the kernel on those fixed targets. 

- Remark: results look a lot worse if we convert to Students-t distributions. Could be that kernel shape (squared-exponential in $x$) has to match proposal-prior shape (squared in $\theta$ for students-T with df=3)?


### 1. basic squared loss

argmin $ \sum_n \left( 1 - \frac{K(x_n, x_0) p(\theta_n)}{\tilde{p}(\theta_n)} \right)^2 $, emphasizing the absolute value of $\approx 1$. 


In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.glm.utils as utils
import matplotlib.pyplot as plt
import numpy as np

from lfimodels.glm.GLM import GLM
from lfimodels.glm.GLMStats import GLMStats
from delfi.utils.viz import plot_pdf

%matplotlib inline

seeds = np.arange(90, 110)
duration = 100

for seed in seeds:
    true_params, labels_params = utils.obs_params()
    obs = utils.obs_data(true_params, seed=seed, duration = duration)
    obs_stats = utils.obs_stats(true_params, seed=seed, duration = duration)

    rerun = True  # if False, will try loading file from disk

    try:
        assert rerun == False, 'rerun requested'
        sam = np.load('sam_' + str(duration) + '_' + str(seed) + '.npz')['arr_0']
    except:
        sam = utils.pg_mcmc(true_params, obs)
        np.savez('sam_' + str(duration) + '_' + str(seed) + '.npz', sam)

    n_train = 5000
    n_rounds = 10
    minibatch = 100
    epochs = 500
    round_cl = 999

    n_hiddens=[50] 
    convert_to_T=None 
    pilot_samples=0
    svi=True
    reg_lambda=0.01
    prior_norm=False
    
    
    
    m = GLM(seed=seed, duration = duration)
    p = utils.smoothing_prior(n_params=m.n_params, seed=seed)
    s = GLMStats(n_summary=m.n_params)
    g = dg.Default(model=m, prior=p, summary=s)

    res = infer.SNPE(g, 
                     obs=obs_stats, 
                     n_hiddens=n_hiddens, 
                     seed=seed, 
                     convert_to_T=convert_to_T, 
                     pilot_samples=pilot_samples,
                     svi=svi,
                     reg_lambda=reg_lambda,
                     prior_norm=prior_norm)

    logs, tds, posteriors = res.run(n_train=n_train, 
                                    n_rounds=n_rounds, 
                                    minibatch=minibatch, 
                                    epochs=epochs, 
                                    round_cl=round_cl, 
                                    kernel_loss=None)

    m = GLM(seed=seed, duration = duration)
    p = utils.smoothing_prior(n_params=m.n_params, seed=seed)
    s = GLMStats(n_summary=m.n_params)
    g = dg.Default(model=m, prior=p, summary=s)
    res_k = infer.SNPE(g, 
                     obs=obs_stats, 
                     n_hiddens=n_hiddens, 
                     seed=seed, 
                     convert_to_T=convert_to_T, 
                     pilot_samples=pilot_samples,
                     svi=svi,
                     reg_lambda=reg_lambda,
                     prior_norm=prior_norm)

    logs_k, tds_k, posteriors_k = res_k.run(n_train=n_train, 
                                    n_rounds=n_rounds, 
                                    minibatch=minibatch, 
                                    epochs=epochs, 
                                    round_cl=round_cl, 
                                    kernel_loss='x_kl')

    m = GLM(seed=seed, duration = duration)
    p = utils.smoothing_prior(n_params=m.n_params, seed=seed)
    s = GLMStats(n_summary=m.n_params)
    g = dg.Default(model=m, prior=p, summary=s)
    res_k2 = infer.SNPE(g, 
                     obs=obs_stats, 
                     n_hiddens=n_hiddens, 
                     seed=seed, 
                     convert_to_T=convert_to_T, 
                     pilot_samples=pilot_samples,
                     svi=svi,
                     reg_lambda=reg_lambda,
                     prior_norm=prior_norm)

    logs_k2, tds_k2, posteriors_k2 = res_k2.run(n_train=n_train, 
                                    n_rounds=n_rounds, 
                                    minibatch=minibatch, 
                                    epochs=epochs, 
                                    round_cl=round_cl, 
                                    kernel_loss='basic')
    
    np.save('check_kernels_d' + str(duration) + '_' + str(seed), 
            {'seed': seed,
             'duration' : duration, 
             'n_train' : n_train,
             'n_rounds' : n_rounds,
             'minibatch' : minibatch,
             'epochs' : minibatch,

             'n_hiddens' : [50], 
             'convert_to_T' : None, 
             'pilot_samples' : 0,
             'svi' : True,
             'reg_lambda': 0.01,
             'prior_norm':False,             
             'round_cl' : 999, 
             
             'obs_stats' : obs_stats,
             'true_params' : true_params,

             'logs' : logs, 
             'logs_k' : logs_k, 
             'logs_k2' : logs_k2, 
             'tds' : tds,
             'tds_k' : tds_k,
             'tds_k2' : tds_k2,
             'posteriors' : posteriors,
             'posteriors_k' : posteriors_k,
             'posteriors_k2' : posteriors_k2

             })

In [None]:
# run with Gaussian proposals
for r in range(len(tds_k)):
    plot_pdf(posteriors[r],
             pdf2=posteriors_k[r], 
             lims=[-2,2], 
             samples=sam, 
             gt=true_params, 
             figsize=(9,9));

In [None]:
# run with Gaussian proposals
for r in range(len(tds_k)):
    plot_pdf(posteriors[r],
             pdf2=posteriors_k2[r], 
             lims=[-2,2], 
             samples=sam, 
             gt=true_params, 
             figsize=(9,9));