## learning kernels for the GLM example. 

we optimize kernels such that
$ K(x_n, x_0) p(\theta_n) / \tilde{p}(\theta_n) \approx 1$. 

Spoiler:
starts to work.


# approach

The above problem doesn't require MDNs at all. 
Once prior, proposal, kernel and simulator are fixed and we drew an artificial dataset $(x_n, \theta_n)$, we're good to play. 
Let's run SNPE as usual, note down the data-sets $(x_n, \theta_n)$, proposal priors and importance weights it produced over rounds, and afterwards play with the kernel on those fixed targets. 

- Remark: results look a lot worse if we convert to Students-t distributions. Could be that kernel shape (squared-exponential in $x$) has to match proposal-prior shape (squared in $\theta$ for students-T with df=3)?

We try out a bunch of simple squared losses. 

### 1. basic squared loss

argmin $ \sum_n \left( 1 - \frac{K(x_n, x_0) p(\theta_n)}{\tilde{p}(\theta_n)} \right)^2 $, emphasizing the absolute value of $\approx 1$. 

### 2. inverse-kernel loss

argmin $ \sum_n \left( \frac{1}{K(x_n,x_0)} - \frac{p(\theta_n)}{\tilde{p}(\theta_n)} \right)^2 $, emphasizes that the kernel should be small where importance weights are large

### 3. log-space loss

argmin $ \sum_n \left( \log(\frac{1}{K(x_n,x_0)}) - \log(\frac{p(\theta_n)}{\tilde{p}(\theta_n)}) \right)^2 = \sum_n \left( \log(\frac{p(\theta_n)K(x_n,x_0)}{\tilde{p}(\theta_n)}) \right)^2$ emphasizes ratios of $\approx 1$.

### 4. inverse weights-dominated loss

argmin $ \sum_n \left( K(x_n,x_0) - \frac{\tilde{p}(\theta_n)}{p(\theta_n)} \right)^2 $, emphasizes that the kernel should be large where importance weights are small

In [None]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.utils.io as io
import delfi.summarystats as ds
import lfimodels.glm.utils as utils
import matplotlib.pyplot as plt
import numpy as np

from lfimodels.glm.GLM import GLM
from lfimodels.glm.GLMStats import GLMStats
from delfi.utils.viz import plot_pdf

%matplotlib inline

seed = 42
m = GLM(seed=seed)
p = utils.smoothing_prior(n_params=m.n_params, seed=seed)
s = GLMStats(n_summary=m.n_params)
g = dg.Default(model=m, prior=p, summary=s)

true_params, labels_params = utils.obs_params()
obs = utils.obs_data(true_params, seed=seed)
obs_stats = utils.obs_stats(true_params, seed=seed)

rerun = False  # if False, will try loading file from disk

try:
    assert rerun == False, 'rerun requested'
    sam = np.load('sam.npz')['arr_0']
except:
    sam = utils.pg_mcmc(true_params, obs)
    np.savez('sam.npz', sam)
    
seed = 98
g = dg.Default(model=m, prior=p, summary=s)
res = infer.SNPE(g, 
                 obs=obs_stats, 
                 n_hiddens=[50], 
                 seed=seed, 
                 convert_to_T=None, 
                 pilot_samples=0,
                 svi=True,
                 reg_lambda=0.01,
                 prior_norm=False)

logs, tds, posteriors = res.run(n_train=5000, 
                                n_rounds=2, 
                                minibatch=100, 
                                epochs=1000, 
                                round_cl=3)

## SNPE fits over rounds (used as proposal priors in the follow-up)

In [None]:
# run with Gaussian proposals
for r in range(len(tds)):
    posterior = posteriors[r]
    plot_pdf(posterior.xs[0], 
             lims=[-2,2], 
             samples=sam, 
             gt=true_params, 
             figsize=(9,9));

# quick-check: efficacy of importance sampling on this setting
- importance sampling should take the parameter statistics (mean and var) of the proposal (black lines) and bring them back to that of the prior (cyan lines). 
- if the importance-weighted statistics are far off the prior statistics, IS didn't work. 
- if the importance-weighted statistics are not even closer to the prior statistics than those of the proposal, something is messed up

In [None]:

g_ = dg.Default(model=m, prior=p, summary=s)
th_prior, x_prior = g_.gen(5000)

r = 1 # pick proposal after first round

params = tds[r][0].astype(np.float32)

if r > 0:
    p_prior = g.prior.eval(params,log=False)
    p_proposal = posteriors[r-1].eval(params,log=False)
    iws = ( p_prior / p_proposal ).astype(np.float32)
else:
    iws = np.ones(params.shape[0],dtype=np.float32)

w = (iws).reshape(-1,1)
w = w / np.sum(w)

mu_th =   np.sum( w * tds[r][0],           axis=0).reshape(1,-1)
sig2_th = np.sum( w * (tds[r][0]-mu_th)**2, axis=0)

plt.figure(figsize=(9,6))
plt.subplot(1,2,1)
plt.plot(np.mean(th_prior,axis=0), 'co--')
plt.plot(tds[r][0].mean(axis=0), 'ko-')
plt.plot(mu_th.flatten(), 'ro-')
#plt.plot(true_params.flatten(), 'go--')
plt.legend(['E[th]', 'raw E[th]','IS E[th]'])
plt.xlabel('# summary statistic')
plt.title('raw and importance sampled data means')

plt.subplot(1,2,2)
plt.plot(np.var(th_prior,axis=0), 'co--')
plt.plot(tds[r][0].var(axis=0), 'ko-')
plt.plot(sig2_th.flatten(), 'ro-')
plt.legend(['Var[th]', 'raw Var[th]','IS Var[th]'])
plt.xlabel('# summary statistic')
plt.title('raw and importance sampled data variances')

plt.show()    

## learning a kernel


KernelLayer implements a very simple kernel 

$ K(x_n,x_0) = \exp( - (x-x_0)^\top A (x - x_0))$, 

where $A =BB^\top$ and $B$ is diagonal.

- simple kernel cannot follow $\frac{\tilde{p}(\theta_n)}{p(\theta_n)} > 1$ (cannot correct for importance weights < 1)
- kernels end up learning $\forall n: K(x_n,x_0) = 1$ (via $A \approx 0$)

KernelLayer_offset implements  

$ K(x_n,x_0) = c \exp( - (x-x_0)^\top A (x - x_0) )$

where $c = exp(Z)$ is a non-negative scaling factor

- adding pre-factor exp(Z) to kernel
- prefaktor starts shifting (to negative $Z$), but still $A \approx 1$

We use offset kernels in the following

In [None]:
import lasagne
import numpy as np
import theano
import theano.tensor as T

dtype = theano.config.floatX
ndim_x = 10

class KernelLayer(lasagne.layers.Layer):
    def __init__(self, incoming, B=lasagne.init.Normal(0.01), **kwargs):
        super(KernelLayer, self).__init__(incoming, **kwargs)
        num_inputs = self.input_shape[1]
        self.eye = T.eye(num_inputs)
        self.B = self.add_param(B, (num_inputs, ), name='B')

    def get_output_for(self, input, **kwargs):
        D = T.dot(self.B*self.eye, self.B*self.eye.T)
        inner = (input.dot(D)*input).sum(axis=1)
        return T.exp(-inner)

    def get_output_shape_for(self, input_shape):
        return (input_shape[0],)

    
class KernelLayer_offset(lasagne.layers.Layer):
    def __init__(self, incoming, B=lasagne.init.Normal(0.01), Z=lasagne.init.Normal(0.01), **kwargs):
        super(KernelLayer_offset, self).__init__(incoming, **kwargs)
        num_inputs = self.input_shape[1]
        self.eye = T.eye(num_inputs)
        self.B = self.add_param(B, (num_inputs, ), name='B')
        self.Z = self.add_param(Z, (1,), name='Z')

    def get_output_for(self, input, **kwargs):
        D = T.dot(self.B*self.eye, self.B*self.eye.T)
        inner = (input.dot(D)*input).sum(axis=1)
        return T.exp(-inner + self.Z)

    def get_output_shape_for(self, input_shape):
        return (input_shape[0],)    
    
input_var = T.fmatrix('inputs')
target_var = T.fvector('targets')


# 1. basic loss

argmin $ \sum_n \left( 1 - \frac{K(x_n, x_0) p(\theta_n)}{\tilde{p}(\theta_n)} \right)^2 $

- final average errors $\approx 1$, achieved by trivial solution  $\frac{K(x_n, x_0) p(\theta_n)}{\tilde{p}(\theta_n)} \approx 0$. 
- no obvious trend in kernel value relative to importance weight size. Kernel too simple?
- generally, kernels focusing on countering the largest weights (offeset Z negative)


In [None]:

for r in range(1, len(tds)): # pick best fit
    
    print('round #' + str(r))
    print('')
    
    
    # x - x0
    dx = (tds[r][1] - obs_stats).astype(np.float32)
    # weights (normalized)
    params = tds[r][0].astype(np.float32)

    if r > 0:
        p_prior = g.prior.eval(params,log=False)
        p_proposal = posteriors[r-1].eval(params,log=False)
        iws = ( p_prior / p_proposal ).astype(np.float32)
    else:
        iws = np.ones(params.shape[0],dtype=np.float32)
    
    #iws = (tds[r][2].reshape(-1)).astype(np.float32) # weights returned by SNPE are renormalized by default now
    
    # inverse weights (capturing zero weigths)
    iiws = np.minimum((1./iws).astype(np.float32), 1e20*np.ones_like(iws))

    l_in = lasagne.layers.InputLayer(shape=(None, ndim_x),input_var=input_var)
    l_dot = KernelLayer_offset(l_in, name='kernel_layer')
    prediction = lasagne.layers.get_output(l_dot)
    loss = (1 - prediction*target_var)**2
    loss = loss.mean()
    params = lasagne.layers.get_all_params(l_dot, trainable=True)
    updates = lasagne.updates.adam(
                loss, params, learning_rate=0.001)
    train_fn = theano.function([input_var, target_var], [loss, prediction], updates=updates,
                                on_unused_input='ignore')
    
    print('initial kernel A:', l_dot.B.get_value())
    train_errs = np.zeros(20000)
    for i in range(train_errs.size):
        train_errs[i], pred = train_fn(dx, iws)
    print('learned kernel A:', l_dot.B.get_value())

    idx = np.argsort(iws)
    plt.plot(train_errs)
    plt.title('training error')
    plt.show()

    plt.figure(figsize=(16,6))
    krnl  = pred[idx]
    ikrnl = 1./pred[idx]
    plt.subplot(1,3,1)
    plt.semilogy(iiws[idx])
    plt.semilogy(krnl)
    plt.legend(['inv. weights', 'kernel'])
    plt.title('kernel should track inverse weights')

    plt.subplot(1,3,2)
    plt.semilogy(iws[idx])
    plt.semilogy(ikrnl)
    plt.legend(['weights', 'inv. kernel'])
    plt.xlabel('n (sorted by importance weight value)')
    plt.title('inv. kernel should track weights')

    plt.subplot(1,3,3)
    plt.semilogy(iws[idx])
    plt.semilogy(krnl*iws[idx])
    plt.legend(['weights', 'kernel*weights'])
    plt.title('kernel*weights should be flatter than weights')
    plt.show()
    
    print('loss with / without learned kernels', np.mean((iws[idx]-1)**2), np.mean((pred[idx]*iws[idx]-1)**2))
    print('mean and std of raw importance weights', np.mean(iws), np.std(iws))
    print('mean and std of kernel-weighted importance weights', np.mean(iws*pred), np.std(iws*pred))   
    
    
    w = (iws).reshape(-1,1)
    w = w / np.sum(w)

    mu_x =   np.sum( w * tds[r][1],           axis=0).reshape(1,-1)
    sig2_x = np.sum( w * (tds[r][1]-mu_x)**2, axis=0)
    mu_x, sig2_x

    w = (pred*iws).reshape(-1,1)
    w = w / np.sum(w)

    mu_x_K =   np.sum( w * tds[r][1],           axis=0).reshape(1,-1)
    sig2_x_K = np.sum( w * (tds[r][1]-mu_x_K)**2, axis=0)
    mu_x_K, sig2_x_K

    plt.figure(figsize=(9,6))
    plt.subplot(1,2,1)
    plt.plot(np.mean(x_prior,axis=0), 'co--')
    plt.plot(tds[r][1].mean(axis=0), 'ko-')
    plt.plot(mu_x_K.flatten(), 'bo-')
    plt.plot(mu_x.flatten(), 'ro-')
    plt.plot(obs_stats.flatten(), 'go--')
    plt.legend(['E[x]', 'raw E[x]', 'kernel-weighted E_K[x]','pure IS E_IS[x]', 'real x_0'])
    plt.xlabel('# summary statistic')
    plt.title('raw and importance sampled data means')

    plt.subplot(1,2,2)
    plt.plot(np.var(x_prior,axis=0), 'co--')
    plt.plot(tds[r][1].var(axis=0), 'ko-')
    plt.plot(sig2_x_K.flatten(), 'bo-')
    plt.plot(sig2_x.flatten(), 'ro-')
    plt.legend(['Var[x]', 'raw Var[x]', 'kernel-weighted Var_K[x]','pure IS Var_IS[x]'])
    plt.xlabel('# summary statistic')
    plt.title('raw and importance sampled data variances')

    plt.show()    

# coordinate-wise loss landscape 

In [None]:
import seaborn

B = l_dot.B.get_value()
A_base=  np.diag(B**2)

Z = l_dot.Z.get_value()

def A_(lambd=None, i=0):
    A = A_base.copy()
    if not lambd is None:
        A[i,i]= lambd
    return A
def K(dx, A, Z):
    return  np.exp(- np.sum(dx.dot(A)* dx,axis=1) + Z )

def loss_map(lambd, i):
    return np.mean( (1 - iws*K(dx, A_(lambd,i), Z))**2 )


lambds = np.exp(np.log(10) * np.linspace(-3, 3, 1000))
l      = np.zeros_like(lambds)              
plt.figure(figsize=(8,16))
for i in range(10):    
    plt.subplot(5,2,i+1)
    for j in range(len(lambds)):
        l[j] = loss_map(lambds[j], i)
    plt.semilogx(lambds, l)
    if A_base[i,i]>np.min(lambds):
        plt.semilogx([A_base[i,i], A_base[i,i]], [1.,1.], 'r*', markersize=10)        
    if i == 0:
        plt.legend(['loss', 'found optimum'], loc=4)
    if i > 7:
        plt.xlabel('1/sig^2')
    if np.mod(i,2) ==0:
        plt.ylabel('loss')

plt.savefig('kernel_learning_GLM_example_basicLoss_offsetKernel_gaussianProposals_coordinatewiseErrorFunction.pdf')        
plt.show()

lss - train_errs[-1]

## spherical kernel loss landscape

In [None]:
import seaborn

def A_(lambd):
    A = lambd * np.eye(10)    
    return A
def K(dx, A, Z):
    return  np.exp(- np.sum(dx.dot(A)* dx,axis=1) + Z )

def loss_map(lambd, Z):
    return np.mean( (1 - iws*K(dx, A_(lambd), Z))**2 )


lambds = np.exp(np.log(10) * np.linspace(-3, 3, 100))
Zs     = [l_dot.Z.get_value()] #np.linspace(-10, 10, 20)
l      = np.zeros((len(lambds), len(Zs)))            

plt.figure(figsize=(4,4))
for j in range(len(lambds)):
    for z in range(len(Zs)):
        l[j, z] = loss_map(lambds[j], Zs[z])
plt.loglog(lambds, l)
plt.legend(['log loss', 'found optimum'], loc=4)
plt.xlabel('1/sig^2')
plt.ylabel('loss')

plt.show()

lss - train_errs[-1]

# 2. inverse-kernel loss

argmin $ \sum_n \left( \frac{1}{K(x_n,x_0)} - \frac{p(\theta_n)}{\tilde{p}(\theta_n)} \right)^2 $

- arguably we care more about too large than about too small weights - this simple squared loss is dominated by large weights, meaning the kernel (where possible) will try to counter those.


- starts doing something constructive
- average kernel-weighted importance weights become closer to $1$ than averag raw importance weights.
- sometimes also learns a flat kernel. Local optimum?

In [None]:

for r in range(1, len(tds)):
    
    print('round #' + str(r))
    print('')
    
    
    # x - x0
    dx = (tds[r][1] - obs_stats).astype(np.float32)
    # weights (normalized)
    
    params = tds[r][0].astype(np.float32)

    if r > 0:
        p_prior = g.prior.eval(params,log=False)
        p_proposal = posteriors[r-1].eval(params,log=False)
        iws = ( p_prior / p_proposal ).astype(np.float32)
    else:
        iws = np.ones(params.shape[0],dtype=np.float32)
    
    #iws = (tds[r][2].reshape(-1)).astype(np.float32) # weights returned by SNPE are renormalized by default now
    
    # inverse weights (capturing zero weigths)
    iiws = np.minimum((1./iws).astype(np.float32), 1e20*np.ones_like(iws))

    l_in = lasagne.layers.InputLayer(shape=(None, ndim_x),input_var=input_var)
    l_dot = KernelLayer_offset(l_in, name='kernel_layer')
    prediction = lasagne.layers.get_output(l_dot)
    loss = (T.inv(prediction)-target_var)**2
    loss = loss.mean()
    params = lasagne.layers.get_all_params(l_dot, trainable=True)
    updates = lasagne.updates.adam(
                loss, params, learning_rate=0.001)
    train_fn = theano.function([input_var, target_var], [loss, prediction], updates=updates,
                                on_unused_input='ignore')
    
    print('initial kernel A:', l_dot.B.get_value())
    train_errs = np.zeros(20000)
    for i in range(train_errs.size):
        train_errs[i], pred = train_fn(dx, iws)
    print('learned kernel A:', l_dot.B.get_value())

    idx = np.argsort(iws)
    plt.plot(train_errs)
    plt.title('training error')
    plt.show()

    plt.figure(figsize=(16,6))
    krnl  = pred[idx]
    ikrnl = 1./pred[idx]
    plt.subplot(1,3,1)
    plt.semilogy(iiws[idx])
    plt.semilogy(krnl)
    plt.legend(['inv. weights', 'kernel'])
    plt.title('kernel should track inverse weights')

    plt.subplot(1,3,2)
    plt.semilogy(iws[idx])
    plt.semilogy(ikrnl)
    plt.legend(['weights', 'inv. kernel'])
    plt.xlabel('n (sorted by importance weight value)')
    plt.title('inv. kernel should track weights')

    plt.subplot(1,3,3)
    plt.semilogy(iws[idx])
    plt.semilogy(krnl*iws[idx])
    plt.legend(['weights', 'kernel*weights'])
    plt.title('kernel*weights should be flatter than weights')
    plt.show()
    
    print('loss with/without learned kernels', np.mean((iws[idx]-1)**2), np.mean((iws[idx]-1/pred[idx])**2))
    print('mean and std of raw importance weights', np.mean(iws), np.std(iws))    
    print('mean and std of kernel-weighted importance weights', np.mean(iws*pred), np.std(iws*pred))
    
    
    w = (iws).reshape(-1,1)
    w = w / np.sum(w)

    mu_x =   np.sum( w * tds[r][1],           axis=0).reshape(1,-1)
    sig2_x = np.sum( w * (tds[r][1]-mu_x)**2, axis=0)
    mu_x, sig2_x

    w = (pred*iws).reshape(-1,1)
    w = w / np.sum(w)

    mu_x_K =   np.sum( w * tds[r][1],           axis=0).reshape(1,-1)
    sig2_x_K = np.sum( w * (tds[r][1]-mu_x_K)**2, axis=0)
    mu_x_K, sig2_x_K

    plt.figure(figsize=(9,6))
    plt.subplot(1,2,1)
    plt.plot(tds[r][1].mean(axis=0), 'ko-')
    plt.plot(mu_x_K.flatten(), 'bo-')
    plt.plot(mu_x.flatten(), 'ro-')
    plt.plot(obs_stats.flatten(), 'go--')
    plt.legend(['raw E[x]', 'kernel-weighted E_K[x]','pure IS E_IS[x]', 'real x_0'])
    plt.xlabel('# summary statistic')
    plt.title('raw and importance sampled data means')

    plt.subplot(1,2,2)
    plt.plot(tds[r][1].var(axis=0), 'ko-')
    plt.plot(sig2_x_K.flatten(), 'bo-')
    plt.plot(sig2_x.flatten(), 'ro-')
    plt.legend(['raw Var[x]', 'kernel-weighted Var_K[x]','pure IS Var_IS[x]'])
    plt.xlabel('# summary statistic')
    plt.title('raw and importance sampled data variances')

    plt.show()    

# 3. log-space loss

argmin $ \sum_n \left( \log(\frac{1}{K(x_n,x_0)}) - \log(\frac{p(\theta_n)}{\tilde{p}(\theta_n)}) \right)^2 = \sum_n  \left( \log(1) - \log(\frac{p(\theta_n)}{\tilde{p}(\theta_n)K(x_n,x_0)}) \right)^2 = \sum_n \left( \log(\frac{p(\theta_n)K(x_n,x_0)}{\tilde{p}(\theta_n)}) \right)^2$

- appear to suffer from most of the weights being negative: fit draws the geometric mean towards $1$, making the few large samples in the data-set gigantic. Would definately have to renormalize.

In [None]:

for r in range(1, len(tds)):
    
    print('round #' + str(r))
    print('')
    
    
    # x - x0
    dx = (tds[r][1] - obs_stats).astype(np.float32)
    # weights (normalized)
    params = tds[r][0].astype(np.float32)

    if r > 0:
        p_prior = g.prior.eval(params,log=False)
        p_proposal = posteriors[r-1].eval(params,log=False)
        iws = ( p_prior / p_proposal ).astype(np.float32)
    else:
        iws = np.ones(params.shape[0],dtype=np.float32)
    
        
    # inverse weights (capturing zero weigths)
    iiws = np.minimum((1./iws).astype(np.float32), 1e20*np.ones_like(iws))
    
    # inverse weights (capturing zero weigths)
    iiws = np.minimum((1./tds[r][2].reshape(-1)).astype(np.float32), 1e20*np.ones_like(iws))

    l_in = lasagne.layers.InputLayer(shape=(None, ndim_x),input_var=input_var)
    l_dot = KernelLayer_offset(l_in, name='kernel_layer')
    prediction = lasagne.layers.get_output(l_dot)
    loss = (T.log(prediction)+T.log(target_var))**2
    loss = loss.mean()
    params = lasagne.layers.get_all_params(l_dot, trainable=True)
    updates = lasagne.updates.adam(
                loss, params, learning_rate=0.001)
    train_fn = theano.function([input_var, target_var], [loss, prediction], updates=updates,
                                on_unused_input='ignore')
    
    
    print('initial kernel A:', l_dot.B.get_value())
    train_errs = np.zeros(20000)
    for i in range(train_errs.size):
        train_errs[i], pred = train_fn(dx, iws)
    print('learned kernel A:', l_dot.B.get_value())

    idx = np.argsort(iws)
    plt.plot(train_errs)
    plt.title('training error')
    plt.show()

    plt.figure(figsize=(16,6))
    krnl  = pred[idx]
    ikrnl = 1./pred[idx]
    plt.subplot(1,3,1)
    plt.semilogy(iiws[idx])
    plt.semilogy(krnl)
    plt.legend(['inv. weights', 'kernel'])
    plt.title('kernel should track inverse weights')

    plt.subplot(1,3,2)
    plt.semilogy(iws[idx])
    plt.semilogy(ikrnl)
    plt.legend(['weights', 'inv. kernel'])
    plt.xlabel('n (sorted by importance weight value)')
    plt.title('inv. kernel should track weights')

    plt.subplot(1,3,3)
    plt.semilogy(iws[idx])
    plt.semilogy(krnl*iws[idx])
    plt.legend(['weights', 'kernel*weights'])
    plt.title('kernel*weights should be flatter than weights')
    plt.show()
    
    print('loss with/without learned kernels', np.mean(np.log(pred[idx]*iws[idx])**2), np.mean(np.log(iws[idx])**2))
    print('mean and std of raw importance weights', np.mean(iws), np.std(iws))    
    print('mean and std of kernel-weighted importance weights', np.mean(iws*pred), np.std(iws*pred))
    
    
    w = (iws).reshape(-1,1)
    w = w / np.sum(w)

    mu_x =   np.sum( w * tds[r][1],           axis=0).reshape(1,-1)
    sig2_x = np.sum( w * (tds[r][1]-mu_x)**2, axis=0)
    mu_x, sig2_x

    w = (pred*iws).reshape(-1,1)
    w = w / np.sum(w)

    mu_x_K =   np.sum( w * tds[r][1],           axis=0).reshape(1,-1)
    sig2_x_K = np.sum( w * (tds[r][1]-mu_x_K)**2, axis=0)
    mu_x_K, sig2_x_K

    plt.figure(figsize=(9,6))
    plt.subplot(1,2,1)
    plt.plot(tds[r][1].mean(axis=0), 'ko-')
    plt.plot(mu_x_K.flatten(), 'bo-')
    plt.plot(mu_x.flatten(), 'ro-')
    plt.plot(obs_stats.flatten(), 'go--')
    plt.legend(['raw E[x]', 'kernel-weighted E_K[x]','pure IS E_IS[x]', 'real x_0'])
    plt.xlabel('# summary statistic')
    plt.title('raw and importance sampled data means')

    plt.subplot(1,2,2)
    plt.plot(tds[r][1].var(axis=0), 'ko-')
    plt.plot(sig2_x_K.flatten(), 'bo-')
    plt.plot(sig2_x.flatten(), 'ro-')
    plt.legend(['raw Var[x]', 'kernel-weighted Var_K[x]','pure IS Var_IS[x]'])
    plt.xlabel('# summary statistic')
    plt.title('raw and importance sampled data variances')

    plt.show()    

# 4. inverse weights-dominated loss

argmin $ \sum_n \left( K(x_n,x_0) - \frac{\tilde{p}(\theta_n)}{p(\theta_n)} \right)^2 $, i.e. squared error on inverse importance weights

- squared error dominated by small weights (large $\frac{\tilde{p}(\theta_n)}{p(\theta_n)}$), leading to large positive offsets $Z$ (yielding $K(x,x_0)$ values around 1e6)

In [None]:

for r in range(1, len(tds)):
    
    print('round #' + str(r))
    print('')
    
    
    # x - x0
    dx = (tds[r][1] - obs_stats).astype(np.float32)
    # weights (normalized)
    
    params = tds[r][0].astype(np.float32)

    if r > 0:
        p_prior = g.prior.eval(params,log=False)
        p_proposal = posteriors[r-1].eval(params,log=False)
        iws = ( p_prior / p_proposal ).astype(np.float32)
    else:
        iws = np.ones(params.shape[0],dtype=np.float32)
    
    #iws = (tds[r][2].reshape(-1)).astype(np.float32) # weights returned by SNPE are renormalized by default now
    
    # inverse weights (capturing zero weigths)
    iiws = np.minimum((1./iws).astype(np.float32), 1e20*np.ones_like(iws))

    l_in = lasagne.layers.InputLayer(shape=(None, ndim_x),input_var=input_var)
    l_dot = KernelLayer_offset(l_in, name='kernel_layer')
    prediction = lasagne.layers.get_output(l_dot)
    loss = (prediction-target_var)**2
    loss = loss.mean()
    params = lasagne.layers.get_all_params(l_dot, trainable=True)
    updates = lasagne.updates.adam(
                loss, params, learning_rate=0.001)
    train_fn = theano.function([input_var, target_var], [loss, prediction], updates=updates,
                                on_unused_input='ignore')
    
    print('initial kernel A:', l_dot.B.get_value())
    train_errs = np.zeros(20000)
    for i in range(train_errs.size):
        train_errs[i], pred = train_fn(dx, iiws)
    print('learned kernel A:', l_dot.B.get_value())

    idx = np.argsort(iws)
    

    plt.figure(figsize=(16,6))
    plt.subplot(1,3,1)
    
    plt.plot(train_errs)
    plt.title('training error')
    plt.xlabel('gradient step')
    
    #plt.semilogy(iiws[idx])
    #plt.semilogy(pred[idx])
    #plt.legend(['inv. weights', 'kernel'])
    #plt.title('kernel should track inverse weights')

    plt.subplot(1,3,2)
    plt.semilogy(1./pred[idx])
    plt.semilogy(iws[idx], linewidth=2.5)
    plt.legend(['inv. kernel', 'weights'])
    plt.xlabel('n (sorted by importance weight value)')
    plt.title('inv. kernel should track weights')

    plt.subplot(1,3,3)
    plt.semilogy(pred[idx]*iws[idx])
    plt.semilogy(iws[idx], linewidth=2.5)
    plt.legend(['kernel*weights', 'weights'])
    plt.title('kernel*weights should be flatter than weights')
    plt.xlabel('n (sorted by importance weight value)')
    
    plt.show()
    
    print('loss with/without learned kernels', np.mean((iws[idx]-1)**2), np.mean((iws[idx]-1/pred[idx])**2))
    print('mean and std of raw importance weights', np.mean(iws), np.std(iws))    
    print('mean and std of kernel-weighted importance weights', np.mean(iws*pred), np.std(iws*pred))
    
    
    
    w = (iws).reshape(-1,1)
    w = w / np.sum(w)

    mu_x =   np.sum( w * tds[r][1],           axis=0).reshape(1,-1)
    sig2_x = np.sum( w * (tds[r][1]-mu_x)**2, axis=0)
    mu_x, sig2_x

    w = (pred*iws).reshape(-1,1)
    w = w / np.sum(w)

    mu_x_K =   np.sum( w * tds[r][1],           axis=0).reshape(1,-1)
    sig2_x_K = np.sum( w * (tds[r][1]-mu_x_K)**2, axis=0)
    mu_x_K, sig2_x_K

    plt.figure(figsize=(9,6))
    plt.subplot(1,2,1)
    plt.plot(tds[r][1].mean(axis=0), 'ko-')
    plt.plot(mu_x_K.flatten(), 'bo-')
    plt.plot(mu_x.flatten(), 'ro-')
    plt.plot(obs_stats.flatten(), 'go--')
    plt.legend(['raw E[x]', 'kernel-weighted E_K[x]','pure IS E_IS[x]', 'real x_0'])
    plt.xlabel('# summary statistic')
    plt.title('raw and importance sampled data means')

    plt.subplot(1,2,2)
    plt.plot(tds[r][1].var(axis=0), 'ko-')
    plt.plot(sig2_x_K.flatten(), 'bo-')
    plt.plot(sig2_x.flatten(), 'ro-')
    plt.legend(['raw Var[x]', 'kernel-weighted Var_K[x]','pure IS Var_IS[x]'])
    plt.xlabel('# summary statistic')
    plt.title('raw and importance sampled data variances')

    plt.show()

# pick one for plotting

In [None]:

for r in range(1, 2):
    
    print('round #' + str(r))
    print('')
    
    
    # x - x0
    dx = (tds[r][1] - obs_stats).astype(np.float32)
    # weights (normalized)
    
    params = tds[r][0].astype(np.float32)

    if r > 0:
        p_prior = g.prior.eval(params,log=False)
        p_proposal = posteriors[r-1].eval(params,log=False)
        iws = ( p_prior / p_proposal ).astype(np.float32)
    else:
        iws = np.ones(params.shape[0],dtype=np.float32)
    
    #iws = (tds[r][2].reshape(-1)).astype(np.float32) # weights returned by SNPE are renormalized by default now
    
    # inverse weights (capturing zero weigths)
    iiws = np.minimum((1./iws).astype(np.float32), 1e20*np.ones_like(iws))

    l_in = lasagne.layers.InputLayer(shape=(None, ndim_x),input_var=input_var)
    l_dot = KernelLayer_offset(l_in, name='kernel_layer')
    prediction = lasagne.layers.get_output(l_dot)
    loss = (prediction-target_var)**2
    loss = loss.mean()
    params = lasagne.layers.get_all_params(l_dot, trainable=True)
    updates = lasagne.updates.adam(
                loss, params, learning_rate=0.001)
    train_fn = theano.function([input_var, target_var], [loss, prediction], updates=updates,
                                on_unused_input='ignore')
    
    print('initial kernel A:', l_dot.B.get_value())
    train_errs = np.zeros(20000)
    for i in range(train_errs.size):
        train_errs[i], pred = train_fn(dx, iiws)
    print('learned kernel A:', l_dot.B.get_value())

    idx = np.argsort(iws)
    

    plt.figure(figsize=(16,6))
    plt.subplot(1,3,1)
    
    plt.plot(train_errs)
    plt.title('training error')
    plt.xlabel('gradient step')
    
    #plt.semilogy(iiws[idx])
    #plt.semilogy(pred[idx])
    #plt.legend(['inv. weights', 'kernel'])
    #plt.title('kernel should track inverse weights')

    plt.subplot(1,3,2)
    plt.semilogy(1./pred[idx])
    plt.semilogy(iws[idx], linewidth=2.5)
    plt.legend(['inv. kernel', 'weights'])
    plt.xlabel('n (sorted by importance weight value)')
    plt.title('inv. kernel should track weights')

    plt.subplot(1,3,3)
    plt.semilogy(pred[idx]*iws[idx])
    plt.semilogy(iws[idx], linewidth=2.5)
    plt.legend(['kernel*weights', 'weights'])
    plt.title('kernel*weights should be flatter than weights')
    plt.xlabel('n (sorted by importance weight value)')
    
    plt.savefig('kernel_learning_GLM_example_r1_inverseWeightedLoss_gaussianProposals.pdf')
    plt.show()
    
    print('loss with/without learned kernels', np.mean((iws[idx]-1)**2), np.mean((iws[idx]-1/pred[idx])**2))
    print('mean and std of raw importance weights', np.mean(iws), np.std(iws))    
    print('mean and std of kernel-weighted importance weights', np.mean(iws*pred), np.std(iws*pred))
    
    
    
    w = (iws).reshape(-1,1)
    w = w / np.sum(w)

    mu_x =   np.sum( w * tds[r][1],           axis=0).reshape(1,-1)
    sig2_x = np.sum( w * (tds[r][1]-mu_x)**2, axis=0)
    mu_x, sig2_x

    w = (pred*iws).reshape(-1,1)
    w = w / np.sum(w)

    mu_x_K =   np.sum( w * tds[r][1],           axis=0).reshape(1,-1)
    sig2_x_K = np.sum( w * (tds[r][1]-mu_x_K)**2, axis=0)
    mu_x_K, sig2_x_K

    plt.figure(figsize=(9,6))
    plt.subplot(1,2,1)
    plt.plot(tds[r][1].mean(axis=0), 'ko-')
    plt.plot(mu_x_K.flatten(), 'bo-')
    plt.plot(mu_x.flatten(), 'ro-')
    plt.plot(obs_stats.flatten(), 'go--')
    plt.legend(['raw E[x]', 'kernel-weighted E_K[x]','pure IS E_IS[x]', 'real x_0'])
    plt.xlabel('# summary statistic')
    plt.title('raw and importance sampled data means')

    plt.subplot(1,2,2)
    plt.plot(tds[r][1].var(axis=0), 'ko-')
    plt.plot(sig2_x_K.flatten(), 'bo-')
    plt.plot(sig2_x.flatten(), 'ro-')
    plt.legend(['raw Var[x]', 'kernel-weighted Var_K[x]','pure IS Var_IS[x]'])
    plt.xlabel('# summary statistic')
    plt.title('raw and importance sampled data variances')

    plt.show()

# sanity check: draw new data not returned from SNPE logs
- might be overlooking something in the SNPE code and how it handles (Z-scores, shifts, stores...) it's logged datasets.

In [None]:
g = dg.Default(model=m, prior=p, summary=s)
g.proposal = posteriors[0]

th, x = g.gen(5000)

# x - x0
dx = (x - obs_stats).astype(np.float32)
# weights (normalized)
params = th.astype(np.float32)

p_prior = g.prior.eval(params,log=False)
p_proposal = g.proposal.eval(params,log=False)
iws = ( p_prior / p_proposal ).astype(np.float32)

#iws = (tds[r][2].reshape(-1)).astype(np.float32) # weights returned by SNPE are renormalized by default now

# inverse weights (capturing zero weigths)
iiws = np.minimum((1./iws).astype(np.float32), 1e20*np.ones_like(iws))

l_in = lasagne.layers.InputLayer(shape=(None, ndim_x),input_var=input_var)
l_dot = KernelLayer_offset(l_in, name='kernel_layer')
prediction = lasagne.layers.get_output(l_dot)
loss = (1 - prediction*target_var)**2
loss = loss.mean()
params = lasagne.layers.get_all_params(l_dot, trainable=True)
updates = lasagne.updates.adam(
            loss, params, learning_rate=0.001)
train_fn = theano.function([input_var, target_var], [loss, prediction], updates=updates,
                            on_unused_input='ignore')

print('initial kernel A:', l_dot.B.get_value())
train_errs = np.zeros(20000)
for i in range(train_errs.size):
    train_errs[i], pred = train_fn(dx, iws)
print('learned kernel A:', l_dot.B.get_value())

idx = np.argsort(iws)
plt.plot(train_errs)
plt.title('training error')
plt.show()

plt.figure(figsize=(16,6))
krnl  = pred[idx]
ikrnl = 1./pred[idx]
plt.subplot(1,3,1)
plt.semilogy(iiws[idx])
plt.semilogy(krnl)
plt.legend(['inv. weights', 'kernel'])
plt.title('kernel should track inverse weights')

plt.subplot(1,3,2)
plt.semilogy(iws[idx])
plt.semilogy(ikrnl)
plt.legend(['weights', 'inv. kernel'])
plt.xlabel('n (sorted by importance weight value)')
plt.title('inv. kernel should track weights')

plt.subplot(1,3,3)
plt.semilogy(iws[idx])
plt.semilogy(krnl*iws[idx])
plt.legend(['weights', 'kernel*weights'])
plt.title('kernel*weights should be flatter than weights')
plt.show()

print('loss with / without learned kernels', np.mean((iws[idx]-1)**2), np.mean((pred[idx]*iws[idx]-1)**2))
print('mean and std of raw importance weights', np.mean(iws), np.std(iws))
print('mean and std of kernel-weighted importance weights', np.mean(iws*pred), np.std(iws*pred))   


w = (iws).reshape(-1,1)
w = w / np.sum(w)

mu_x =   np.sum( w *  x,           axis=0).reshape(1,-1)
sig2_x = np.sum( w * (x-mu_x)**2, axis=0)
mu_x, sig2_x

w = (pred*iws).reshape(-1,1)
w = w / np.sum(w)

mu_x_K =   np.sum( w *  x,           axis=0).reshape(1,-1)
sig2_x_K = np.sum( w * (x-mu_x_K)**2, axis=0)
mu_x_K, sig2_x_K

plt.figure(figsize=(9,6))
plt.subplot(1,2,1)
plt.plot(np.mean(x_prior,axis=0), 'co--')
plt.plot(x.mean(axis=0), 'ko-')
plt.plot(mu_x_K.flatten(), 'bo-')
plt.plot(mu_x.flatten(), 'ro-')
plt.plot(obs_stats.flatten(), 'go--')
plt.legend(['E[x]', 'raw E[x]', 'kernel-weighted E_K[x]','pure IS E_IS[x]', 'real x_0'])
plt.xlabel('# summary statistic')
plt.title('raw and importance sampled data means')

plt.subplot(1,2,2)
plt.plot(np.var(x_prior,axis=0), 'co--')
plt.plot(th.var(axis=0), 'ko-')
plt.plot(sig2_x_K.flatten(), 'bo-')
plt.plot(sig2_x.flatten(), 'ro-')
plt.legend(['Var[x]', 'raw Var[x]', 'kernel-weighted Var_K[x]','pure IS Var_IS[x]'])
plt.xlabel('# summary statistic')
plt.title('raw and importance sampled data variances')

plt.show()    


import seaborn

B = l_dot.B.get_value()
A_base=  np.diag(B**2)

Z = l_dot.Z.get_value()

def A_(lambd=None, i=0):
    A = A_base.copy()
    if not lambd is None:
        A[i,i]= lambd
    return A
def K(dx, A, Z):
    return  np.exp(- np.sum(dx.dot(A)* dx,axis=1) + Z )

def loss_map(lambd, i):
    return np.mean( (1 - iws*K(dx, A_(lambd,i), Z))**2 )


lambds = np.exp(np.log(10) * np.linspace(-3, 3, 1000))
l      = np.zeros_like(lambds)              
plt.figure(figsize=(8,16))
for i in range(10):    
    plt.subplot(5,2,i+1)
    for j in range(len(lambds)):
        l[j] = loss_map(lambds[j], i)
    plt.semilogx(lambds, l)
    if A_base[i,i]>np.min(lambds):
        plt.semilogx([A_base[i,i], A_base[i,i]], [1.,1.], 'r*', markersize=10)        
    if i == 0:
        plt.legend(['loss', 'found optimum'], loc=4)
    if i > 7:
        plt.xlabel('1/sig^2')
    if np.mod(i,2) ==0:
        plt.ylabel('loss')

#plt.savefig('kernel_learning_GLM_example_basicLoss_offsetKernel_gaussianProposals_coordinatewiseErrorFunction.pdf')        
plt.show()

lss - train_errs[-1]