# Inferring the location and contrast of a blob

In [None]:
%matplotlib inline

import numpy as np
import numpy.random as nr
import scipy.stats as stats

from scipy.special import gammaln
from matplotlib import pyplot as plt

import delfi.generator as dg
import delfi.distribution as dd
from delfi.utils.viz import plot_pdf
import delfi.inference as infer

seed = 42
M = 32  # edge dimensionality of image


In [None]:
def expit17(x):
        return 34. / (1. + np.exp(-x)) -17

def expit5(x):
        return 4.85 / (1. + np.exp(-x)) + 0.2
    
def logit17(x):
        x = (x+17.)/34.
        return np.log( x / (1. - x) )
    
def logit5(x):
        x = (x-0.2)/4.85
        return np.log( x / (1. - x) )

In [None]:
from delfi.simulator.BaseSimulator import BaseSimulator
from delfi.summarystats.Identity import Identity


class Blob(BaseSimulator): 

    def __init__(self, M, N, sigma=None):
        self.M = M
        self.N = N
        
        self.sigma = sigma
        self.x, self.y = np.meshgrid(np.linspace(-M//2, M//2, M),
                           np.linspace(-M//2, M//2, M))        

    def gen_single(self, params):
        
        if self.sigma is None:
            assert params.size == 4
            xo, yo, gamma, sigma = params[0], params[1], params[2], params[3]
        else: 
            assert params.size == 3
            xo, yo, gamma, sigma = params[0], params[1], params[2], self.sigma        
            
        xo, yo, gamma = expit17(xo), expit17(yo), expit5(gamma)
        
        r = (self.x - xo)**2 + (self.y - yo)**2
        p = 0.1 + 0.8 * np.exp(-0.5 * (r / sigma**2) ** gamma)
        
        return {'data' : nr.binomial(self.N, p).reshape(1,-1) / self.N }

p = dd.Gaussian(m = np.zeros(3), S=1.78 * np.eye(3))

#p = dd.Uniform(lower = [-16, -16, 0.25 ],
#               upper = [ 16,  16,    5 ])
#
g = dg.Default(model= Blob(M, N=255, sigma=2.), prior = p, summary = Identity(), seed = 42)


In [None]:
pars_true = np.array([ logit17(7.2), logit17(4.4), logit5(1.6) ])
obs = g.model.gen_single(pars_true)
obs_stats = g.summary.calc([obs])

plt.imshow(obs_stats.reshape(M,M), interpolation='None')
plt.show()

In [None]:
algo = 'CDELFI'

# network architecture: 8 layer network [4x conv, 3x fully conn., 1x MoG], 20k parameters in total 

filter_sizes=[3,3,3,2,2]   # 5 conv ReLU layers
n_filters=(16,16,32,32,32) # 16 to 64 filters
pool_sizes=[1,2,2,2,1,2]     # 
n_hiddens=[50,50]     # 3 fully connected layers

# N = 100k per round

n_train=10000

# single component (posterior at most STAs is well-approximated by single Gaussian - we also want to run more SNPE-A)

n_components=1

# single rounds (first round is always'amortized' and can be used with any other STA covered by the prior)

n_rounds=5

# new feature for CNN architectures: passing a value directly to the hidden layers (bypassing the conv layers).
# In this case, we pass the number of spikes (single number) directly, which allows to normalize the STAs 
# and hence help out the conv layers. Without that extra input, we couldn't recover the RF gain anymore. 
n_inputs_hidden = 0

# some learning-schedule parameters
lr_decay = 0.999
epochs=10
minibatch=50

svi=False          # large N should make this do nothing anyways
reg_lambda=0.0   # just to make doubly sure SVI is switched off...

pilot_samples=None # z-scoring only applies to extra inputs (here: firing rate) directly fed to fully connected layers

prior_norm = True  # doesn't hurt. 
init_norm = False  # didn't yet figure how to best normalize initialization through conv- and ReLU- layers

rank = None   # fitting only DIAGONAL covariances

# SNPE-B specific settings
convert_to_T=3
cbk_feature_layer=14 # mixture means
kernel_loss = 'x_kl'

# SNPE-A specific settings
def sbc_fun(x):
    return x


In [None]:
if algo == 'CDELFI':

    inf = infer.CDELFI(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                     pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                     n_components=1, n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,M,M),
                     filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,
                     rank=rank, verbose=True)
    
elif algo == 'SNPE':

    init_norm = True
    inf = infer.SNPE(generator=g, obs=obs_stats, prior_norm=prior_norm, init_norm=init_norm,
                     pilot_samples=pilot_samples, seed=seed, reg_lambda=reg_lambda, svi=svi,
                     n_components=n_components, n_hiddens=n_hiddens, n_filters=n_filters, n_inputs = (1,M,M),
                     filter_sizes=filter_sizes, pool_sizes=pool_sizes, n_inputs_hidden=n_inputs_hidden,
                     rank=rank, verbose=True, convert_to_T=convert_to_T)


In [None]:
# print parameter numbers per layer (just weights, not biases)
def get_shape(i):
    return inf.network.aps[i].get_value().shape
print([get_shape(i) for i in range(1,17,2)])
print([np.prod(get_shape(i)) for i in range(1,17,2)])

In [None]:
if algo == 'CDELFI':
    
    #run SNPE-A for one round
    project_proposal=False
    stndrd_comps=True
    
    log, trn_data, posteriors = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
                   lr_decay=lr_decay,n_components=n_components, 
                   stndrd_comps=stndrd_comps, project_proposal=project_proposal,
                   sbc_fun=sbc_fun)
    
elif algo == 'SNPE':

    # run SNPE-B for one round
    lr = 0.001    
    log, trn_data, posteriors = inf.run(n_train=n_train, epochs=epochs, minibatch=minibatch, n_rounds=n_rounds,  
                   lr_decay=lr_decay, lr =lr, kernel_loss=kernel_loss, cbk_feature_layer=cbk_feature_layer)

    for r in range(len(trn_data)):
        iws = trn_data[r][2]
        iws = iws/iws.sum()
        ESS = 1./ np.sum( iws ** 2)
        print('ESS', ESS)
    
plt.plot(log[-1]['loss'][:1000])
plt.show()

In [None]:
L = 100
labels_params = ['xo', 'yo', 'gamma'] 
for r in range(n_rounds):
    res = log[r]['sbc']
    plt.figure(figsize=(16,6))
    for i in range(pars_true.size):
        plt.subplot(1,pars_true.size,i+1)
        plt.hist(res[:,i], color='r', normed=True, bins=np.linspace(0,L+1,L+2)-.5)
        plt.title(labels_params[i])
    plt.show()
    
    plt.figure(figsize=(16,3))
    plt.subplot(1,2,1)
    plt.plot(log[r]['loss'])
    plt.subplot(1,2,2)
    plt.semilogx(log[r]['loss'])
    plt.show()

In [None]:
posterior = inf.predict(obs_stats)
posterior = dd.mixture.MoTG(ms=[x.m for x in posterior.xs],Ss=[x.S for x in posterior.xs],a = posterior.a,
                            flags=[2,2,2],upper=[17,17,5.05],lower=[-17,-17,0.2])

pars_raw = np.array( [ 7.2, 4.4, 1.6 ] )

fig, _ = plot_pdf(posterior, lims=[[-10, 10], [-10, 10], [0.25, 5]], gt=pars_raw.reshape(-1),  
                  figsize=(16,16), resolution=100,labels_params=labels_params, ticks=True)

try:
    pars_true2 = np.array([ 4.1,  5.2,  4.5])
    obs_stats = g.summary.calc([g.model.gen_single(pars_true2)])
    posterior = inf.predict(obs_stats)
    posterior = dd.mixture.MoTG(ms=[x.m for x in posterior.xs],Ss=[x.S for x in posterior.xs],a = posterior.a,
                                flags=[2,2,2],upper=[17,17,5.05],lower=[-17,-17,0.2])

    # all pairwise marginals of fitted posterior
    fig, _ = plot_pdf(posterior, lims=[[-10, 10], [-10, 10], [0.25, 5]], gt=pars_true.reshape(-1),  
                      figsize=(16,16), resolution=100,labels_params=labels_params, ticks=True)
except:
    print('second test cell broke')


In [None]:
posterior = inf.predict(obs_stats)
posterior = dd.mixture.MoTG(ms=[x.m for x in posterior.xs],Ss=[x.S for x in posterior.xs],a = posterior.a,
                            flags=[2,2,2],upper=[17,17,5.05],lower=[-17,-17,0.2])

pars_raw = np.array( [ 7.2, 4.4, 1.6 ] )

fig, _ = plot_pdf(posterior, lims=[[-10, 10], [-10, 10], [0.25, 5]], gt=pars_raw.reshape(-1),  
                  figsize=(16,16), resolution=100,labels_params=labels_params, ticks=True)

try:
    pars_true2 = np.array([ 4.1,  5.2,  4.5])
    obs_stats = g.summary.calc([g.model.gen_single(pars_true2)])
    posterior = inf.predict(obs_stats)
    posterior = dd.mixture.MoTG(ms=[x.m for x in posterior.xs],Ss=[x.S for x in posterior.xs],a = posterior.a,
                                flags=[2,2,2],upper=[17,17,5.05],lower=[-17,-17,0.2])

    # all pairwise marginals of fitted posterior
    fig, _ = plot_pdf(posterior, lims=[[-10, 10], [-10, 10], [0.25, 5]], gt=pars_true.reshape(-1),  
                      figsize=(16,16), resolution=100,labels_params=labels_params, ticks=True)
except:
    print('second test cell broke')


# SNPE A

In [None]:
obs_stats = g.summary.calc([g.model.gen_single(pars_true)])
posterior = inf.predict(obs_stats)

# all pairwise marginals of fitted posterior
fig, _ = plot_pdf(posterior, lims=[[-10, 10], [-10, 10], [0.25, 5]], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                  labels_params=labels_params, ticks=True)

try:
    pars_true2 = np.array([ 4.1,  5.2,  4.5])
    obs_stats = g.summary.calc([g.model.gen_single(pars_true2)])
    posterior = inf.predict(obs_stats)

    # all pairwise marginals of fitted posterior
    fig, _ = plot_pdf(posterior, lims=[[-10, 10], [-10, 10], [0.25, 5]], gt=pars_true.reshape(-1), figsize=(16,16), resolution=100,
                      labels_params=labels_params, ticks=True)
except:
    print('second test cell broke')
