In [1]:
import numpy as np
import pandas as pd
import torch
_ = torch.manual_seed(10)
import os
import math
from sbi import utils as utils
import sbi
from sbi import inference
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi
import scipy.io as sio

device = torch.device("cpu")

In [2]:
class CustomPriorDist:
    def __init__(self,a,b, return_numpy: bool = False):
        self.dist1 = utils.BoxUniform(a,b)
        self.dist2 = utils.BoxUniform(a,b)
        self.dist3 = utils.BoxUniform(a,50*b)
        self.dist4 = utils.BoxUniform(a,500*b)
        self.dist5 = utils.BoxUniform(b,32*b)
        self.dist6 = utils.BoxUniform(a,b)
        self.dist7 = utils.BoxUniform(a,b)
        self.dist8 = utils.BoxUniform(a,50*b)
        self.dist9 = utils.BoxUniform(a,500*b)
        self.return_numpy = return_numpy
        
    def sample(self, sample_shape=torch.Size([])):
        if len(sample_shape) == 1:
            length = sample_shape[0]
            samples = torch.ones(length,9)
            temp_1 = self.dist1.sample(sample_shape)
            temp_2 = self.dist2.sample(sample_shape)
            temp_3 = self.dist3.sample(sample_shape)
            temp_4 = self.dist4.sample(sample_shape)
            temp_5 = self.dist5.sample(sample_shape)
            temp_6 = self.dist6.sample(sample_shape)
            temp_7 = self.dist7.sample(sample_shape)
            temp_8 = self.dist8.sample(sample_shape)
            temp_9 = self.dist9.sample(sample_shape)
            samples[:,0] = temp_1[:,0]
            samples[:,1] = temp_2[:,0]
            samples[:,2] = temp_3[:,0]
            samples[:,3] = temp_4[:,0]
            samples[:,4] = temp_5[:,0]
            samples[:,5] = temp_6[:,0]
            samples[:,6] = temp_7[:,0]
            samples[:,7] = temp_8[:,0]
            samples[:,8] = temp_9[:,0]
            return samples.numpy() if self.return_numpy else samples
        else:
            samples = torch.ones(1,9)
            temp_1 = self.dist1.sample(sample_shape)
            temp_2 = self.dist2.sample(sample_shape)
            temp_3 = self.dist3.sample(sample_shape)
            temp_4 = self.dist4.sample(sample_shape)
            temp_5 = self.dist5.sample(sample_shape)
            temp_6 = self.dist6.sample(sample_shape)
            temp_7 = self.dist7.sample(sample_shape)
            temp_8 = self.dist8.sample(sample_shape)
            temp_9 = self.dist9.sample(sample_shape)

            samples[:,0] = temp_1[0]
            samples[:,1] = temp_2[0]
            samples[:,2] = temp_3[0]
            samples[:,3] = temp_4[0]
            samples[:,4] = temp_5[0]
            samples[:,5] = temp_6[0]
            samples[:,6] = temp_7[0]
            samples[:,7] = temp_8[0]
            samples[:,8] = temp_9[0]
            return samples.numpy() if self.return_numpy else samples            
    
    def log_prob(self, values):
        log_probs = torch.ones((values.size()[0],))
        length = values.size()[0]
        if self.return_numpy:
            values = torch.as_tensor(values)
        
        for i in range(values.size()[0]):
            temp = torch.ones(9)
            temp[0] = self.dist1.log_prob(values[i][0])
            temp[1] = self.dist2.log_prob(values[i][1])
            temp[2] = self.dist3.log_prob(values[i][2])
            temp[3] = self.dist4.log_prob(values[i][3])
            temp[4] = self.dist5.log_prob(values[i][4])
            temp[5] = self.dist6.log_prob(values[i][5])
            temp[6] = self.dist7.log_prob(values[i][6])
            temp[7] = self.dist8.log_prob(values[i][7])
            temp[8] = self.dist9.log_prob(values[i][8])
            log_probs[i] =  torch.sum(temp)
            
        return log_probs.numpy() if self.return_numpy else log_probs

In [3]:
initial_data = sio.loadmat("Neural_method_initial.mat")

In [4]:
theta_np = initial_data["part_vals_prior"]
sims_np = initial_data["prior_pred_sim"]
obs_np = initial_data["syn4"]

In [5]:
theta = torch.from_numpy(theta_np).to(torch.float32).to(device)
x = torch.from_numpy(sims_np).to(torch.float32).to(device)
x_0 = torch.from_numpy(obs_np).to(torch.float32).to(device)

In [6]:
priorDist = CustomPriorDist(torch.zeros(1).to(device),torch.ones(1).to(device))
def simulator(theta):
    return theta.to(device)
_,prior = prepare_for_sbi(simulator,priorDist)

            and / or upper_bound if your prior has bounded support.
                samples...


In [7]:
inference = sbi.inference.SNLE(prior=prior, density_estimator='nsf')
density_estimator = inference.append_simulations(theta, x).train()
posterior = inference.build_posterior(density_estimator)

 Neural network successfully converged after 200 epochs.

In [None]:
posterior_samples = posterior.sample((10000,),x = x_0,num_chains = 10)

Running 10 MCMC chains in 10 batches.:   0%|          | 0/10 [00:00<?, ?it/s]
  0%|          | 0/50 [00:00<?, ?it/s][A
Tuning bracket width...:   0%|          | 0/50 [00:00<?, ?it/s][A
Tuning bracket width...:  20%|██        | 10/50 [00:13<00:55,  1.39s/it][A
Tuning bracket width...:  40%|████      | 20/50 [00:16<00:31,  1.06s/it][A
Tuning bracket width...:  60%|██████    | 30/50 [00:20<00:16,  1.18it/s][A
Tuning bracket width...:  80%|████████  | 40/50 [00:23<00:06,  1.47it/s][A
Tuning bracket width...: 100%|██████████| 50/50 [00:25<00:00,  1.94it/s][A

  0%|          | 0/10 [00:00<?, ?it/s][A
Generating samples:   0%|          | 0/10 [00:00<?, ?it/s][A
Generating samples:   0%|          | 0/10 [00:14<?, ?it/s][A
Generating samples:  50%|█████     | 5/10 [00:18<00:18,  3.62s/it][A
Generating samples:  60%|██████    | 6/10 [00:22<00:15,  3.81s/it][A
Generating samples:  70%|███████   | 7/10 [00:26<00:11,  3.90s/it][A
Generating samples:  80%|████████  | 8/10 [00:30<00:07, 

Generating samples:  20%|██        | 203/1000 [14:29<55:47,  4.20s/it][A
Generating samples:  20%|██        | 204/1000 [14:33<54:52,  4.14s/it][A
Generating samples:  20%|██        | 205/1000 [14:37<55:57,  4.22s/it][A
Generating samples:  21%|██        | 206/1000 [14:41<56:37,  4.28s/it][A
Generating samples:  21%|██        | 207/1000 [14:46<57:05,  4.32s/it][A
Generating samples:  21%|██        | 208/1000 [14:50<57:15,  4.34s/it][A
Generating samples:  21%|██        | 209/1000 [14:54<56:13,  4.27s/it][A
Generating samples:  21%|██        | 210/1000 [14:58<55:32,  4.22s/it][A
Generating samples:  21%|██        | 211/1000 [15:03<55:36,  4.23s/it][A
Generating samples:  21%|██        | 212/1000 [15:07<56:06,  4.27s/it][A
Generating samples:  21%|██▏       | 213/1000 [15:11<56:38,  4.32s/it][A
Generating samples:  21%|██▏       | 214/1000 [15:16<57:11,  4.37s/it][A
Generating samples:  22%|██▏       | 215/1000 [15:20<57:47,  4.42s/it][A
Generating samples:  22%|██▏       | 2

Generating samples:  42%|████▏     | 423/1000 [30:46<40:34,  4.22s/it][A
Generating samples:  42%|████▏     | 424/1000 [30:51<42:27,  4.42s/it][A
Generating samples:  42%|████▎     | 425/1000 [30:56<42:33,  4.44s/it][A
Generating samples:  43%|████▎     | 426/1000 [31:00<41:42,  4.36s/it][A
Generating samples:  43%|████▎     | 427/1000 [31:04<41:07,  4.31s/it][A
Generating samples:  43%|████▎     | 428/1000 [31:09<41:16,  4.33s/it][A
Generating samples:  43%|████▎     | 429/1000 [31:13<41:49,  4.40s/it][A
Generating samples:  43%|████▎     | 430/1000 [31:17<41:13,  4.34s/it][A
Generating samples:  43%|████▎     | 431/1000 [31:21<40:29,  4.27s/it][A
Generating samples:  43%|████▎     | 432/1000 [31:26<40:16,  4.25s/it][A
Generating samples:  43%|████▎     | 433/1000 [31:30<39:35,  4.19s/it][A
Generating samples:  43%|████▎     | 434/1000 [31:34<39:24,  4.18s/it][A
Generating samples:  44%|████▎     | 435/1000 [31:38<39:50,  4.23s/it][A
Generating samples:  44%|████▎     | 4

Generating samples:  64%|██████▍   | 643/1000 [46:31<27:03,  4.55s/it][A
Generating samples:  64%|██████▍   | 644/1000 [46:36<27:07,  4.57s/it][A
Generating samples:  64%|██████▍   | 645/1000 [46:40<26:47,  4.53s/it][A
Generating samples:  65%|██████▍   | 646/1000 [46:45<26:24,  4.48s/it][A
Generating samples:  65%|██████▍   | 647/1000 [46:49<26:09,  4.45s/it][A
Generating samples:  65%|██████▍   | 648/1000 [46:54<27:17,  4.65s/it][A
Generating samples:  65%|██████▍   | 649/1000 [46:59<27:31,  4.70s/it][A
Generating samples:  65%|██████▌   | 650/1000 [47:04<27:23,  4.70s/it][A
Generating samples:  65%|██████▌   | 651/1000 [47:08<26:57,  4.64s/it][A
Generating samples:  65%|██████▌   | 652/1000 [47:13<27:21,  4.72s/it][A
Generating samples:  65%|██████▌   | 653/1000 [47:18<27:33,  4.77s/it][A
Generating samples:  65%|██████▌   | 654/1000 [47:22<26:39,  4.62s/it][A
Generating samples:  66%|██████▌   | 655/1000 [47:27<26:13,  4.56s/it][A
Generating samples:  66%|██████▌   | 6

Generating samples:  86%|████████▌ | 862/1000 [1:03:04<10:22,  4.51s/it][A
Generating samples:  86%|████████▋ | 863/1000 [1:03:09<10:10,  4.46s/it][A
Generating samples:  86%|████████▋ | 864/1000 [1:03:13<10:16,  4.53s/it][A
Generating samples:  86%|████████▋ | 865/1000 [1:03:18<10:14,  4.55s/it][A
Generating samples:  87%|████████▋ | 866/1000 [1:03:22<09:49,  4.40s/it][A
Generating samples:  87%|████████▋ | 867/1000 [1:03:26<09:42,  4.38s/it][A
Generating samples:  87%|████████▋ | 868/1000 [1:03:30<09:30,  4.32s/it][A
Generating samples:  87%|████████▋ | 869/1000 [1:03:35<09:21,  4.28s/it][A
Generating samples:  87%|████████▋ | 870/1000 [1:03:39<09:07,  4.21s/it][A
Generating samples:  87%|████████▋ | 871/1000 [1:03:43<09:08,  4.25s/it][A
Generating samples:  87%|████████▋ | 872/1000 [1:03:47<09:01,  4.23s/it][A
Generating samples:  87%|████████▋ | 873/1000 [1:03:52<09:05,  4.30s/it][A
Generating samples:  87%|████████▋ | 874/1000 [1:03:56<09:04,  4.32s/it][A
Generating s

Generating samples:   6%|▌         | 60/1000 [05:20<1:25:43,  5.47s/it][A
Generating samples:   6%|▌         | 61/1000 [05:25<1:24:40,  5.41s/it][A
Generating samples:   6%|▌         | 62/1000 [05:31<1:25:07,  5.45s/it][A
Generating samples:   6%|▋         | 63/1000 [05:36<1:24:43,  5.43s/it][A
Generating samples:   6%|▋         | 64/1000 [05:42<1:24:32,  5.42s/it][A
Generating samples:   6%|▋         | 65/1000 [05:47<1:23:06,  5.33s/it][A
Generating samples:   7%|▋         | 66/1000 [05:52<1:23:32,  5.37s/it][A
Generating samples:   7%|▋         | 67/1000 [05:58<1:24:00,  5.40s/it][A
Generating samples:   7%|▋         | 68/1000 [06:03<1:22:55,  5.34s/it][A
Generating samples:   7%|▋         | 69/1000 [06:08<1:23:06,  5.36s/it][A
Generating samples:   7%|▋         | 70/1000 [06:13<1:22:26,  5.32s/it][A
Generating samples:   7%|▋         | 71/1000 [06:19<1:23:19,  5.38s/it][A
Generating samples:   7%|▋         | 72/1000 [06:24<1:22:25,  5.33s/it][A
Generating samples:   7%|

In [None]:
from sbi import analysis as analysis
_ = analysis.pairplot(
    posterior_samples, figsize=(10, 10)
)

In [None]:
mdic = {"theta_new": posterior_samples.numpy(), "theta_old":theta.numpy(), "sims_old":x.numpy(),"observation":x_0.numpy()}
sio.savemat("SNLE_nsf_syn4_20k.mat", mdic)