## learning kernels for the GLM example. 

we optimize kernels such that
$ K(x_n, x_0) p(\theta_n) / \tilde{p}(\theta_n) \approx 1$. 

Spoiler:
starts to work.


# approach

The above problem doesn't require MDNs at all. 
Once prior, proposal, kernel and simulator are fixed and we drew an artificial dataset $(x_n, \theta_n)$, we're good to play. 
Let's run SNPE as usual, note down the data-sets $(x_n, \theta_n)$, proposal priors and importance weights it produced over rounds, and afterwards play with the kernel on those fixed targets. 

- Remark: results look a lot worse if we convert to Students-t distributions. Could be that kernel shape (squared-exponential in $x$) has to match proposal-prior shape (squared in $\theta$ for students-T with df=3)?


### 1. basic squared loss

argmin $ \sum_n \left( 1 - \frac{K(x_n, x_0) p(\theta_n)}{\tilde{p}(\theta_n)} \right)^2 $, emphasizing the absolute value of $\approx 1$. 


In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.glm.utils as utils
import matplotlib.pyplot as plt
import numpy as np

from lfimodels.glm.GLM import GLM
from lfimodels.glm.GLMStats import GLMStats
from delfi.utils.viz import plot_pdf

%matplotlib inline

def d_kl_gauss(m1, m2, S1, S2):
   
    S2i = np.linalg.inv(S2)
    dm =  m2 - m1
   
    out  = np.trace(S2i.dot(S1))
    out += dm.dot(S2i.dot(dm))
    out -= m1.size
    out += np.prod(np.linalg.slogdet(S2)) - np.prod(np.linalg.slogdet(S1))
    return out / 2.


seeds = np.arange(90, 110)
duration = 250

posteriors_s = []

for seed in seeds:
    true_params, labels_params = utils.obs_params()
    obs = utils.obs_data(true_params, seed=seed, duration = duration)
    obs_stats = utils.obs_stats(true_params, seed=seed, duration = duration)

    sam = np.load('sam_' + str(duration) + '_' + str(seed) + '.npz')['arr_0']
    
    res  = np.load('check_kernels_d' + str(duration) + '_' + str(seed) + '.npy')[()]

    ms = sam.mean(axis=1)
    Ss = np.cov(sam)    
    posteriors_s.append(dd.Gaussian(m=ms, S=Ss))    
    
    # run with Gaussian proposals
    
    print('\n seed #' + str(seed) + '\n')
    posteriors = res['posteriors']
    posteriors_k = res['posteriors_k']
    posteriors_k2 = res['posteriors_k2']
    
    if posteriors[-1] is None:
        posteriors.pop()
    
    print('vanilla #succesful rounds: ', len(posteriors))
    print('x_kl    #succesful rounds: ', len(posteriors_k))
    print('max_ESS #succesful rounds: ', len(posteriors_k2))

    plot_pdf(posteriors[-1],
             pdf2=posteriors_k[-1], 
             resolution=100,
             lims=[-2,2], 
             samples=sam, 
             gt=true_params, 
             figsize=(9,9));


In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.glm.utils as utils
import matplotlib.pyplot as plt
import numpy as np

from lfimodels.glm.GLM import GLM
from lfimodels.glm.GLMStats import GLMStats
from delfi.utils.viz import plot_pdf

%matplotlib inline

def d_kl_gauss(m1, m2, S1, S2):
   
    S2i = np.linalg.inv(S2)
    dm =  m2 - m1
   
    out  = np.trace(S2i.dot(S1))
    out += dm.dot(S2i.dot(dm))
    out -= m1.size
    out += np.prod(np.linalg.slogdet(S2)) - np.prod(np.linalg.slogdet(S1))
    return out / 2.


seeds = np.arange(90, 110)
duration = 100

posteriors_s = []

for seed in seeds:
    true_params, labels_params = utils.obs_params()
    obs = utils.obs_data(true_params, seed=seed, duration = duration)
    obs_stats = utils.obs_stats(true_params, seed=seed, duration = duration)

    sam = np.load('sam_' + str(duration) + '_' + str(seed) + '.npz')['arr_0']
    
    res  = np.load('check_kernels_d' + str(duration) + '_' + str(seed) + '.npy')[()]

    ms = sam.mean(axis=1)
    Ss = np.cov(sam)    
    posteriors_s.append(dd.Gaussian(m=ms, S=Ss))    
    
    # run with Gaussian proposals
    
    print('\n seed #' + str(seed) + '\n')
    posteriors = res['posteriors']
    posteriors_k = res['posteriors_k']
    posteriors_k2 = res['posteriors_k2']
    
    if posteriors[-1] is None:
        posteriors.pop()
    
    print('vanilla #succesful rounds: ', len(posteriors))
    print('x_kl    #succesful rounds: ', len(posteriors_k))
    print('max_ESS #succesful rounds: ', len(posteriors_k2))



In [None]:
import seaborn 

dkls = np.zeros((len(seeds), 3))
dklsa = np.nan * np.ones((len(seeds), 3, res['n_rounds']))
BAM = np.nan * np.ones_like(dklsa)

ess = np.zeros((len(seeds), 3, res['n_rounds']))
mu_   = np.zeros((len(seeds), 3, res['n_rounds'], obs_stats.size))
sig2e = np.zeros((len(seeds), 3, res['n_rounds'], obs_stats.size))
sig2_ = np.zeros((len(seeds), 3, res['n_rounds'], obs_stats.size))

for i in range(len(seeds)):

    seed = seeds[i]
    p1 = posteriors_s[i]
    
    res  = np.load('check_kernels_d' + str(duration) + '_' + str(seed) + '.npy')[()]
    posteriors_all = [res['posteriors'], res['posteriors_k'], res['posteriors_k2']]
    tds_all = [res['tds'], res['tds_k'], res['tds_k2']]
    
    for j in range(len(posteriors_all)):
        psts =  posteriors_all[j]
        for k in range(len(psts)):
            if psts[k] is None:
                p2 = psts[k-1].xs[0]
                BAM[i,j,k-1] = d_kl_gauss(p1.m, p2.m, p1.S, p2.S)
                psts.pop()
                psts.pop()    
            else:
                p2 = psts[k].xs[0]
                dklsa[i,j,k] = d_kl_gauss(p1.m, p2.m, p1.S, p2.S)
                
                w = tds_all[j][k][2].reshape(-1,1)
                w /= w.sum()
                ess[i,j,k] = 1/np.sum(w**2)

                stats = tds_all[j][k][1]
                sig2e[i, j, k, :] = np.var( stats, axis = 0)
                mu_[  i, j, k, :] = np.sum( w * stats,                      axis=0).reshape(1,-1)
                sig2_[i, j, k, :] = np.sum( w * (stats-mu_[i, j, k, :])**2, axis=0)                   
                
        p2 = psts[-1].xs[0]
        dkls[i,j] = d_kl_gauss(p1.m, p2.m, p1.S, p2.S)


plt.figure(figsize=(16,13))

plt.subplot(2,2,1)
idx = np.argsort(dkls[:,0])[::-1]
#idx = np.arange(20)
dkls_s = dkls[idx,:]
plt.bar(np.arange(dkls_s.shape[0]),     dkls_s[:,0], 0.3, color='r')
plt.bar(np.arange(dkls_s.shape[0])+0.2, dkls_s[:,1], 0.3, color='b')
plt.bar(np.arange(dkls_s.shape[0])+0.4, dkls_s[:,2], 0.3, color='g')
plt.xlabel('#rnd seed')
plt.ylabel('final KL divergences')
plt.legend(['raw', 'x_kl', 'mESS'])

plt.subplot(2,4,3)
plt.boxplot(dkls, labels=['raw', 'x_kl', 'mESS'])

plt.subplot(2,4,4)
plt.plot(-1, -1, 'r', linewidth=1.5)
plt.plot(-1, -1, 'b', linewidth=1.5)
plt.plot(-1, -1, 'g', linewidth=1.5)
plt.plot(np.arange(1,dklsa.shape[2]+1), dklsa[:,0,:].T, 'r', linewidth=1.5)
plt.plot(np.arange(1,dklsa.shape[2]+1), dklsa[:,1,:].T, 'b', linewidth=1.5)
plt.plot(np.arange(1,dklsa.shape[2]+1), dklsa[:,2,:].T, 'g', linewidth=1.5)
plt.plot(np.arange(1,dklsa.shape[2]+1), BAM[:,0,:].T, 'ro', ms=6)
plt.plot(np.arange(1,dklsa.shape[2]+1), BAM[:,1,:].T, 'bo', ms=6)
plt.plot(np.arange(1,dklsa.shape[2]+1), BAM[:,2,:].T, 'go', ms=6)
plt.ylabel('KL divergence')
plt.xlabel('#round')
plt.legend(['raw', 'x_kl', 'mESS'])
plt.axis([0.99, dklsa.shape[2], 0, 1.1*np.nanmax(dklsa)])

plt.subplot(2,2,3)
clrs = ['r', 'b', 'g']
for j_ in np.arange(3)[::-1]:
    m, s = np.mean(ess[:,j_,1:], axis=0), np.std(ess[:,j_,1:], axis=0)
    #plt.plot(np.arange(2,ess.shape[2]+1), m+s, '--', color=clrs[j_], linewidth=1.5)
    #plt.plot(np.arange(2,ess.shape[2]+1), m-s, '--', color=clrs[j_], linewidth=1.5)
    plt.fill_between(np.arange(2,ess.shape[2]+1),
                     m-s,
                     m+s,
                     color=clrs[j_],
                     alpha=0.5
                    )
    plt.plot(np.arange(2,ess.shape[2]+1), m, clrs[j_], linewidth=2.5)    
plt.xlabel('#round')
plt.ylabel('ESS (+/- std over 20 rnd seeds)')


rs = [0, 1, 4, 9]
sp = [11, 12, 15, 16]
for r_ in range(4):
    
    r = rs[r_]
    plt.subplot(4,4,sp[r_])
    
    for j_ in np.arange(3):
        m, s = sig2_[:,j_,r,:].mean(axis=(0,)), sig2_[:,j_,r,:].std(axis=(0,))
        plt.plot(np.arange(len(m))+1, m, color=clrs[j_], linewidth=1.5)   
        #plt.fill_between(np.arange(len(m)),
        #                 m-s,
        #                 m+s,
        #                 color=clrs[j_],
        #                 alpha=0.5
        #                )    
        #plt.plot(m-s, '--', color=clrs[j_], linewidth=1.5)   
        #plt.plot(m+1, '--', color=clrs[j_], linewidth=1.5)   

    if r_ in [2,3] :
        plt.xlabel('# summary statistic')
    if r_ in [0,2]:
        plt.ylabel('empirical variance')
    plt.title('SNPE round #' + str(r+1))

    if r_==0:
        plt.legend(['raw iws', 'iws + x_kl', 'iws + max_ess', 'unweighted'], loc=2)
        
plt.savefig('summary_GLM_' + str(len(seeds)) + 'seeds_dur' + str(duration) + '.pdf')
plt.show()

In [None]:
duration

In [None]:
for seed in seeds:
    r = 0
    res  = np.load('check_kernels_d' + str(duration) + '_' + str(seed) + '.npy')[()]
    stats = res['tds'][r][1]

    plt.plot(stats.mean(axis=0))