In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from functools import partial, cache
from copy import deepcopy
from pathlib import Path

from typing import Callable, Optional

from tqdm.notebook import trange, tqdm

import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm
from matplotlib.ticker import MaxNLocator
import numpy as np

import scipy
import scipy.stats
import scipy.integrate

import torch
import torch.distributions

import pyro

import sbi
import sbi.analysis as analysis
from sbi.inference import SNRE_A, SNRE_B, prepare_for_sbi, simulate_for_sbi, infer
from sbi.utils.get_nn_models import classifier_nn
from sbi.inference.posteriors import MCMCPosterior
from sbi.inference.potentials.likelihood_based_potential import LikelihoodBasedPotential
from sbi.inference.potentials.ratio_based_potential import ratio_estimator_based_potential

import sbibm
from sbibm.algorithms.sbi.utils import wrap_posterior

import cnre

## generate data

In [3]:
task = sbibm.get_task("slcp")
prior = task.get_prior_dist()
simulator = task.get_simulator()
x_o = task.get_observation(1)
theta_o = task.get_true_parameters(1)
transform = task._get_transforms()

  m = torch.stack(


In [4]:
class ClosedFormPotential(sbi.inference.potentials.base_potential.BasePotential):
    def __init__(
        self, 
        fn: Callable,
        prior: torch.distributions.Distribution, 
        x_o: Optional[torch.Tensor] = None, 
        device: str = "cpu"
    ):
        super().__init__(prior, x_o, device)
        self.fn = fn
    
    def __call__(self, theta: torch.Tensor, track_gradients: bool = True) -> torch.Tensor:
        with torch.set_grad_enabled(track_gradients):
            return self.fn(theta)
    
    @property
    def allow_iid_x(self) -> bool:
        raise NotImplementedError

def create_chain(
    starting_theta: torch.Tensor, 
    potential_fn: Callable, 
    transform: Optional = None,
    thin: int = 10,
    num_workers: int = 1,
    device: Optional = None,
    x_shape: Optional = None, 
):
    """starting_theta [num_starting_pts, theta_dim] - we try to make this an array of samples drawn from the prior which then start each chain."""
    if transform is not None:
        starting_theta = transform(starting_theta)
    # delta = pyro.distributions.delta.Delta(starting_theta)
    posterior = MCMCPosterior(
        potential_fn=potential_fn,
        # proposal=starting_theta,  # should be delta, but the support of the prior has been transformed to the reals (so should the point)
        proposal=prior,  # should be delta, but the support of the prior has been transformed to the reals (so should the point)
        theta_transform=torch.distributions.transforms.identity_transform,  # this is the identity when running sbibm
        method="slice_np_vectorized",
        # method="slice",
        thin=thin,
        warmup_steps=0,
        num_chains=starting_theta.shape[0],
        init_strategy="proposal",
        # init_strategy="delta",
        num_workers=num_workers,
        device=device,
        x_shape=x_shape,
    )
    return wrap_posterior(posterior, transform)
    # samples = posterior.sample((self.num_posterior_samples,), x=observation).detach()
    # TODO figure out whether the prior and likelihood need to be transformed.

In [14]:
n = 5
d = theta_o.shape[1]
dd = x_o.shape[1]
starting_theta = torch.rand(1, d)

inference = SNRE_A(prior)
theta, x = simulate_for_sbi(simulator, prior, num_simulations=500)
ratio_estimator = inference.append_simulations(theta, x).train()

class Ratio(torch.nn.Module):
    def __init__(self, device):
        super().__init__()
        self.net = torch.nn.Sequential(torch.nn.Linear(10, 10))
        self.map = torch.randn(dd, d)
        self.net.to(device)
    
    def forward(self, w: list):
        """Inputs to (s)nre classifier must be a list containing raw theta and x."""
        # return -torch.sum(w[0] ** 2, dim=-1)
        theta, x = w
        assert isinstance(theta, torch.Tensor)
        assert len(theta.shape) == 2
        
        print(self.map)
        print(theta)
        print(x)
        t = torch.einsum("ij,bj->bi", self.map, theta)
        return (t - x).pow(2).neg().sum(dim=-1, keepdims=True)

ratio = Ratio("cpu")
    
potential, _ = ratio_estimator_based_potential(
    ratio, 
    prior, 
    x_o=x_o, 
    enable_transform=False
)
p = create_chain(starting_theta, potential)

Running 500 simulations.:   0%|          | 0/500 [00:00<?, ?it/s]

 Neural network successfully converged after 45 epochs.

In [15]:
p_built = inference.build_posterior(mcmc_method="slice_np_vectorized")
p_built = wrap_posterior(p_built, transform)

In [16]:
aaa = torch.rand(2, d)
p.flow.potential_fn.set_x(x_o)
print(p.flow.potential_fn(aaa))

p_built.flow.potential_fn.set_x(x_o)
print(p_built.flow.potential_fn(aaa))

tensor([[ 0.2899,  1.0014, -1.7746, -2.5819,  0.2798],
        [ 1.4384, -0.0464, -0.0704,  0.9918,  1.0043],
        [-2.4651, -0.6654,  0.4301, -1.8725,  0.3934],
        [-0.3969, -2.9779,  0.3740,  1.5982,  0.4220],
        [ 0.5613,  0.1550, -0.2588, -0.2382,  0.5420],
        [ 1.8282,  0.4050, -1.5480, -0.6493,  1.0415],
        [-1.8636, -0.2498,  1.8554,  0.2767, -0.3876],
        [-0.4782,  0.6687, -0.0266, -0.2716, -0.3646]])
tensor([[0.5864, 0.4337, 0.7282, 0.3138, 0.8866],
        [0.6034, 0.6280, 0.4096, 0.2352, 0.1128]])
tensor([[  2.3719,   0.4995,   9.9314,   1.7137, -10.4364,  -1.9068,  -1.2344,
          -0.0974],
        [  2.3719,   0.4995,   9.9314,   1.7137, -10.4364,  -1.9068,  -1.2344,
          -0.0974]])
tensor([-293.9955, -296.8475])
tensor([-9.1700, -9.0978], grad_fn=<AddBackward0>)


In [17]:
from joblib import parallel_backend
with parallel_backend('threading', n_jobs=1):
    p.sample((10,), x=x_o)
# For some reason, this one has shape issues but the one above does not. it could be due to the distribution i'm using (but I don't think so)
# maybe the ratio estiamtor is bad somehow.

# maybe the potential funciton requires a default x and this one doesn't give it? or maybe it needs an option to give an x?

# After many tests ( and the creation of this Ratio class above), I think the issue lies in the log_prob being too high dimensional.
# I think the way they create this potential function requires some scrutiny. Especially the setting of the x_o.

## p.s. don't forget you have a stashed update to the repo. (not sure if it's important.) I think it was created before you did all the merging.

Running vectorized MCMC with 5 chains:   0%|          | 0/100 [00:00<?, ?it/s]

tensor([[ 0.2899,  1.0014, -1.7746, -2.5819,  0.2798],
        [ 1.4384, -0.0464, -0.0704,  0.9918,  1.0043],
        [-2.4651, -0.6654,  0.4301, -1.8725,  0.3934],
        [-0.3969, -2.9779,  0.3740,  1.5982,  0.4220],
        [ 0.5613,  0.1550, -0.2588, -0.2382,  0.5420],
        [ 1.8282,  0.4050, -1.5480, -0.6493,  1.0415],
        [-1.8636, -0.2498,  1.8554,  0.2767, -0.3876],
        [-0.4782,  0.6687, -0.0266, -0.2716, -0.3646]])
tensor([[-1.3040,  1.7884, -1.2323, -2.4065, -2.3238],
        [-2.3774, -2.8776, -0.3343, -0.4853,  2.7386],
        [ 0.6703,  1.0233, -2.6752, -1.6807, -1.7447],
        [ 1.9382, -0.7606, -2.2653,  1.3156, -1.2346],
        [-2.4845, -0.2214,  2.4168,  2.6739, -2.5722]])
tensor([[  2.3719,   0.4995,   9.9314,   1.7137, -10.4364,  -1.9068,  -1.2344,
          -0.0974],
        [  2.3719,   0.4995,   9.9314,   1.7137, -10.4364,  -1.9068,  -1.2344,
          -0.0974],
        [  2.3719,   0.4995,   9.9314,   1.7137, -10.4364,  -1.9068,  -1.2344,
      

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

In [None]:
raise NotImplementedError()

In [None]:
def run(prior, num_observation, observation_seed, **kwargs):
    np.random.seed(observation_seed)
    torch.manual_seed(observation_seed)

    prior = self.get_prior()
    true_parameters = prior(num_samples=1)

    simulator = self.get_simulator()
    observation = simulator(true_parameters)

    if create_reference:
        reference_posterior_samples = self._sample_reference_posterior(
            num_observation=num_observation,
            num_samples=self.num_reference_posterior_samples,
            **kwargs,
        )
        num_unique = torch.unique(reference_posterior_samples, dim=0).shape[0]
        assert num_unique == self.num_reference_posterior_samples
        self._save_reference_posterior_samples(
            num_observation,
            reference_posterior_samples,
        )

Parallel(n_jobs=n_jobs, verbose=50, backend="loky")(
    delayed(run)(num_observation, observation_seed, **kwargs)
    for num_observation, observation_seed in enumerate(
        self.observation_seeds, start=1
    )
)

## setup problem

In [None]:
num_total_samples = 2 ** 13
validation_fraction=2 ** -2
num_training_samples = int((1 - validation_fraction) * num_total_samples)
num_validation_samples = int(validation_fraction * num_total_samples)
num_posterior_samples = 3_000

training_batch_size=2 ** 10
learning_rate=3e-4
stop_after_epochs=2 ** 31 - 1
# max_num_epochs=1_000  # We want the network to see the same number of batches no matter how much data we provide
num_batches_to_see = 8_000
clip_max_norm=None

classifier_kwargs = dict(
    model='resnet', 
    hidden_features=50, 
    num_blocks=2,
    dropout_probability=0.0,
    use_batch_norm=True
)
sample_with = "rejection"

device = "cpu"
root = Path("figures")

In [None]:
kind = "gaussian"
if kind == "parabola":
    dimension = 3
    low = -torch.ones(dimension).mul(3.).sqrt().to(device)
    high = torch.ones(dimension).mul(3.).sqrt().to(device)
    limits = torch.stack([low.cpu(), high.cpu()], dim=-1).numpy()
    prior = sbi.utils.BoxUniform(low=low, high=high, device=device)
    task = cnre.Parabola(scale=0.1)
    simulate = task.simulate
    simulator, prior = prepare_for_sbi(simulate, prior)
    true_theta = torch.ones(dimension).to(device).unsqueeze(0)
    observation = simulator(true_theta)
elif kind == "gaussian":
    dimension = 3
    low = -torch.ones(dimension).mul(3.).sqrt().to(device)
    high = torch.ones(dimension).mul(3.).sqrt().to(device)
    limits = torch.stack([low.cpu(), high.cpu()], dim=-1).numpy()
    prior = sbi.utils.BoxUniform(low=low, high=high, device=device)
    task = cnre.Gaussian(scale=0.1)
    simulate = task.simulate
    simulator, prior = prepare_for_sbi(simulate, prior)
    true_theta = torch.ones(dimension).to(device).unsqueeze(0)
    observation = simulator(true_theta)
elif kind == "slcp":
    tt = sbibm.get_task("slcp")
    num_observation = 1
    simulator = tt.get_simulator()
    prior = tt.get_prior_dist()
    dimension = tt.dim_data
    observation = tt.get_observation(num_observation)
    true_theta = tt.get_true_parameters(num_observation)
else:
    raise NotImplementedError


theta = prior.sample((num_total_samples,)).to(device)
x = simulator(theta).to(device)

In [None]:
get_classifier = classifier_nn(**classifier_kwargs)
classifier = get_classifier(theta[:training_batch_size, ...], x[:training_batch_size, ...])
classifier.to(device)

# optimizer = torch.optim.SGD(classifier.parameters(), lr=lr, momentum=momentum)
optimizer = torch.optim.Adam(classifier.parameters(), lr=learning_rate)

## experiment

In [None]:
dataset = torch.utils.data.TensorDataset(theta, x)
train_set, valid_set = torch.utils.data.random_split(dataset, [num_training_samples, num_validation_samples])
train_loader = torch.utils.data.DataLoader(train_set, training_batch_size, shuffle=True, drop_last=True)
valid_loader = torch.utils.data.DataLoader(valid_set, training_batch_size, drop_last=True)

In [None]:
max_num_epochs = round(num_batches_to_see / len(train_loader))
max_num_epochs

In [None]:
torch.manual_seed(0)
l1 = cnre.loss_bce(classifier, theta[:256], x[:256])
torch.manual_seed(0)
l2 = cnre.loss(classifier, theta[:256], x[:256], 2, gamma=1.0, reuse=True)

# The issue here is that the BCE takes the mean of more terms, my version doesn't do that.
# l2 - l1.reshape(-1, 2).sum(dim=-1) * 0.5
# l2.mean() - l1.mean()
l2 - l1

In [None]:
def doit(num_atoms: int):
    gamma = 1.0
    results = cnre.algorithms.cnre.train(
        classifier, 
        optimizer, 
        max_num_epochs, 
        train_loader, 
        valid_loader, 
        num_atoms,
        gamma=gamma,
    )
    name = f"num atoms {num_atoms}"
    plt.plot(results['valid_losses'])
    plt.title(name)
    classifier.load_state_dict(results["best_network_state_dict"])
    posterior = cnre.get_sbi_posterior(
        ratio_estimator=classifier,
        prior=prior,
        sample_with="rejection", 
        mcmc_method="slice_np",
        mcmc_parameters={},
        rejection_sampling_parameters={},
    )
    samples = posterior.sample((num_posterior_samples,), x=observation.cpu())
    fig, _ = analysis.pairplot(
        samples.cpu().numpy(), 
        figsize=(6,6), 
        points=true_theta.cpu().numpy(),
        title=name,
        # limits=limits,
    )

In [None]:
name = f"ref"
fig, _ = analysis.pairplot(
    tt.get_reference_posterior_samples(num_observation), 
    figsize=(6,6), 
    points=true_theta.cpu().numpy(),
    title=name,
    # limits=limits,
)

### gilles

In [None]:
doit(num_atoms=2)

### ours

In [None]:
doit(num_atoms=10)