In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
import logging
from functools import partial
from copy import deepcopy
from toolz import compose, valmap, keyfilter, identity, merge
from itertools import combinations
import pickle

from joblib import Parallel, delayed, parallel_backend

import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import torch

import sbi
import sbi.utils
import sbi.inference
import sbi.analysis
import sbibm

import swyft

import tmnre
import tmnre.metrics
import tmnre.benchmark
import tmnre.coverage.oned
from tmnre.nn.resnet import make_resenet_tail
from tmnre.marginalize import filter_marginals_by_dim

In [3]:
log = logging.getLogger()
log.setLevel(logging.INFO)
logging.getLogger('matplotlib').setLevel(logging.WARNING)

np.random.seed(28);
torch.manual_seed(28);

In [4]:
SAVE = False

In [5]:
TASK_NAME = "eggbox"
NUM_OBS = 1
DIM = 10
SEQDIM = 2
N_SIMULATIONS = 10_000
N_POSTERIOR_SAMPLES = 25_000
N_JOBS = 12

task = sbibm.get_task(
    TASK_NAME,
    dim=DIM,
)

marginal_2d_inds = list(combinations(range(task.dim_parameters), 2))
marginal_1d_inds = list(combinations(range(task.dim_parameters), 1))
N_SIMS_PER_MARGINAL = N_SIMULATIONS // (len(marginal_1d_inds) + len(marginal_2d_inds))
print(N_SIMS_PER_MARGINAL)

theta0 = task.get_true_parameters(NUM_OBS).squeeze()
obs0 = task.get_observation(NUM_OBS).squeeze()
print(theta0)
print(obs0)

181
tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
        0.2500])
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000])


In [6]:
def get_marginal_simulator():
    simulator = task.get_simulator()
    dim = task.dim_parameters
    
    def simulate(parameters):
        *b, p = parameters.shape
        non_seq_p = dim - p
        if b:
            extra_params = torch.rand(b, non_seq_p)
        else:
            extra_params = torch.rand(non_seq_p)
        total_parameters = torch.cat([parameters, extra_params], dim=-1)
        return simulator(total_parameters)
    
    return simulate

def get_snre_posterior_samples(n_simulations, n_posterior_samples, prior, sim):
    num_rounds = 10
    n_per_round = n_simulations // 10
    posteriors = []
    inference = sbi.inference.SNRE_A(prior=prior)
    proposal = prior
    print(proposal.sample())
    for r in range(10):
        theta, x = sbi.inference.simulate_for_sbi(sim, proposal, num_simulations=n_per_round, num_workers=4)
        density_estimator = inference.append_simulations(theta, x, from_round=r).train()
        posterior = inference.build_posterior(density_estimator)
        posteriors.append(posterior)
        proposal = posterior.set_default_x(torch.atleast_2d(obs0))
    return posterior.sample(
        (int(n_posterior_samples),), 
        x=torch.atleast_2d(obs0)
    )

In [7]:
prior_2d_min = torch.zeros(SEQDIM)
prior_2d_max = torch.ones(SEQDIM)
sbi_2d_prior = sbi.utils.torchutils.BoxUniform(
    low=torch.as_tensor(prior_2d_min), 
    high=torch.as_tensor(prior_2d_max)
)
seq_2d_sim, seq_2d_prior = sbi.inference.prepare_for_sbi(
    get_marginal_simulator(), 
    sbi_2d_prior,
)

In [8]:
if SAVE:
    with parallel_backend("loky", inner_max_num_threads=1):
        result = Parallel(n_jobs=N_JOBS)(
            delayed(get_snre_posterior_samples)(N_SIMS_PER_MARGINAL, N_POSTERIOR_SAMPLES, seq_2d_prior, seq_2d_sim)
            for _ in marginal_2d_inds
        )
    marginals_2d = {ind: res for ind, res in zip(marginal_2d_inds, result)}
else:
    pass

In [9]:
prior_1d_min = torch.zeros(SEQDIM - 1)
prior_1d_max = torch.ones(SEQDIM - 1)
sbi_1d_prior = sbi.utils.torchutils.BoxUniform(
    low=torch.as_tensor(prior_1d_min), 
    high=torch.as_tensor(prior_1d_max)
)
seq_1d_sim, seq_1d_prior = sbi.inference.prepare_for_sbi(
    get_marginal_simulator(), 
    sbi_1d_prior,
)

In [10]:
if SAVE:
    with parallel_backend("loky", inner_max_num_threads=1):
        result = Parallel(n_jobs=N_JOBS)(
            delayed(get_snre_posterior_samples)(N_SIMS_PER_MARGINAL, N_POSTERIOR_SAMPLES, seq_1d_prior, seq_1d_sim)
            for _ in marginal_1d_inds
        )
    marginals_1d = {ind: res for ind, res in zip(marginal_1d_inds, result)}
else:
    pass

In [11]:
path = f"eggbox-seq-marg-{SEQDIM:02d}-marginals.pickle"
if SAVE:
    with open(path, "wb") as f:
        marginals = merge(marginals_1d, marginals_2d)
        pickle.dump(marginals, f)
else:
    with open(path, "rb") as f:
        marginals = pickle.load(f)