In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pickle
import torch

from gbi.GBI import GBInference
from gbi.distances import mse_dist
import gbi.hh.utils as utils

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
with open("data/theta.pkl", "rb") as handle:
    theta = pickle.load(handle)

with open("data/summstats.pkl", "rb") as handle:
    x = pickle.load(handle)

In [4]:
nsims = 1_000_000

In [5]:
theta = theta[:nsims]
x = x[:nsims]

In [6]:
n_nonaug_x = nsims
n_augmented_x = nsims
noise_level = 2.0

In [7]:
x_aug = x[torch.randint(x.shape[0], size=(n_augmented_x,))]
x_aug = x_aug + torch.randn(x_aug.shape) * x.std(dim=0) * noise_level
x_target = torch.cat([x[:n_nonaug_x], x_aug])

In [8]:
true_params, labels_params = utils.obs_params(reduced_model=False)

prior = utils.prior(
    true_params=true_params,
    prior_uniform=True,
    prior_extent=True,
    prior_log=False,
    seed=0,
)

In [9]:
inference = GBInference(prior, mse_dist, do_precompute_distances=False)

In [10]:
inference = inference.append_simulations(theta, x, x_target)

In [11]:
inference.initialize_distance_estimator(3, 50)

In [None]:
distance_net = inference.train(
    training_batch_size=5_000,
    max_n_epochs=50,
    stop_after_counter_reaches=20,
    print_every_n=1,
    plot_losses=False,
)

0: train loss: 2003638272.000000, val loss: 2079056384.000000, 16.1006 seconds per epoch.
1: train loss: 1087765504.000000, val loss: 1240209280.000000, 16.2929 seconds per epoch.
2: train loss: 813593792.000000, val loss: 778252544.000000, 15.9289 seconds per epoch.
3: train loss: 672664768.000000, val loss: 580353408.000000, 16.4091 seconds per epoch.
4: train loss: 460065984.000000, val loss: 494529280.000000, 16.4309 seconds per epoch.
5: train loss: 402335648.000000, val loss: 376760896.000000, 16.2109 seconds per epoch.
6: train loss: 344716864.000000, val loss: 334997472.000000, 16.5731 seconds per epoch.
7: train loss: 284170560.000000, val loss: 296903360.000000, 16.5012 seconds per epoch.
8: train loss: 272512576.000000, val loss: 283559104.000000, 16.5289 seconds per epoch.
9: train loss: 213363856.000000, val loss: 233506448.000000, 16.0687 seconds per epoch.
10: train loss: 208039552.000000, val loss: 201923280.000000, 16.0522 seconds per epoch.
11: train loss: 198913296.0