In [1]:
import numpy as np
%matplotlib inline
from tqdm import tqdm_notebook as tqdm

import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.summarystats as ds

from lfimodels.glm.GLM import GLM
from lfimodels.glm.GLMStats import GLMStats
from delfi.utils.viz import plot_pdf

import lfimodels.glm.utils as utils

In [2]:
seed = 43
prior_prec = 1

#StudT 10, alpha=0.2, 100 samples

len_filter = 9
nrounds = 5

true_params, labels_params = utils.obs_params(len_filter=len_filter)
obs = utils.obs_data(true_params, seed=seed)
obs_stats = utils.obs_stats(true_params, seed=seed)

verbose = True

In [3]:
def run_sim(nrounds, N, alpha=0, convert_to_T=None, override_posteriors=None, seed=seed):
    m = GLM(len_filter=len_filter, seed=seed)
    p = utils.smoothing_prior(n_params=m.n_params, rel_prec=prior_prec, seed=seed)
    s = GLMStats(n_summary=m.n_params)
    g = dg.Default(model=m, prior=p, summary=s)

    res = infer.SNPE(g, 
               obs=obs_stats, 
               convert_to_T=convert_to_T,
               n_hiddens=[50], 
               seed=seed, 
               pilot_samples=100,
               svi=True,
               reg_lambda=0.01,
               prior_mixin=alpha,
               prior_norm=True,
               verbose=verbose)

    logs, tds, posteriors = res.run(n_train=N, 
                                    n_rounds=nrounds, 
                                    minibatch=100, 
                                    epochs=1000,
                                    round_cl=3)

    return { 'true_params' : true_params, 'logs' : logs, 'tds' : tds, 'posteriors' : posteriors }

In [8]:
data = run_sim(nrounds, 1000, alpha=0.015)
posteriors = data['posteriors']

Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





Widget Javascript not detected.  It may not be installed or enabled properly.





In [10]:
rdata = run_sim(nrounds, 1000, alpha=0, convert_to_T=None)#, override_posteriors = posteriors)
rposteriors = rdata['posteriors']




















































In [11]:
trn_data = data['tds']

for i in range(len(trn_data)):
    print(trn_data[i][2][:5])
    print(trn_data[i][2][-5:])
    print(trn_data[i][1][:5])
    print(trn_data[i][1][-5:])
    
    print(np.sort(trn_data[i][2])[:10])
    print(np.sort(trn_data[i][2])[-10:])
    print(np.mean(trn_data[i][2]))
    print(np.mean(trn_data[i][2] ** 2))


[ 1.  1.  1.  1.  1.]
[ 1.  1.  1.  1.  1.]
[[ 1.00547467  0.58898868  0.07463676 -0.38999564 -0.35683605 -0.59986616
  -1.39471387]
 [ 1.16958915 -0.37505685 -1.26664675 -1.22496111 -0.99797175 -1.05298591
  -0.85154873]
 [-0.69037488  0.53763671 -0.14601087  1.09486321  1.75517427  0.6583371
   0.97327479]
 [-1.12801348  0.03918023  0.32038117  0.84159274  0.98310231  0.95334365
   1.25104674]
 [-0.69037488 -0.62133607  0.42260229  0.82668888  1.29460774  1.278872
   1.02014799]]
[[ 0.95076985  0.12970086 -0.41568852 -0.85396312 -1.07273885 -0.98196579
  -1.07929685]
 [-1.12801348 -0.25368984  0.48339031  1.24321927  1.38780174  0.92008705
   1.11287667]
 [-0.30744111  1.10926216  1.85656025  1.59738477  1.25647008  0.24841989
   0.52957119]
 [-0.74507971  0.85393336  1.8012511   1.1904351   1.42608322  1.37370464
   0.12453665]
 [-0.52626041 -1.38281836 -0.17812545  1.13975912  1.93904515  1.26216398
   0.5784848 ]]
[ 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.]
[ 1.  1.  1.  1.  1.  1. 

In [12]:
rtrn_data = rdata['tds']

for i in range(len(rtrn_data)):
    print(rtrn_data[i][2][:5])
    print(rtrn_data[i][2][-5:])
    print(rtrn_data[i][1][:5])
    print(rtrn_data[i][1][-5:])
    
    
    print(np.sort(rtrn_data[i][2])[:10])
    print(np.sort(rtrn_data[i][2])[-10:])
    print(np.mean(rtrn_data[i][2]))
    print(np.mean(rtrn_data[i][2] ** 2))


[ 1.  1.  1.  1.  1.]
[ 1.  1.  1.  1.  1.]
[[ 1.00547467  0.58898868  0.07463676 -0.38999564 -0.35683605 -0.59986616
  -1.39471387]
 [ 1.16958915 -0.37505685 -1.26664675 -1.22496111 -0.99797175 -1.05298591
  -0.85154873]
 [-0.69037488  0.53763671 -0.14601087  1.09486321  1.75517427  0.6583371
   0.97327479]
 [-1.12801348  0.03918023  0.32038117  0.84159274  0.98310231  0.95334365
   1.25104674]
 [-0.69037488 -0.62133607  0.42260229  0.82668888  1.29460774  1.278872
   1.02014799]]
[[ 0.95076985  0.12970086 -0.41568852 -0.85396312 -1.07273885 -0.98196579
  -1.07929685]
 [-1.12801348 -0.25368984  0.48339031  1.24321927  1.38780174  0.92008705
   1.11287667]
 [-0.30744111  1.10926216  1.85656025  1.59738477  1.25647008  0.24841989
   0.52957119]
 [-0.74507971  0.85393336  1.8012511   1.1904351   1.42608322  1.37370464
   0.12453665]
 [-0.52626041 -1.38281836 -0.17812545  1.13975912  1.93904515  1.26216398
   0.5784848 ]]
[ 1.  1.  1.  1.  1.  1.  1.  1.  1.  1.]
[ 1.  1.  1.  1.  1.  1. 