In [7]:
import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt

import torch
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform

from probjax.nn.transformers import Transformer
from probjax.nn.helpers import GaussianFourierEmbedding
from probjax.nn.loss_fn import denoising_score_matching_loss
from probjax.distributions.sde import VESDE
from probjax.distributions import Empirical, Independent

import sbi
import sbi.utils as utils
from sbi.inference import simulate_for_sbi
from sbi.utils.user_input_checks import check_sbi_inputs, process_prior, process_simulator

# Priors

In [3]:
# Elements to track
labels_out_H = ['C', 'Fe', 'H', 'He', 'Mg', 'N', 'Ne', 'O', 'Si']
labels_out = ['C', 'Fe', 'He', 'Mg', 'N', 'Ne', 'O', 'Si']

# Input parameters
labels_in = ['high_mass_slope', 'log10_N_0', 'log10_starformation_efficiency', 'log10_sfr_scale', 'outflow_feedback_fraction', 'time']
priors = torch.tensor([[-2.3000,  0.3000],
                       [-2.8900,  0.3000],
                       [-0.3000,  0.3000],
                       [ 0.5500,  0.1000],
                       [ 0.5000,  0.1000]])

combined_priors = utils.MultipleIndependent(
    [Normal(p[0]*torch.ones(1), p[1]*torch.ones(1)) for p in priors] +
    [Uniform(torch.tensor([2.0]), torch.tensor([12.8]))],
    validate_args=False)

# NN Simulator

In [5]:
class Model_Torch(torch.nn.Module):
    def __init__(self):
        super(Model_Torch, self).__init__()
        self.l1 = torch.nn.Linear(len(labels_in), 100)
        self.l2 = torch.nn.Linear(100, 40)
        self.l3 = torch.nn.Linear(40, len(labels_out_H))

    def forward(self, x):
        x = torch.tanh(self.l1(x))
        x = torch.tanh(self.l2(x))
        x = self.l3(x)
        return x

In [6]:
model_simulator = Model_Torch()
# --- Load the weights ---
model_simulator.load_state_dict(torch.load('../data/pytorch_state_dict.pt'))
model_simulator.eval();

  model_simulator.load_state_dict(torch.load('../data/pytorch_state_dict.pt'))


# Create data

In [8]:
# ----- set up the simulator -----
def simulator(params):
    y = model_simulator(params)
    y = y.detach().numpy()

    # Remove H from data, because it is just used for normalization (output with index 2)
    y = np.delete(y, 2)

    return y

prior, num_parameters, prior_returns_numpy = process_prior(combined_priors)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=10000)
data = jnp.asarray(torch.concatenate([theta, x], axis=1).reshape(len(x), -1, 1))

  0%|          | 0/10000 [00:00<?, ?it/s]

# Diffusion model
## Set up diffusion process

In [16]:
# VESDE 
T = 1.
T_min = 1e-2
sigma_min = 1e-3
sigma_max = 15.


p0 = Independent(Empirical(data), 1) # Empirical distribution of the data
sde = VESDE(p0, sigma_min=sigma_min , sigma_max=sigma_max)

In [15]:
p0.event_shape

(1,)