# Save and load data
Utilize a prior and a simulator to create said dataset. Save a proportion as a training set, and part as a validation set.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
# remove top and right axis from plots
matplotlib.rcParams["axes.spines.right"] = False
matplotlib.rcParams["axes.spines.top"] = False

In [2]:
import sbi
from sbi.inference import SNPE
from sbi.inference.base import infer
from sbi.analysis import pairplot
import torch

In [3]:
from src.scripts.io import DataLoader

In [4]:
def simulator(thetas):#, percent_errors):
    # convert to numpy array (if tensor):
    thetas = np.atleast_2d(thetas)
    # Check if the input has the correct shape
    if thetas.shape[1] != 2:
        raise ValueError("Input tensor must have shape (n, 2) where n is the number of parameter sets.")

    # Unpack the parameters
    if thetas.shape[0] == 1:
        # If there's only one set of parameters, extract them directly
        m, b = thetas[0, 0], thetas[0, 1]
    else:
        # If there are multiple sets of parameters, extract them for each row
        m, b = thetas[:, 0], thetas[:, 1]
    x = np.linspace(0, 100, 101)
    rs = np.random.RandomState()#2147483648)# 
    # I'm thinking sigma could actually be a function of x
    # if we want to get fancy down the road
    # Generate random noise (epsilon) based on a normal distribution with mean 0 and standard deviation sigma
    sigma = 5
    ε = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0]))
    
    # Initialize an empty array to store the results for each set of parameters
    y = np.zeros((len(x), thetas.shape[0]))
    for i in range(thetas.shape[0]):
        m, b = thetas[i, 0], thetas[i, 1]
        y[:, i] = m * x + b + ε[:, i]
    return torch.Tensor(y.T)

In [5]:
num_dim = 2

low_bounds = torch.tensor([0, -10])
high_bounds = torch.tensor([10, 10])

prior = sbi.utils.BoxUniform(low = low_bounds, high = high_bounds)

To create the training set, sample from this prior and run it through the simulator.

In [11]:
params = prior.sample((10000,))
xs = simulator(params)
print(r'$\theta$s', params, 'xs', xs)


$\theta$s tensor([[ 4.1849, -6.4817],
        [ 7.4023, -8.3961],
        [ 1.0809, -4.3180],
        ...,
        [ 7.9366,  8.7476],
        [ 0.7334, -7.2838],
        [ 8.7519, -0.0754]]) xs tensor([[ -1.1131, -10.7165,  12.7179,  ..., 406.5125, 400.8539, 418.1629],
        [-12.5729,  -7.0795,   8.2704,  ..., 732.9651, 724.2283, 734.0270],
        [ -3.6071,  -5.1995,   6.8589,  ...,  96.1481, 101.5714, 111.1034],
        ...,
        [  6.1125,   9.5420,  16.1666,  ..., 789.9856, 794.8154, 798.9568],
        [ -8.5876,  -5.4256,  -8.3190,  ...,  64.5963,  66.2074,  73.9468],
        [  7.8072,  15.0878,  27.3789,  ..., 854.3221, 866.7562, 870.2589]])


In [12]:
# Save both params and xs to a .pkl file
data_to_save = {'thetas': params, 'xs': xs}

dataloader = DataLoader()
dataloader.save_data_pkl('../saveddata/',
                         'data_train',
                         data_to_save)

Redo this with a validation set that is the same size.

In [13]:
params_valid = prior.sample((10000,))
xs_valid = simulator(params_valid)
print(r'$\theta$s', params_valid, 'xs', xs_valid)

$\theta$s tensor([[ 9.2163,  1.7404],
        [ 9.2730, -2.2233],
        [ 5.7325, -7.6186],
        ...,
        [ 9.7974,  0.9499],
        [ 5.0012, -4.6716],
        [ 1.2730, -9.9987]]) xs tensor([[  3.3746,   3.4333,  24.3763,  ..., 906.7070, 915.0146, 926.6005],
        [ -8.5487,   8.1198,  21.0463,  ..., 906.0042, 917.9788, 917.3160],
        [ -7.4844,   4.4250,  10.5513,  ..., 550.0146, 559.8779, 558.8493],
        ...,
        [  2.6021,  14.3608,  21.4350,  ..., 951.8111, 969.6288, 977.5541],
        [ -5.0430,  -3.9875,  11.0748,  ..., 482.7758, 487.1702, 501.3554],
        [ -7.7894, -10.5941,  -9.1570,  ..., 116.4008, 113.7430, 120.2171]])


In [14]:
# Save both params and xs to a .pkl file
data_to_save_valid = {'thetas': params_valid, 'xs': xs_valid}

dataloader = DataLoader()
dataloader.save_data_pkl('../saveddata/',
                         'data_validation',
                         data_to_save_valid)

## Now load up this data and run SBI using it