In [None]:
import numpy as np
import matplotlib.pyplot as plt
import scipy.interpolate as interpolate
import simulators.jla_supernovae.jla_simulator as jla
import ndes.ndes as ndes
import delfi.delfi as delfi
import compression.score.score as score
import distributions.priors as priors
import tensorflow as tf
from scipy.linalg import block_diag
tf.logging.set_verbosity(tf.logging.ERROR)
%matplotlib inline

In [None]:
### SET UP THE PRIOR ###

# Prior over theta (interesting parameters)
lower = np.array([0, -1.5])
upper = np.array([0.6, 0])
prior_covariance = np.diag([0.4, 0.75])**2
prior_covariance[0,1] = prior_covariance[1,0] = -0.8*0.4*0.75
prior_mean = np.array([  0.3  ,  -0.75])
prior = priors.TruncatedGaussian(prior_mean, prior_covariance, lower, upper)

# Prior over eta (nuisances)
eta_lower = np.array([-20, 0, 0, -0.5])
eta_upper = np.array([-18, 1, 6, 0.5])
eta_mean = np.array([-19.05 ,   0.125,   2.6  ,  -0.05 ])
eta_covariance = np.diag([0.1, 0.025, 0.25, 0.05])**2
eta_prior = priors.TruncatedGaussian(eta_mean, eta_covariance, eta_lower, eta_upper)

# Joint prior over nuisances and interesting parameters
joint_lower = np.concatenate([lower, eta_lower])
joint_upper = np.concatenate([upper, eta_upper])
joint_mean = np.concatenate([prior_mean, eta_mean])
joint_covariance = block_diag(prior_covariance, eta_covariance)
joint_prior = priors.TruncatedGaussian(joint_mean, 
                                       joint_covariance,
                                       joint_lower,
                                       joint_upper)

In [None]:
### SET UP FOR SIMULATION CODE ###

JLASimulator = jla.JLA_Model()

# Simulator function: This must be of the form simulator(theta, seed, args) -> simulated data vector
def simulator(theta, seed, simulator_args, batch):
    
    # Draw nuisances from prior
    eta_prior = simulator_args[0]
    eta = eta_prior.draw()
    
    return JLASimulator.simulation(np.concatenate([theta, eta]), seed)

# Arguments for simulator
simulator_args = [eta_prior]

In [None]:
### SET UP THE COMPRESSOR ###

# Fiducial parameters
theta_fiducial = np.array([0.20181324,  -0.74762939])
eta_fiducial = np.array([-19.04253368,   0.12566322,   2.64387045, -0.05252869])

# Expected data (mean) and covariance
mu = JLASimulator.apparent_magnitude(np.concatenate([theta_fiducial, eta_fiducial]))
Cinv = JLASimulator.Cinv

# Calculate derivatives of the expected power spectrum
h = np.array(abs(np.concatenate([theta_fiducial, eta_fiducial])))*0.01
dmudt = JLASimulator.dmudt(np.concatenate([theta_fiducial, eta_fiducial]), h)

# Define compression as score-MLE of a Wishart likelihood
Compressor = score.Gaussian(len(JLASimulator.data), np.concatenate([theta_fiducial, eta_fiducial]), mu = mu, Cinv = Cinv, dmudt = dmudt, prior_mean = joint_mean, prior_covariance = joint_covariance)

# Compute the Fisher matrix
Compressor.compute_fisher()

# Pull out Fisher matrix inverse
Finv = Compressor.Finv[0:2,0:2]

# Compressor function: This must have the form compressor(data, args) -> compressed summaries (pseudoMLE)
def compressor(d, compressor_args):
    return Compressor.projected_scoreMLE(d, np.arange(2,6))
compressor_args = None

In [None]:
### Compress the JLA data ###
compressed_data = compressor(JLASimulator.data, compressor_args)

In [None]:
# Create ensemble of NDEs
NDEs = [ndes.ConditionalMaskedAutoregressiveFlow(n_parameters=2, n_data=2, n_hiddens=[50,50], n_mades=5, act_fun=tf.tanh, index=0),
        ndes.MixtureDensityNetwork(n_parameters=2, n_data=2, n_components=1, n_hidden=[30,30], activations=[tf.tanh, tf.tanh], index=1),
        ndes.MixtureDensityNetwork(n_parameters=2, n_data=2, n_components=2, n_hidden=[30,30], activations=[tf.tanh, tf.tanh], index=2),
        ndes.MixtureDensityNetwork(n_parameters=2, n_data=2, n_components=3, n_hidden=[30,30], activations=[tf.tanh, tf.tanh], index=3),
        ndes.MixtureDensityNetwork(n_parameters=2, n_data=2, n_components=4, n_hidden=[30,30], activations=[tf.tanh, tf.tanh], index=4),
        ndes.MixtureDensityNetwork(n_parameters=2, n_data=2, n_components=5, n_hidden=[30,30], activations=[tf.tanh, tf.tanh], index=5)]


# Create the DELFI object
DelfiEnsemble = delfi.Delfi(compressed_data, prior, NDEs, Finv = Finv, theta_fiducial = theta_fiducial, 
                       param_limits = [lower, upper],
                       param_names = ['\Omega_m', 'w_0'], 
                       results_dir = "simulators/jla_supernovae/results_marginal/",
                       input_normalization="fisher")

In [None]:
# Do the Fisher pre-training
DelfiEnsemble.fisher_pretraining()

In [None]:
# Initial samples, batch size for population samples, number of populations
n_initial = 100
n_batch = 100
n_populations = 11

# Do the SNL training
DelfiEnsemble.sequential_training(simulator, compressor, n_initial, n_batch, n_populations, patience=10, simulator_args=simulator_args)