In [1]:
import pickle
import torch

from gbi.GBI import GBInference
from gbi.distances import mse_dist
import gbi.hh.utils as utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open("data/theta.pkl", "rb") as handle:
    theta = pickle.load(handle)

with open("data/summstats.pkl", "rb") as handle:
    x = pickle.load(handle)

In [3]:
nsims = 10_000

In [4]:
theta = theta[:nsims]
x = x[:nsims]

In [5]:
n_nonaug_x = nsims
n_augmented_x = nsims
noise_level = 2.0

In [6]:
x_aug = x[torch.randint(x.shape[0], size=(n_augmented_x,))]
x_aug = x_aug + torch.randn(x_aug.shape) * x.std(dim=0) * noise_level
x_target = torch.cat([x[:n_nonaug_x], x_aug])

In [7]:
true_params, labels_params = utils.obs_params(reduced_model=False)

prior = utils.prior(
    true_params=true_params,
    prior_uniform=True,
    prior_extent=True,
    prior_log=False,
    seed=0,
)

In [12]:
inference = GBInference(prior, mse_dist, do_precompute_distances=True)

In [13]:
inference = inference.append_simulations(theta, x, x_target)

In [14]:
inference.initialize_distance_estimator(3, 50)

In [15]:
distance_net = inference.train(
    training_batch_size=5_000,
    max_n_epochs=50,
    stop_after_counter_reaches=20,
    print_every_n=1,
    plot_losses=False,
)

0: train loss: 21103493120.000000, val loss: 21395851264.000000
1: train loss: 22961315840.000000, val loss: 21381824512.000000
2: train loss: 22647048192.000000, val loss: 21365981184.000000
3: train loss: 22572337152.000000, val loss: 21347833856.000000
4: train loss: 21458485248.000000, val loss: 21326751744.000000
5: train loss: 20096442368.000000, val loss: 21301848064.000000
6: train loss: 19771578368.000000, val loss: 21271871488.000000
7: train loss: 21761746944.000000, val loss: 21234984960.000000
8: train loss: 22722154496.000000, val loss: 21188538368.000000
9: train loss: 21683525632.000000, val loss: 21129003008.000000



KeyboardInterrupt

