In [139]:
import numpy as np

from Chempy.parameter import ModelParameters

import sbi.utils as utils
from sbi.inference import SNPE, prepare_for_sbi, simulate_for_sbi
from sbi.analysis import pairplot

import torch
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform

import time as t
import pickle

# Load the data

In [2]:
# ------ Load & prepare the data ------

# --- Load in training data ---
path_training = '../ChempyMulti/tutorial_data/TNG_Training_Data.npz'
training_data = np.load(path_training, mmap_mode='r')

elements = training_data['elements']
train_x = training_data['params']
train_y = training_data['abundances']


# ---  Load in the validation data ---
path_test = '../ChempyMulti/tutorial_data/TNG_Test_Data.npz'
val_data = np.load(path_test, mmap_mode='r')

val_x = val_data['params']
val_y = val_data['abundances']


# --- Clean the data ---
# Chempy sometimes returns zeros or infinite values, which need to removed
def clean_data(x, y):
    # Remove all zeros from the training data
    index = np.where((y == 0).all(axis=1))[0]
    x = np.delete(x, index, axis=0)
    y = np.delete(y, index, axis=0)

    # Remove all infinite values from the training data
    index = np.where(np.isfinite(y).all(axis=1))[0]
    x = x[index]
    y = y[index]

    return x, y


train_x, train_y = clean_data(train_x, train_y)
val_x, val_y     = clean_data(val_x, val_y)

# convert to torch tensors
train_x = torch.tensor(train_x, dtype=torch.float32)
train_y = torch.tensor(train_y, dtype=torch.float32)
val_x = torch.tensor(val_x, dtype=torch.float32)
val_y = torch.tensor(val_y, dtype=torch.float32)

In [101]:
a = ModelParameters()
labels = [a.to_optimize[i] for i in range(len(a.to_optimize))] + ['time']
priors = torch.tensor([[a.priors[opt][0], a.priors[opt][1]] for opt in a.to_optimize])

# Define the NN

In [134]:
if torch.backends.mps.is_available():
    print("using mps")
    device = torch.device("cpu")
else:
    print("using cpu")
    device = torch.device("cpu")


class Model_Torch(torch.nn.Module):
    def __init__(self):
        super(Model_Torch, self).__init__()
        self.l1 = torch.nn.Linear(train_x.shape[1], 100)
        self.l2 = torch.nn.Linear(100, 40)
        self.l3 = torch.nn.Linear(40, train_y.shape[1])

    def forward(self, x):
        # Normalize the input
        """x1 = (abs(x[:,0:-1]-priors[:,0]))/(priors[:,1]*10)
        x2 = x[:,-1]/torch.tensor([13.8])
        x = torch.cat((x1, x2.reshape(-1,1)),1)"""
        
        x = torch.tanh(self.l1(x))
        x = torch.tanh(self.l2(x))
        x = self.l3(x)
        return x

model = Model_Torch()
model.to(device)

using mps


Model_Torch(
  (l1): Linear(in_features=6, out_features=100, bias=True)
  (l2): Linear(in_features=100, out_features=40, bias=True)
  (l3): Linear(in_features=40, out_features=9, bias=True)
)

# Train the model

In [135]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.MSELoss()

# shuffle the data
index = np.arange(train_x.shape[0])
np.random.shuffle(index)
train_x = train_x[index]
train_y = train_y[index]

# --- Train the neural network ---
epochs = 15
batch_size = 64
for epoch in range(epochs):
    start_epoch = t.time()
    for i in range(0, train_x.shape[0], batch_size):
        optimizer.zero_grad()
        
        # Get the batch
        x_batch = train_x[i:i+batch_size].detach().clone().to(device).requires_grad_(True)
        y_batch = train_y[i:i+batch_size].detach().clone().to(device).requires_grad_(True)

        # Forward pass
        y_pred = model(x_batch)

        # Compute Loss
        loss = loss_fn(y_pred, y_batch)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
    # Validation loss
    y_pred = model(val_x)
    val_loss = loss_fn(y_pred, val_y)
        
    end_epoch = t.time()
    epoch_time = end_epoch - start_epoch
    
    print(f'Epoch {epoch+1}/{epochs} in {round(epoch_time,1)}s, Loss: {round(loss.item(),6)} | Val Loss: {round(val_loss.item(),6)}')

Epoch 1/15 in 2.5s, Loss: 0.000933 | Val Loss: 0.001334
Epoch 2/15 in 2.6s, Loss: 0.00058 | Val Loss: 0.000586
Epoch 3/15 in 2.6s, Loss: 0.000372 | Val Loss: 0.000482
Epoch 4/15 in 2.7s, Loss: 0.000264 | Val Loss: 0.000417
Epoch 5/15 in 2.6s, Loss: 0.000224 | Val Loss: 0.000355
Epoch 6/15 in 2.6s, Loss: 0.000206 | Val Loss: 0.000297
Epoch 7/15 in 2.6s, Loss: 0.000177 | Val Loss: 0.000262
Epoch 8/15 in 2.6s, Loss: 0.000153 | Val Loss: 0.000242
Epoch 9/15 in 2.7s, Loss: 0.000137 | Val Loss: 0.000228
Epoch 10/15 in 2.7s, Loss: 0.000125 | Val Loss: 0.000217
Epoch 11/15 in 2.6s, Loss: 0.000117 | Val Loss: 0.000209
Epoch 12/15 in 2.6s, Loss: 0.000112 | Val Loss: 0.000201
Epoch 13/15 in 2.6s, Loss: 0.000106 | Val Loss: 0.000195
Epoch 14/15 in 2.6s, Loss: 0.000102 | Val Loss: 0.00019
Epoch 15/15 in 2.7s, Loss: 9.9e-05 | Val Loss: 0.000185


In [None]:
# --- Save the model ---
torch.save(model.state_dict(), 'data/pytorch_state_dict.pt')

# Train SBI

In [None]:
# --- Load the model ---
model = Model_Torch()
model.load_state_dict(torch.load('data/pytorch_state_dict.pt'))
model.to(device)
model.eval()

In [141]:
combined_priors = utils.MultipleIndependent(
    [Normal(p[0]*torch.ones(1), p[1]*torch.ones(1)) for p in priors] +
    [Uniform(torch.tensor([2.0]), torch.tensor([12.8]))],
    validate_args=False)

In [142]:
def simulator_NN_torch(params):
    pred = model(params)
    return pred

In [154]:
simulator, prior = prepare_for_sbi(simulator_NN_torch, combined_priors)
inference = SNPE(prior=prior)

start = t.time()

theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=100)
density_estimator = inference.append_simulations(theta, x).train()
posterior = inference.build_posterior(density_estimator)

end = t.time()
comp_time = end - start
print(f'Time taken to train the posterior with {len(train_y)} samples: '
      f'{np.floor(comp_time/60).astype("int")}min {np.floor(comp_time%60).astype("int")}s')

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [9]], which is output 0 of StdBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

In [151]:
theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=100)

In [152]:
theta

tensor([[-2.3841e+00, -3.2961e+00, -2.2128e-01,  5.3151e-01,  4.1123e-01,
          8.0533e+00],
        [-2.6094e+00, -2.9897e+00, -1.8209e-01,  5.8454e-01,  5.1260e-01,
          1.1587e+01],
        [-2.1794e+00, -3.0984e+00,  3.1821e-02,  5.8867e-01,  4.6967e-01,
          4.7040e+00],
        [-2.0311e+00, -2.4871e+00, -5.4967e-02,  4.1474e-01,  5.0354e-01,
          1.0812e+01],
        [-2.6379e+00, -2.9906e+00, -3.7574e-01,  7.3419e-01,  4.7525e-01,
          2.2424e+00],
        [-2.1731e+00, -2.8508e+00, -4.8300e-01,  5.3187e-01,  5.7189e-01,
          5.0941e+00],
        [-2.7868e+00, -2.4215e+00,  9.8619e-02,  3.8069e-01,  4.5382e-01,
          9.9786e+00],
        [-2.3006e+00, -3.1132e+00, -6.0857e-01,  4.1632e-01,  4.9090e-01,
          5.9860e+00],
        [-2.3414e+00, -3.2068e+00, -2.6241e-01,  5.9485e-01,  4.9363e-01,
          7.1465e+00],
        [-2.0368e+00, -2.7116e+00, -1.7991e-01,  5.5132e-01,  5.1690e-01,
          4.8675e+00],
        [-2.2917e+00, -2.8780e