<a href="https://colab.research.google.com/github/jecampagne/ML-toys/blob/main/Test_SBI_Pytorch_UserExo_v0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [102]:
!python --version   # >= 3.7

Python 3.7.15


In [103]:
!pip install --upgrade --quiet sbi

In [3]:
# Replace version 3.2 by 3.5
!pip uninstall -y matplotlib
!pip install -q matplotlib==3.5.0
# Needs Re-initialize the environment then pass to the next cell

Found existing installation: matplotlib 3.2.2
Uninstalling matplotlib-3.2.2:
  Successfully uninstalled matplotlib-3.2.2
[K     |████████████████████████████████| 11.2 MB 4.4 MB/s 
[K     |████████████████████████████████| 42 kB 1.0 MB/s 
[K     |████████████████████████████████| 960 kB 44.4 MB/s 
[?25h

In [104]:
import matplotlib as mpl
import matplotlib.pyplot as plt

print(mpl.__version__)

3.5.0


In [105]:
import torch
import numpy as np

from sbi.inference import SNPE, SNLE, SNRE, prepare_for_sbi, simulate_for_sbi
from sbi.utils.get_nn_models import posterior_nn
from sbi import utils as utils

In [106]:
# No speed up with GPU as stated in sbi code
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [107]:
import pyro


In [108]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [109]:
cd /content/gdrive/MyDrive/

/content/gdrive/MyDrive


In [110]:
mkdir SBIimg

mkdir: cannot create directory ‘SBIimg’: File exists


In [111]:
cd /content/gdrive/MyDrive/SBIimg

/content/gdrive/MyDrive/SBIimg


In [112]:
pwd

'/content/gdrive/MyDrive/SBIimg'

In [113]:
import arviz as az
import matplotlib.patches as mpatches

def overplot_lines(axes, xs, reverse=False, **kwargs):
    """
    Overplot lines on a figure generated by ``corner.corner``
    Parameters
    ----------
    fig : Figure
        The figure generated by a call to :func:`corner.corner`.
    xs : array_like[ndim]
       The values where the lines should be plotted. This must have ``ndim``
       entries, where ``ndim`` is compatible with the :func:`corner.corner`
       call that originally generated the figure. The entries can optionally
       be ``None`` to omit the line in that axis.
    reverse: bool
       A boolean flag that should be set to 'True' if the corner plot itself
       was plotted with 'reverse=True'.
    **kwargs
        Any remaining keyword arguments are passed to the ``ax.axvline``
        method.
    """
    K = len(xs)
    if reverse:
        for k1 in range(K):
            if xs[k1] is not None:
                axes[K - k1 - 1, K - k1 - 1].axvline(xs[k1], **kwargs)
            for k2 in range(k1 + 1, K):
                if xs[k1] is not None:
                    axes[K - k2 - 1, K - k1 - 1].axvline(xs[k1], **kwargs)
                if xs[k2] is not None:
                    axes[K - k2 - 1, K - k1 - 1].axhline(xs[k2], **kwargs)

    else:
        for k1 in range(K):
            if xs[k1] is not None:
                axes[k1, k1].axvline(xs[k1], **kwargs)
            for k2 in range(k1 + 1, K):
                if xs[k1] is not None:
                    axes[k2, k1].axvline(xs[k1], **kwargs)
                if xs[k2] is not None:
                    axes[k2, k1].axhline(xs[k2], **kwargs)

def plot_params_kde(samples,hdi_probs=[0.393, 0.865, 0.989], 
                    patName=None, fname=None, pcut=None, reference_values=None, 
                    reference_color='k', label_size=10,labeller=None, limits=None,
                   var_names=None, point_estimate="median", figsize=(8,8)):
    """
     limts = [[min_1,max_1], ... , [min_N, max_N]] N varaibles
    """
        
    if pcut is not None:
        low = pcut[0]
        up  = pcut[1] 
        #keep only data in the [low, up] percentiles ex. 0.5, 99.5
        samples={name:value[(value>np.percentile(value,low)) &  (value<np.percentile(value,up))] \
          for name, value in samples.items()}
        len_min = np.min([len(value) for name, value in samples.items()])
        len_max = np.max([len(value) for name, value in samples.items()])
        if (len_max-len_min)>0.01*len_max:
            print(f"Warning: pcut leads to min/max spls size = {len_min}/{len_max}")
        samples = {name:value[:len_min] for name, value in samples.items()}
    
    axs= az.plot_pair(
            samples,
            var_names=var_names,
            kind="kde",
            labeller=labeller,
            figsize=figsize,
            marginal_kwargs={"plot_kwargs": {"linewidth": 2, "c": "b"}},
            kde_kwargs={
#                "hdi_probs": [0.68, 0.9],  # Plot 68% and 90% HDI contours
                "hdi_probs":hdi_probs,  # 1, 2 and 3 sigma contours
                "contour_kwargs":{"colors":('r', 'green', 'blue'), "linewidth":2},
                "contourf_kwargs":{"alpha":0},
            },
            point_estimate_kwargs={"lw": 2, "c": "b"},
            marginals=True, textsize=label_size, point_estimate=point_estimate,
            reference_values=reference_values, reference_values_kwargs={"c":reference_color}
        );
    
    if reference_values is not None:
      overplot_lines(axs,list(reference_values.values()), color=reference_color)
      if limits is not None:
        assert len(limits) == len(samples.keys()), "wrong number of limits"
        for i in range(0,axs.shape[0]):
          for j in range(0,i+1):
            if j == i:
              axs[i,i].set_xlim(limits[i])
            else:
              axs[i,j].set_xlim(limits[j])
              axs[i,j].set_ylim(limits[i])

    plt.tight_layout()
    
    if patName is not None:
#        patName_patch = mpatches.Patch(color='b', label=patName)
#        axs[0,0].legend(handles=[patName_patch], fontsize=40, bbox_to_anchor=(1, 0.7));
      fig = axs[0,0].get_figure()
      fig.suptitle(patName)
    if fname is not None:
        plt.savefig(fname)
        plt.close()
    else:
        plt.show();




In [114]:
#Let us define a common uniform prior for each parameter
theta_dim = 3
prior = utils.BoxUniform(low=-3 * torch.ones(theta_dim), high=3 * torch.ones(theta_dim))
limits = [[-3,3] for i in range(theta_dim)]

In [115]:
import pyro.distributions as dist

In [116]:
tMes = torch.linspace(0,1,10)

In [125]:
def _calc_vars(theta):
  #fixed positions (tMes) & fixed independant error
  mu = theta[...,0] + theta[...,1]*tMes + theta[...,2]*tMes*tMes
  sigma = 0.05
  return mu, sigma
def simulator(theta):
  mu,sigma  = _calc_vars(theta)
  x = dist.Normal(loc=mu,scale=sigma).sample()
  return x



In [133]:
true_theta = np.array([0.7, -1.9, 1.5])
pyro.set_rng_seed(0)

In [134]:
# Let us define one observation of the following true underlaying theta param
x_o = simulator(torch.tensor(true_theta))[None]

In [135]:
x_o

tensor([[0.7770, 0.4927, 0.2429, 0.2618, 0.0976, 0.0375, 0.1202, 0.1715, 0.1603,
         0.2798]])

In [136]:
#adapt/check the prior & simulator for SBI
simulator, prior = prepare_for_sbi(simulator, prior)

# 1 round Optimisation

In [137]:
keys = ["t"+str(i) for i in range(theta_dim)]
truth = dict(zip(keys,true_theta))


In [139]:
# 1 pass / method => 13min in total
for name in ["SNPE","SNLE","SNRE"]:
  print("Inference :",name)
  if name == "SNPE":
    inference = SNPE(prior=prior) #"SNPE" as SNPE_C 
  elif name == "SNLE":
    inference = SNLE(prior=prior) #"SNLE" as SNLE_A 
  elif name == "SNRE":
    infer_SNRE = SNRE(prior=prior) #"SNRE" as SNRE_B

  #siumlate
  num_sim = 10_000
  theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=num_sim)
  inference = inference.append_simulations(theta, x)
  #train the inference network/flow paramater
  density_estimator = inference.train()
  #build the posterior estimate p(theta|X)
  posterior = inference.build_posterior(density_estimator)
  #Sample the posterior with the constraint x=x_o
  if isinstance(inference,SNLE) or isinstance(inference,SNRE):
    spls = posterior.sample((10_000,), x=x_o, num_chains=100,  method="slice_np_vectorized")
  else:
    spls = posterior.sample((10_000,), x=x_o)  

  #plot
  values = [spls[:,i]for i in range(theta_dim)]
  data = dict(zip(keys,values))
  plot_params_kde(data,var_names=keys, figsize=(8,8), limits=None,
    point_estimate=None, reference_values=truth, reference_color='r',
    patName=name, fname='./user_'+name+'_1obs_1round_limitsOff.pdf');
  

Inference : SNPE


Running 10000 simulations.:   0%|          | 0/10000 [00:00<?, ?it/s]

 Neural network successfully converged after 71 epochs.

Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)
  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)


Inference : SNLE


Running 10000 simulations.:   0%|          | 0/10000 [00:00<?, ?it/s]

 Neural network successfully converged after 81 epochs.

Running vectorized MCMC with 100 chains:   0%|          | 0/110000 [00:00<?, ?it/s]

  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)
  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)


Inference : SNRE


Running 10000 simulations.:   0%|          | 0/10000 [00:00<?, ?it/s]

 Neural network successfully converged after 55 epochs.

Running vectorized MCMC with 100 chains:   0%|          | 0/110000 [00:00<?, ?it/s]

  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)
  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)


# Multi rounds... 
not necessary with this example and takes 1h, but it is just to see how to proceed

In [56]:
# multi rounds: first round simulates from the prior, second round simulates parameter set
# that were sampled from the obtained posterior.
num_rounds = 3 # => ~ 1h for the all 3 methods
# The specific observation we want to focus the inference on is x_o (single)
num_simu = 10_000


simulator, prior = prepare_for_sbi(simulator, prior)
for name in ["SNPE", "SNLE","SNRE"]:
  print("Inference :",name)
  if name == "SNPE":
    inference = SNPE(prior=prior, device=device) #"SNPE" as SNPE_C 
  elif name == "SNLE":
    inference = SNLE(prior=prior, device=device) #"SNLE" as SNLE_A 
  elif name == "SNRE":
    infer_SNRE = SNRE(prior=prior, device=device) #"SNRE" as SNRE_B

  posteriors = []
  proposal = prior

  for i in range(num_rounds):
    if i==0:
      theta, x = simulate_for_sbi(simulator, proposal, num_simulations=num_simu)
    else:
      if isinstance(inference,SNPE):
        theta = proposal.sample((num_simu,))
      else:
        theta = proposal.sample((num_simu,), method="slice_np_vectorized", num_chains=100)
      x = simulator(theta)

    # In `SNLE` and `SNRE`, you should not pass the `proposal` to `.append_simulations()`
    if isinstance(inference,SNPE):
      density_estimator = inference.append_simulations(
          theta, x, proposal=proposal
      ).train()
    else:
      density_estimator = inference.append_simulations(
          theta, x
      ).train()

    posterior = inference.build_posterior(density_estimator)
    posteriors.append(posterior)  
    proposal = posterior.set_default_x(x_o)

  # sample with tuned posterior
  if isinstance(inference,SNLE) or isinstance(inference,SNRE):
      spls = posterior.sample((10_000,), x=x_o, num_chains=100,  method="slice_np_vectorized")
  else:
      spls = posterior.sample((10_000,), x=x_o)  


  values = [spls[:,i]for i in range(theta_dim)]
  data = dict(zip(keys,values))
  np.save("./user_"+name+"_data.npy",np.array(values))

  plot_params_kde(data,var_names=keys, figsize=(8,8), limits=None,
      point_estimate=None, reference_values=truth, reference_color='r',
      patName=name, fname='./user_'+name+'_1obs_'+str(num_rounds)+'rounds_limitsOff.pdf');


Inference : SNPE


Running 10000 simulations.:   0%|          | 0/10000 [00:00<?, ?it/s]

 Neural network successfully converged after 113 epochs.

Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

Using SNPE-C with atomic loss
 Neural network successfully converged after 42 epochs.

Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

Using SNPE-C with atomic loss
 Neural network successfully converged after 28 epochs.

Drawing 10000 posterior samples:   0%|          | 0/10000 [00:00<?, ?it/s]

  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)
  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)


Inference : SNLE


Running 10000 simulations.:   0%|          | 0/10000 [00:00<?, ?it/s]

 Neural network successfully converged after 78 epochs.

Running vectorized MCMC with 100 chains:   0%|          | 0/110000 [00:00<?, ?it/s]

 Neural network successfully converged after 33 epochs.

Running vectorized MCMC with 100 chains:   0%|          | 0/110000 [00:00<?, ?it/s]

 Neural network successfully converged after 41 epochs.

Running vectorized MCMC with 100 chains:   0%|          | 0/110000 [00:00<?, ?it/s]

  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)
  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)


Inference : SNRE


Running 10000 simulations.:   0%|          | 0/10000 [00:00<?, ?it/s]

 Neural network successfully converged after 48 epochs.

Running vectorized MCMC with 100 chains:   0%|          | 0/110000 [00:00<?, ?it/s]

 Neural network successfully converged after 25 epochs.

Running vectorized MCMC with 100 chains:   0%|          | 0/110000 [00:00<?, ?it/s]

 Neural network successfully converged after 34 epochs.

Running vectorized MCMC with 100 chains:   0%|          | 0/110000 [00:00<?, ?it/s]

  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)
  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)


# Direct True Likelihood  sampling

## Try with Pyro: takes too long 

In [169]:
def wrapper_log_posterior(theta):

  def __calc_vars(theta):
    #fixed positions (tMes) & fixed independant error
    mu = theta[...,0] + theta[...,1]*tMes + theta[...,2]*tMes*tMes
    sigma = 0.05
    return mu, sigma
  
  theta = theta['t']
  mu, Sigma = __calc_vars(theta)
  log_post = dist.Normal(mu, Sigma).log_prob(x_o).sum() + prior.log_prob(theta).sum()
  return -log_post


In [170]:
wrapper_log_posterior({'t': torch.zeros(size=(1,theta_dim))})

tensor(211.1931)

In [171]:
kernel = pyro.infer.NUTS(
    potential_fn=wrapper_log_posterior
)


In [172]:
num_chains = 1
mcmc = pyro.infer.MCMC(
        kernel,
        num_samples=10_000,
        warmup_steps=500,
        num_chains=num_chains,
        initial_params={'t': torch.zeros(size=(num_chains,theta_dim))}
)

In [173]:
mcmc.run()

Sample: 100%|██████████| 10500/10500 [04:21, 40.16it/s, step size=7.44e-02, acc. prob=0.954]


In [146]:
mcmc.summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
    t[0,0]      0.72      0.08      0.72      0.60      0.85   2451.57      1.00
    t[0,1]     -2.11      0.35     -2.11     -2.71     -1.55   2299.28      1.00
    t[0,2]      1.70      0.34      1.70      1.16      2.28   2467.42      1.00

Number of divergences: 281


In [174]:
spls_0 = mcmc.get_samples()['t'].numpy().squeeze()
keys = ["t"+str(i) for i in range(theta_dim)]
values = [spls_0[:,i]for i in range(theta_dim)]
data_0 = dict(zip(keys,values))


In [175]:
plot_params_kde(data_0, figsize=(5,5),limits=None,
  point_estimate=None, reference_values=truth, reference_color='r',
  fname='./user0_true_posterior_pyro.pdf');

  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)
  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)


## Use JAX/Numpyro for fun

In [149]:
! pip install -q numpyro

[?25l[K     |█▏                              | 10 kB 21.5 MB/s eta 0:00:01[K     |██▎                             | 20 kB 5.5 MB/s eta 0:00:01[K     |███▍                            | 30 kB 7.8 MB/s eta 0:00:01[K     |████▌                           | 40 kB 3.3 MB/s eta 0:00:01[K     |█████▋                          | 51 kB 3.5 MB/s eta 0:00:01[K     |██████▊                         | 61 kB 4.1 MB/s eta 0:00:01[K     |███████▉                        | 71 kB 4.3 MB/s eta 0:00:01[K     |█████████                       | 81 kB 4.9 MB/s eta 0:00:01[K     |██████████                      | 92 kB 4.9 MB/s eta 0:00:01[K     |███████████▏                    | 102 kB 4.0 MB/s eta 0:00:01[K     |████████████▎                   | 112 kB 4.0 MB/s eta 0:00:01[K     |█████████████▍                  | 122 kB 4.0 MB/s eta 0:00:01[K     |██████████████▌                 | 133 kB 4.0 MB/s eta 0:00:01[K     |███████████████▊                | 143 kB 4.0 MB/s eta 0:00:01[K    

In [150]:
import jax
import jax.numpy as jnp
import numpyro

In [176]:
seed = 1234
rng, model_rng, hmc_rng = jax.random.split(jax.random.PRNGKey(seed), num=3)

In [177]:
j_xo = jnp.array(x_o.numpy())

In [178]:
j_xo

DeviceArray([[0.7770498 , 0.49273592, 0.24291235, 0.2617549 , 0.09762572,
              0.03747765, 0.12016737, 0.17153093, 0.1603334 , 0.27983278]],            dtype=float32)

In [179]:
def SmoothedBoxPrior(theta_dim=5, lower=0.0, upper=1.0, sigma=0.1, variance=False):
    assert jnp.all(lower < upper), "lower must be less than upper"
    assert jnp.all(sigma > 0), "sigma must be greater than zero"
    assert jnp.logical_xor(sigma, variance), "specify only one of sigma and variance"

    if not variance:
        variance = sigma ** 2
    _center = (upper + lower) / 2.0
    _range = (upper - lower) / 2.0

    def log_prob(theta):
        """Inspired by SmoothedBoxPrior From GPyTorch
        If theta is inside the bounds, return constant.
        If theta is outside the bounds, return log prob from sharp normal
        Can accomplish this saying the distance from the edges of the theta range
        is sampled from a normal distribution (clipped at zero to not go negative)
        """
        _theta_dist = jnp.clip(jnp.abs(theta - _center) - _range, 0, None)
        return -0.5 * (_theta_dist ** 2 / variance + jnp.log(2 * jnp.pi * variance))

    def sample(rng, num_samples: int = 1):
        """
        Samples are taken from a hard uniform distribution between the bounds
        """
        return jax.random.uniform(
            rng, shape=(num_samples, theta_dim), minval=lower, maxval=upper
        )

    return log_prob, sample

In [180]:
# set up prior
log_prior, sample_prior = SmoothedBoxPrior(
    theta_dim=theta_dim, lower=-3.0, upper=3.0, sigma=0.02
)


In [181]:
j_tMes = jnp.array(tMes.numpy())

In [182]:
def _jax_calc_vars(theta: jnp.array):
    #fixed positions (tMes) & fixed independant error
    mu = theta[...,0] + theta[...,1]*j_tMes + theta[...,2]*j_tMes*j_tMes
    sigma = 0.05
    return mu, sigma

In [189]:
?jax.scipy.stats.norm.logpdf

In [190]:
def jax_log_likelihood(x: jnp.array, theta: jnp.array):
    """
    Calculate the log likelihood of the data given the posterior.
    """
    mu, sigma = _jax_calc_vars(theta)
    return jax.scipy.stats.norm.logpdf(x, loc=mu, scale=sigma).sum()

def jax_wrapper_log_posterior(theta):
    log_post = jax_log_likelihood(j_xo, theta) + log_prior(theta).sum()
    return -log_post


In [191]:
num_chain = 1
kernel = numpyro.infer.NUTS(
    potential_fn=jax_wrapper_log_posterior, dense_mass=True
)

mcmc = numpyro.infer.MCMC(
        kernel,
        num_samples=10_000,
        num_warmup=500,
        chain_method='vectorized',
        num_chains=num_chains,
        progress_bar=True,
)

In [192]:
mcmc.run(hmc_rng, init_params=jnp.zeros_like(true_theta))

sample: 100%|██████████| 10500/10500 [00:19<00:00, 552.33it/s, 7 steps of size 7.18e-01. acc. prob=0.93]


In [193]:
mcmc.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
Param:0[0]      0.73      0.04      0.73      0.66      0.79   7357.54      1.00
Param:0[1]     -2.12      0.18     -2.12     -2.41     -1.82   8991.09      1.00
Param:0[2]      1.70      0.17      1.70      1.42      1.98   9405.11      1.00

Number of divergences: 0


In [194]:
spls_0 = mcmc.get_samples()
values = [spls_0[:,i]for i in range(theta_dim)]
data_0 = dict(zip(keys,values))


In [195]:
spls_0.shape

(10000, 3)

In [196]:
plot_params_kde(data_0, figsize=(8,8),limits=None,
  point_estimate=None, reference_values=truth, reference_color='r',
  fname='./user0_true_posterior_numpyro.pdf');

  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)
  qcs = ax.contour(x_x, y_y, density, **contour_kwargs)
