## learning kernels for the GLM example. 

we optimize kernels such that
$ K(x_n, x_0) p(\theta_n) / \tilde{p}(\theta_n) \approx 1$. 

Spoiler:
starts to work.


# approach

The above problem doesn't require MDNs at all. 
Once prior, proposal, kernel and simulator are fixed and we drew an artificial dataset $(x_n, \theta_n)$, we're good to play. 
Let's run SNPE as usual, note down the data-sets $(x_n, \theta_n)$, proposal priors and importance weights it produced over rounds, and afterwards play with the kernel on those fixed targets. 

- Remark: results look a lot worse if we convert to Students-t distributions. Could be that kernel shape (squared-exponential in $x$) has to match proposal-prior shape (squared in $\theta$ for students-T with df=3)?


### 1. basic squared loss

argmin $ \sum_n \left( 1 - \frac{K(x_n, x_0) p(\theta_n)}{\tilde{p}(\theta_n)} \right)^2 $, emphasizing the absolute value of $\approx 1$. 


In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.glm.utils as utils
import matplotlib.pyplot as plt
import numpy as np

from lfimodels.glm.GLM import GLM
from lfimodels.glm.GLMStats import GLMStats
from delfi.utils.viz import plot_pdf

%matplotlib inline

def d_kl_gauss(m1, m2, S1, S2):
   
    S2i = np.linalg.inv(S2)
    dm =  m2 - m1
   
    out  = np.trace(S2i.dot(S1))
    out += dm.dot(S2i.dot(dm))
    out -= m1.size
    out += np.prod(np.linalg.slogdet(S2)) - np.prod(np.linalg.slogdet(S1))
    return out / 2.


seeds = np.arange(90, 110)
duration = 300
batch_specifier = '_cp'
fit_specifier = '_5dim' #'_5dim'
path = '' #'a_05/' #'glm_runs_badprior/'

posteriors_s = []

for seed in seeds:
    
    res  = np.load(path+'check_higherOrder_kernels' + batch_specifier + fit_specifier + '_d' + str(duration) + '_' + str(seed) + '.npy')[()]

    true_params = res['true_params']
    obs_stats   = res['obs_stats']
    
    sam = np.load(path+'sam_' + str(duration) + '_' + str(seed) + batch_specifier + '.npz')['arr_0']
    ms = sam.mean(axis=1)
    Ss = np.cov(sam)    
    posteriors_s.append(dd.Gaussian(m=ms, S=Ss))    
    
    # run with Gaussian proposals
    
    print('\n seed #' + str(seed) + '\n')
    posteriors = res['posteriors']
    posteriors_k = res['posteriors_ki']
    posteriors_k2 = res['posteriors_kh']
    
    if posteriors[-1] is None:
        posteriors.pop()
    
    print('vanilla #succesful rounds: ', len(posteriors))
    print('x_kl    #succesful rounds: ', len(posteriors_k))
    print('max_ESS #succesful rounds: ', len(posteriors_k2))
    
    try:
        fig,_ = plot_pdf(posteriors[-1],
                 pdf2=posteriors_k[-1], 
                 resolution=100,
                 lims=[-2,2], 
                 samples=sam, 
                 gt=true_params, 
                 figsize=(9,9));
        fig.savefig('finalPosteriors_rawVsKernel' + '_ki' + '_dur' + str(duration) + batch_specifier + fit_specifier + str(seed) +'.pdf')
        fig,_ = plot_pdf(posteriors[-1],
                 pdf2=posteriors_k2[-1], 
                 resolution=100,
                 lims=[-2,2], 
                 samples=sam, 
                 gt=true_params, 
                 figsize=(9,9));
        fig.savefig('finalPosteriors_rawVsKernel' + '_kh' + '_dur' + str(duration) + batch_specifier + fit_specifier + str(seed) +'.pdf')
        
    except:
        print('printing broke !')


In [None]:
import seaborn 

labels=['raw', 'x_kl_input', 'x_kl_hidden']
mks = np.max((res['logs_ki'][-1]['cbkrnl'].A.size,
              res['logs_kh'][-1]['cbkrnl'].A.size))

dkls = np.zeros((len(seeds), 3))
dklsa = np.nan * np.ones((len(seeds), 3, res['n_rounds']))
BAM   = np.nan * np.ones_like(dklsa)

ess = np.nan*np.ones((len(seeds), 3, res['n_rounds']))
mu_   = np.nan*np.ones((len(seeds), 3, res['n_rounds'], obs_stats.size))
sig2e = np.nan*np.ones((len(seeds), 3, res['n_rounds'], obs_stats.size))
sig2_ = np.nan*np.ones((len(seeds), 3, res['n_rounds'], obs_stats.size))

MSEs =  np.nan*np.ones((len(seeds), 3, res['n_rounds'], res['true_params'].size, 2))
invwidths = np.nan*np.ones((len(seeds), 3, res['n_rounds'], mks))
clrs = ['r', 'b', 'g']

for i in range(len(seeds)):

    seed = seeds[i]
    p1 = posteriors_s[i]

    res  = np.load(path+'check_higherOrder_kernels' + batch_specifier + fit_specifier + '_d' + str(duration) + '_' + str(seed) + '.npy')[()]
    posteriors_all = [res['posteriors'], res['posteriors_ki'], res['posteriors_kh']]
    tds_all = [res['tds'], res['tds_ki'], res['tds_kh']]
    logs_all = [res['logs'], res['logs_ki'], res['logs_kh'] ]
        
    for j in range(len(posteriors_all)):
        psts =  posteriors_all[j]
        for k in range(len(psts)):
            if psts[k] is None:
                
                p2 = psts[k-1].xs[0]
                
                BAM[i,j,k-1] = d_kl_gauss(p1.m, p2.m, p1.S, p2.S)
                dklsa[i,j,k-1] = d_kl_gauss(p1.m, p2.m, p1.S, p2.S)

                if len(psts) > 2:
                    psts.pop()
                if len(psts) > 1:
                    psts.pop()    
                
        for k in range(len(psts)):
            #else:
            p2 = psts[k].xs[0]
            dklsa[i,j,k] = d_kl_gauss(p1.m, p2.m, p1.S, p2.S)

            w = tds_all[j][k][2].reshape(-1,1)
            w /= w.sum()
            ess[i,j,k] = 1/np.sum(w**2)

            stats = tds_all[j][k][1]
            sig2e[i, j, k, :] = np.var( stats, axis = 0)
            mu_[  i, j, k, :] = np.sum( w * stats,                      axis=0).reshape(1,-1)
            sig2_[i, j, k, :] = np.sum( w * (stats-mu_[i, j, k, :])**2, axis=0)                   

            MSEs[i,j,k,:,0] = (         psts[k].xs[0].m  -         posteriors_s[i].m  )**2
            MSEs[i,j,k,:,1] = ( np.diag(psts[k].xs[0].S) - np.diag(posteriors_s[i].S) )**2

        lgs =  logs_all[j]
        for k in range(len(psts)):
            if not lgs[k]['cbkrnl'] is None:                
                invwidths[i,j,k,:lgs[k]['cbkrnl'].A.size] = lgs[k]['cbkrnl'].A         
                pass
            
        p2 = psts[-1].xs[0]
        #p2 = psts[np.min((len(psts)-1, res['round_cl']-2))].xs[0]
        dkls[i,j] = d_kl_gauss(p1.m, p2.m, p1.S, p2.S)

RMSEs = np.sqrt(MSEs.sum(axis=3))        
mRMSEs = np.nanmean(RMSEs, axis=0) # avg over seeds


plt.figure(figsize=(24,21))
plt.subplot(3,4,1)

pu = np.zeros((3, res['n_rounds']))
for j in range(1,3):
    for k in range(res['n_rounds']):
        tmp = invwidths[:,j,k,:].reshape(-1)
        tmp = tmp[np.invert(np.isnan(tmp))]
        ks = k * np.ones_like(tmp)
        pu[j,k] = np.nanmean(tmp > 10**(-4))
    plt.plot(pu[j,:], 'o-', color=clrs[j])
plt.xlabel('#round')
plt.ylabel('fraction i : A_ii > 0.0001')
plt.title('usage of kernel (# non-flat kernel dims)')
plt.legend(labels[1:], loc=5)

plt.subplot(6,4,2)
tmp = invwidths[:,1,:,:]
tmp = tmp[np.invert(np.isnan(tmp))]
plt.hist(np.maximum( np.log(tmp), 1e-20 ), bins=20, color=clrs[1]);
plt.yticks([])
plt.ylabel(labels[1])
plt.title('distributions of learned kernel params')

plt.subplot(6,4,6)
tmp = invwidths[:,2,:,:]
tmp = tmp[np.invert(np.isnan(tmp))]
plt.hist(np.maximum( np.log(tmp), 1e-20 ), bins=20, color=clrs[2]);
plt.yticks([])
plt.ylabel(labels[2])
plt.xlabel( 'log(A_ii)' )


plt.subplot(3,4,3)
for j in range(len(posteriors_all)):
    plt.semilogy(range(1, mRMSEs.shape[1]+1), RMSEs[:,j,:,0].T, '--', alpha=0.35, color=clrs[j])
    plt.semilogy(range(1, mRMSEs.shape[1]+1),mRMSEs[j,:,0], color=clrs[j], linewidth=2.5)
plt.title('error in posterior means')
plt.ylabel('Eucl. distance (avg. over seeds)')
plt.legend(labels, loc=1)
plt.xlabel('#round')

plt.subplot(3,4,4)
for j in range(len(posteriors_all)):
    plt.semilogy(range(1, mRMSEs.shape[1]+1),mRMSEs[j,:,1], color=clrs[j], linewidth=2.5)
    plt.semilogy(range(1, mRMSEs.shape[1]+1), RMSEs[:,j,:,1].T, '--', alpha=0.35, color=clrs[j])
plt.title('error in posterior variances (diag(S))')
plt.xlabel('#round')

plt.subplot(3,2,3)
idx = np.argsort(dkls[:,0])[::-1]
#idx = np.arange(len(seeds))
dkls_s = dkls[idx,:]
plt.bar(np.arange(dkls_s.shape[0]),     dkls_s[:,0], 0.3, color='r')
plt.bar(np.arange(dkls_s.shape[0])+0.2, dkls_s[:,1], 0.3, color='b')
plt.bar(np.arange(dkls_s.shape[0])+0.4, dkls_s[:,2], 0.3, color='g')
plt.xlabel('#rnd seed')
plt.ylabel('final KL divergences')
plt.legend(labels)
plt.title('distance of posteriors to ground-truth')

plt.subplot(3,4,7)
if dkls.shape[0] > 1:
    plt.boxplot(dkls, labels=labels)
plt.title('distribution of final D_KL')

plt.subplot(3,4,8)
plt.plot(-1, -1, 'r', linewidth=1.5)
plt.plot(-1, -1, 'b', linewidth=1.5)
plt.plot(-1, -1, 'g', linewidth=1.5)
for j_ in np.arange(3)[::-1]:

    plt.plot(np.arange(1,dklsa.shape[2]+1), dklsa[:,j_,:].T, '--', color=clrs[j_], alpha=0.35)
    plt.plot(np.arange(1,dklsa.shape[2]+1), BAM[:,j_,:].T, 'o', color=clrs[j_], ms=6, alpha=0.35)
    m, s = np.nanmean(dklsa[:,j_,:], axis=0), np.nanstd(dklsa[:,j_,:], axis=0)
    plt.plot(np.arange(1,dklsa.shape[2]+1), m, clrs[j_], linewidth=2.5)    

plt.ylabel('KL divergence')
plt.xlabel('#round')
plt.legend(labels)
plt.axis([0.99, dklsa.shape[2], 0, 1.1*np.nanmax(dklsa)])
plt.title('posterior distance over rounds')

plt.subplot(3,2,5)
for j_ in np.arange(3)[::-1]:
    m, s = np.nanmean(ess[:,j_,1:], axis=0), np.nanstd(ess[:,j_,1:], axis=0)
    plt.plot(np.arange(2,ess.shape[2]+1), m, clrs[j_], linewidth=2.5)    
    plt.plot(np.arange(2,ess.shape[2]+1), ess[:,j_,1:].T, '--', alpha=0.35, color=clrs[j_])
plt.xlabel('#round')
plt.ylabel('ESS (avg: thick line)')
plt.title('IS weight variance: effective sample sizes')

if res['n_rounds'] > 9:
    rs = [0, 1, 4, 9]
else: 
    rs = [0, 1, 2, 3]
sp = [19, 20, 23, 24]
for r_ in range(4):
    
    r = rs[r_]
    plt.subplot(6,4,sp[r_])
    
    for j_ in np.arange(3):
        m, s = np.nanmean(sig2_[:,j_,r,:],axis=(0,)), np.nanmean(sig2_[:,j_,r,:], axis=(0,))
        plt.plot(np.arange(len(m))+1, m, color=clrs[j_], linewidth=1.5)   
        
    for j_ in np.arange(3):
        plt.plot(np.arange(len(m))+1, sig2e[:,j_,r,:].mean(axis=0,), 
                '--', color=clrs[j_], linewidth=1.2)

    if r_ in [2,3] :
        plt.xlabel('# summary statistic')
    if r_ in [0,2]:
        plt.ylabel('empirical variance')
    plt.title('SNPE round #' + str(r+1))

    if r_==0:
        plt.legend(['raw iws', 'iws + x_kl', 'iws + mESS'], loc=1)
        
plt.savefig('summary_GLM' + batch_specifier + fit_specifier +'_'+ str(len(seeds)) + 'seeds_dur' + str(duration) + '.pdf')
plt.show()

# figures on weighted summary statistics

In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.glm.utils as utils
import matplotlib.pyplot as plt
import numpy as np

from lfimodels.glm.GLM import GLM
from lfimodels.glm.GLMStats import GLMStats
from delfi.utils.viz import plot_pdf

%matplotlib inline

def d_kl_gauss(m1, m2, S1, S2):
   
    S2i = np.linalg.inv(S2)
    dm =  m2 - m1
   
    out  = np.trace(S2i.dot(S1))
    out += dm.dot(S2i.dot(dm))
    out -= m1.size
    out += np.prod(np.linalg.slogdet(S2)) - np.prod(np.linalg.slogdet(S1))
    return out / 2.


seeds = np.arange(90, 110)
duration = 300
batch_specifier = '_cp'
fit_specifier = '_5dim' #'_5dim'
path = '' #'a_05/' #'glm_runs_badprior/'

posteriors_s = []

for seed in seeds:
    
    res  = np.load(path+'check_higherOrder_kernels' + batch_specifier + fit_specifier + '_d' + str(duration) + '_' + str(seed) + '.npy')[()]

    true_params = res['true_params']
    obs_stats   = res['obs_stats']
    
    sam = np.load(path+'sam_' + str(duration) + '_' + str(seed) + batch_specifier + '.npz')['arr_0']
    ms = sam.mean(axis=1)
    Ss = np.cov(sam)    
    posteriors_s.append(dd.Gaussian(m=ms, S=Ss))    
    
    # run with Gaussian proposals
    
    print('\n seed #' + str(seed) + '\n')
    posteriors = res['posteriors']
    posteriors_k = res['posteriors_ki']
    posteriors_k2 = res['posteriors_kh']
    
    if posteriors[-1] is None:
        posteriors.pop()
    
    print('vanilla #succesful rounds: ', len(posteriors))
    print('x_kl    #succesful rounds: ', len(posteriors_k))
    print('max_ESS #succesful rounds: ', len(posteriors_k2))
    

    r = res['n_rounds']-1
    tds, tds_k = res['tds'], res['tds_ki']
    stats, w = tds[r][1], tds[r][2]
    w /= w.sum()
    px = dd.Gaussian(m=w.dot(stats),
                     S=np.cov(stats.T, aweights=w))
    stats, w = tds_k[r][1], tds_k[r][2]
    w /= w.sum()
    px_k = dd.Gaussian(m=w.dot(stats),
                     S=np.cov(stats.T, aweights=w))
    stats = tds_k[r][1]
    fig,_ = plot_pdf(px, pdf2=px_k, samples=stats.T, lims = [-2, 2], resolution=100,
             figsize=(16,16));        
    fig.savefig('finalWeightedStats_rawVsKernel' + '_ki' + '_dur' + str(duration) + batch_specifier + fit_specifier + str(seed) +'.pdf')

    tds, tds_k = res['tds'], res['tds_ki']
    stats, w = tds[r][1], tds[r][2]
    w /= w.sum()
    px = dd.Gaussian(m=w.dot(stats),
                     S=np.cov(stats.T, aweights=w))
    stats, w = tds_k[r][1], tds_k[r][2]
    w /= w.sum()
    px_k = dd.Gaussian(m=w.dot(stats),
                     S=np.cov(stats.T, aweights=w))
    stats = tds_k[r][1]
    fig,_ = plot_pdf(px, pdf2=px_k, samples=stats.T, lims = [-2, 2], resolution=100,
             figsize=(16,16));            
    fig.savefig('finalWeightedStats_rawVsKernel' + '_kh' + '_dur' + str(duration) + batch_specifier + fit_specifier + str(seed) +'.pdf')
    

# that one bad one...

In [None]:

seed = 99
res  = np.load(path+'check_higherOrder_kernels' + batch_specifier + fit_specifier + '_d' + str(duration) + '_' + str(seed) + '.npy')[()]

true_params = res['true_params']
obs_stats   = res['obs_stats']

sam = np.load(path+'sam_' + str(duration) + '_' + str(seed) + batch_specifier + '.npz')['arr_0']
ms = sam.mean(axis=1)
Ss = np.cov(sam)    
posteriors_s.append(dd.Gaussian(m=ms, S=Ss))    

# run with Gaussian proposals

print('\n seed #' + str(seed) + '\n')
posteriors = res['posteriors']
posteriors_k = res['posteriors_ki']
posteriors_k2 = res['posteriors_kh']

if posteriors[-1] is None:
    posteriors.pop()

print('vanilla #succesful rounds: ', len(posteriors))
print('x_kl    #succesful rounds: ', len(posteriors_k))
print('max_ESS #succesful rounds: ', len(posteriors_k2))

for r in range(5):
    plot_pdf(posteriors[r],
             pdf2=posteriors_k[r], 
             resolution=100,
             lims=[-2,2], 
             samples=sam, 
             gt=true_params, 
             figsize=(9,9));
for r in range(5):
    plot_pdf(posteriors[r],
             pdf2=posteriors_k2[r], 
             resolution=100,
             lims=[-2,2], 
             samples=sam, 
             gt=true_params, 
             figsize=(9,9));


# previous sweeps

In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.glm.utils as utils
import matplotlib.pyplot as plt
import numpy as np

from lfimodels.glm.GLM import GLM
from lfimodels.glm.GLMStats import GLMStats
from delfi.utils.viz import plot_pdf

%matplotlib inline

def d_kl_gauss(m1, m2, S1, S2):
   
    S2i = np.linalg.inv(S2)
    dm =  m2 - m1
   
    out  = np.trace(S2i.dot(S1))
    out += dm.dot(S2i.dot(dm))
    out -= m1.size
    out += np.prod(np.linalg.slogdet(S2)) - np.prod(np.linalg.slogdet(S1))
    return out / 2.


seeds = np.arange(90, 110)
duration = 300
batch_specifier = '_cp'
fit_specifier = '_pilot_convT_small_05'
path = '' #'a_05/' #'glm_runs_badprior/'

posteriors_s = []

for seed in seeds:
    
    true_params, labels_params = utils.obs_params()
    obs = utils.obs_data(true_params, seed=seed, duration = duration)
    obs_stats = utils.obs_stats(true_params, seed=seed, duration = duration)

    res  = np.load(path+'check_kernels' + batch_specifier + fit_specifier + '_d' + str(duration) + '_' + str(seed) + '.npy')[()]

    #assert np.all( true_params == res['true_params'] )
    #assert np.all( obs_stats == res['obs_stats'] )
    
    sam = np.load(path+'sam_' + str(duration) + '_' + str(seed) + batch_specifier + '.npz')['arr_0']
    ms = sam.mean(axis=1)
    Ss = np.cov(sam)    
    posteriors_s.append(dd.Gaussian(m=ms, S=Ss))    
    
    # run with Gaussian proposals
    
    print('\n seed #' + str(seed) + '\n')
    posteriors = res['posteriors']
    posteriors_k = res['posteriors_k']
    posteriors_k2 = res['posteriors_k2']
    
    if posteriors[-1] is None:
        posteriors.pop()
    
    print('vanilla #succesful rounds: ', len(posteriors))
    print('x_kl    #succesful rounds: ', len(posteriors_k))
    print('max_ESS #succesful rounds: ', len(posteriors_k2))
    
    try:
        bla
        plot_pdf(posteriors[-1],
                 pdf2=posteriors_k[-1], 
                 resolution=100,
                 lims=[-2,2], 
                 samples=sam, 
                 gt=true_params, 
                 figsize=(9,9));
    except:
        print('printing broke !')

import seaborn 

dkls = np.zeros((len(seeds), 3))
dklsa = np.nan * np.ones((len(seeds), 3, res['n_rounds']))
BAM   = np.nan * np.ones_like(dklsa)

ess = np.nan*np.ones((len(seeds), 3, res['n_rounds']))
mu_   = np.nan*np.ones((len(seeds), 3, res['n_rounds'], obs_stats.size))
sig2e = np.nan*np.ones((len(seeds), 3, res['n_rounds'], obs_stats.size))
sig2_ = np.nan*np.ones((len(seeds), 3, res['n_rounds'], obs_stats.size))

MSEs =  np.nan*np.ones((len(seeds), 3, res['n_rounds'], res['true_params'].size, 2))
invwidths = np.nan*np.ones((len(seeds), 3, res['n_rounds'], obs_stats.size))
clrs = ['r', 'b', 'g']

for i in range(len(seeds)):

    seed = seeds[i]
    p1 = posteriors_s[i]

    res  = np.load(path+'check_kernels' + batch_specifier + fit_specifier + '_d' + str(duration) + '_' + str(seed) + '.npy')[()]
    posteriors_all = [res['posteriors'], res['posteriors_k'], res['posteriors_k2']]
    tds_all = [res['tds'], res['tds_k'], res['tds_k2']]
    logs_all = [res['logs'], res['logs_k'], res['logs_k2'] ]
        
    for j in range(len(posteriors_all)):
        psts =  posteriors_all[j]
        for k in range(len(psts)):
            if psts[k] is None:
                
                p2 = psts[k-1].xs[0]
                
                BAM[i,j,k-1] = d_kl_gauss(p1.m, p2.m, p1.S, p2.S)
                dklsa[i,j,k-1] = d_kl_gauss(p1.m, p2.m, p1.S, p2.S)

                if len(psts) > 2:
                    psts.pop()
                if len(psts) > 1:
                    psts.pop()    
                
        for k in range(len(psts)):
            #else:
            p2 = psts[k].xs[0]
            dklsa[i,j,k] = d_kl_gauss(p1.m, p2.m, p1.S, p2.S)

            w = tds_all[j][k][2].reshape(-1,1)
            w /= w.sum()
            ess[i,j,k] = 1/np.sum(w**2)

            stats = tds_all[j][k][1]
            sig2e[i, j, k, :] = np.var( stats, axis = 0)
            mu_[  i, j, k, :] = np.sum( w * stats,                      axis=0).reshape(1,-1)
            sig2_[i, j, k, :] = np.sum( w * (stats-mu_[i, j, k, :])**2, axis=0)                   

            MSEs[i,j,k,:,0] = (         psts[k].xs[0].m  -         posteriors_s[i].m  )**2
            MSEs[i,j,k,:,1] = ( np.diag(psts[k].xs[0].S) - np.diag(posteriors_s[i].S) )**2

        lgs =  logs_all[j]
        for k in range(len(psts)):
            if not lgs[k]['cbkrnl'] is None:                
                invwidths[i,j,k,:] = lgs[k]['cbkrnl'].A         
            
        p2 = psts[-1].xs[0]
        #p2 = psts[np.min((len(psts)-1, res['round_cl']-2))].xs[0]
        dkls[i,j] = d_kl_gauss(p1.m, p2.m, p1.S, p2.S)

RMSEs = np.nanmean(np.sqrt(MSEs.sum(axis=3)), axis=0) # avg over seeds
#sig2_ = sig2_ / sig2e[:,:,0,:].reshape(20,3,1,10)
#sig2e = sig2e / sig2e[:,:,0,:].reshape(20,3,1,10)

plt.figure(figsize=(24,21))

"""
s_bw = np.argmin(dkls[:,1] - dkls[i,0]), np.argmax(dkls[:,1] - dkls[i,0]) # best and worst
for i in range(len(s_bw)):
    res  = np.load('check_kernels_cp_d' + str(duration) + '_' + str(s_bw[i]) + '.npy')[()]
    plt.subplot(3,4,i+1)
    
    posterior, posterior_k = res['posteriors'][-1], res['posteriors_k'][-1]
    sam = np.load('sam_' + str(duration) + '_' + str(seed) + 'cp.npz')['arr_0']

    plot_pdf(posterior,pdf2=posterior_k, 
             resolution=100, 
             samples=sam, gt=true_params);
"""

plt.subplot(3,4,1)

pu = np.zeros((3, res['n_rounds']))
for j in range(1,3):
    for k in range(res['n_rounds']):
        tmp = invwidths[:,j,k,:].reshape(-1)
        tmp = tmp[np.invert(np.isnan(tmp))]
        ks = k * np.ones_like(tmp)
        pu[j,k] = np.nanmean(tmp > 10**(-4))
    plt.plot(pu[j,:], 'o-', color=clrs[j])
plt.xlabel('#round')
plt.ylabel('fraction i : A_ii > 0.0001')
plt.title('usage of kernel (# non-flat kernel dims)')
plt.legend(['x_kl', 'mEss'], loc=5)

plt.subplot(6,4,2)
tmp = invwidths[:,1,:,:]
tmp = tmp[np.invert(np.isnan(tmp))]
plt.hist(np.log(tmp), bins=20, color=clrs[1]);
plt.yticks([])
plt.ylabel('x_kl')
plt.title('distributions of learned kernel params')

plt.subplot(6,4,6)
tmp = invwidths[:,2,:,:]
tmp = tmp[np.invert(np.isnan(tmp))]
plt.hist(np.log(tmp), bins=20, color=clrs[2]);
plt.yticks([])
plt.ylabel('mESS')
plt.xlabel( 'log(A_ii)' )


plt.subplot(3,4,3)
for j in range(len(posteriors_all)):
    plt.semilogy(range(1, RMSEs.shape[1]+1),RMSEs[j,:,0], color=clrs[j])
plt.title('error in posterior means')
plt.ylabel('Eucl. distance (avg. over seeds)')
plt.xlabel('#round')

plt.subplot(3,4,4)
for j in range(len(posteriors_all)):
    plt.semilogy(range(1, RMSEs.shape[1]+1),RMSEs[j,:,1], color=clrs[j])
plt.title('error in posterior variances (diag(S))')
plt.xlabel('#round')

plt.subplot(3,2,3)
idx = np.argsort(dkls[:,0])[::-1]
dkls_s = dkls[idx,:]
plt.bar(np.arange(dkls_s.shape[0]),     dkls_s[:,0], 0.3, color='r')
plt.bar(np.arange(dkls_s.shape[0])+0.2, dkls_s[:,1], 0.3, color='b')
plt.bar(np.arange(dkls_s.shape[0])+0.4, dkls_s[:,2], 0.3, color='g')
plt.xlabel('#rnd seed')
plt.ylabel('final KL divergences')
plt.legend(['raw', 'x_kl', 'mESS'])
plt.title('distance of posteriors to ground-truth')

plt.subplot(3,4,7)
if dkls.shape[0] > 1:
    plt.boxplot(dkls, labels=['raw', 'x_kl', 'mESS'])
plt.title('distribution of final D_KL')

plt.subplot(3,4,8)
plt.plot(-1, -1, 'r', linewidth=1.5)
plt.plot(-1, -1, 'b', linewidth=1.5)
plt.plot(-1, -1, 'g', linewidth=1.5)
for j_ in np.arange(3)[::-1]:

    plt.plot(np.arange(1,dklsa.shape[2]+1), dklsa[:,j_,:].T, '--', color=clrs[j_], alpha=0.35)
    plt.plot(np.arange(1,dklsa.shape[2]+1), BAM[:,j_,:].T, 'o', color=clrs[j_], ms=6, alpha=0.35)
    m, s = np.nanmean(dklsa[:,j_,:], axis=0), np.nanstd(dklsa[:,j_,:], axis=0)
    plt.plot(np.arange(1,dklsa.shape[2]+1), m, clrs[j_], linewidth=2.5)    

plt.ylabel('KL divergence')
plt.xlabel('#round')
plt.legend(['raw', 'x_kl', 'mESS'])
plt.axis([0.99, dklsa.shape[2], 0, 1.1*np.nanmax(dklsa)])
plt.title('posterior distance over rounds')

plt.subplot(3,2,5)
for j_ in np.arange(3)[::-1]:
    m, s = np.nanmean(ess[:,j_,1:], axis=0), np.nanstd(ess[:,j_,1:], axis=0)
    #plt.fill_between(np.arange(2,ess.shape[2]+1),
    #                 m-s,
    #                 m+s,
    #                 color=clrs[j_],
    #                 alpha=0.5
    #                )
    plt.plot(np.arange(2,ess.shape[2]+1), m, clrs[j_], linewidth=2.5)    
    plt.plot(np.arange(2,ess.shape[2]+1), ess[:,j_,1:].T, '--', alpha=0.35, color=clrs[j_])
plt.xlabel('#round')
plt.ylabel('ESS (avg: thick line)')
plt.title('IS weight variance: effective sample sizes')

if res['n_rounds'] > 9:
    rs = [0, 1, 4, 9]
else: 
    rs = [0, 1, 2, 3]
sp = [19, 20, 23, 24]
for r_ in range(4):
    
    r = rs[r_]
    plt.subplot(6,4,sp[r_])
    
    for j_ in np.arange(3):
        m, s = np.nanmean(sig2_[:,j_,r,:],axis=(0,)), np.nanmean(sig2_[:,j_,r,:], axis=(0,))
        plt.plot(np.arange(len(m))+1, m, color=clrs[j_], linewidth=1.5)   
        
    for j_ in np.arange(3):
        plt.plot(np.arange(len(m))+1, sig2e[:,j_,r,:].mean(axis=0,), 
                '--', color=clrs[j_], linewidth=1.2)

    if r_ in [2,3] :
        plt.xlabel('# summary statistic')
    if r_ in [0,2]:
        plt.ylabel('empirical variance')
    plt.title('SNPE round #' + str(r+1))

    if r_==0:
        plt.legend(['raw iws', 'iws + x_kl', 'iws + mESS'], loc=1)
        
plt.savefig('summary_GLM' + batch_specifier + fit_specifier +'_'+ str(len(seeds)) + 'seeds_dur' + str(duration) + '.pdf')
plt.show()