In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy.integrate import cumtrapz
from scipy.stats import gaussian_kde
import scipy.stats as stats
from getdist import plots, MCSamples
import matplotlib as mpl
from astropy.cosmology import Planck15
from scipy.interpolate import InterpolatedUnivariateSpline
import fsps
import emcee
from scipy.special import hyp2f1

import tensorflow as tf
import tensorflow_probability as tfp
import tqdm
from tqdm import trange
tfb = tfp.bijectors
tfd = tfp.distributions
tfkl = tf.keras.layers
tfpl = tfp.layers
tfk = tf.keras

import sys
sys.path.append('/Users/justinalsing/Documents/science/steppz/code')
from plotting import triangle_plot
from utils import *
from priors import *
from affine import *
from ndes import *



Import relevant models

In [2]:
# import the relevant models
log10sSFR_emulator = RegressionNetwork(restore=True, restore_filename='ProspectorAlpha_log10sSFR_emulator.pkl')
baseline_SFR_prior_log_prob = RegressionNetwork(restore=True, restore_filename='ProspectorAlpha_baseline_SFR_prior_logprob.pkl')

# set up the prior class
Prior = ProspectorAlphaBaselinePrior(baselineSFRprior=baseline_SFR_prior_log_prob, 
                             log10sSFRemulator=log10sSFR_emulator, 
                             log10sSFRprior=log10sSFRpriorMizuki, 
                             log10sSFRuniformlimits=tfd.Uniform(low=-14, high=-8), 
                             redshift_prior=redshift_volume_prior)

In [3]:
# initialize walkers for sampling
n_walkers = 2000
n_steps = 1000

# baseline prior draws
bijector = tfb.Blockwise([tfb.Invert(tfb.Chain([tfb.Invert(tfb.NormalCDF()), tfb.Scale(1./(Prior.upper[_]-Prior.lower[_])), tfb.Shift(-Prior.lower[_])])) for _ in range(Prior.n_sps_parameters)])
baseline_draws = bijector(Prior.baselinePrior.sample((30000, Prior.n_sps_parameters)))

# reject those outside SFR prior range
sfh = tf.gather(baseline_draws, [2, 3, 4, 5, 6, 7, 1, 14], axis=-1)
log10sSFR = tf.squeeze(log10sSFR_emulator(sfh))
baseline_draws = tf.squeeze(tf.gather(baseline_draws, indices=tf.where((log10sSFR > -14) & (log10sSFR < -8)), axis=0), axis=1)

# convert log10M to N
baseline_draws = baseline_draws.numpy()
baseline_draws[...,0] = -2.5*baseline_draws[...,0] + distance_modulus(tf.math.maximum(1e-5, baseline_draws[...,-1]))
log_prior = Prior.log_prob(baseline_draws).numpy()
baseline_draws = baseline_draws[~np.isinf(log_prior),:]
baseline_draws = tf.convert_to_tensor(baseline_draws)

# current state
current_state = [baseline_draws[0:n_walkers,:], baseline_draws[n_walkers:2*n_walkers,:]]

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'


In [4]:
n_batches = 64
n_samples = 100000

# burn in
chain = affine_sample(Prior.log_prob, 1000, current_state)
current_state = [chain[-1,0:n_walkers,:], chain[-1,n_walkers:,:]]

for batch in range(n_batches):

    chain = affine_sample(Prior.log_prob, 25, current_state)
    current_state = [chain[-1,0:n_walkers,:], chain[-1,n_walkers:,:]]
    parameters = Prior.bijector(chain).numpy().reshape((25*2*n_walkers, 15))
    
    # save the parameters
    np.save('../model_Prospector-alpha/training_data_prior/parameters/parameters{}.npy'.format(batch), parameters)



100%|██████████| 999/999 [02:49<00:00,  5.91it/s]
100%|██████████| 24/24 [00:04<00:00,  5.92it/s]
100%|██████████| 24/24 [00:04<00:00,  5.92it/s]
100%|██████████| 24/24 [00:04<00:00,  5.88it/s]
100%|██████████| 24/24 [00:04<00:00,  5.91it/s]
100%|██████████| 24/24 [00:04<00:00,  5.93it/s]
100%|██████████| 24/24 [00:04<00:00,  5.88it/s]
100%|██████████| 24/24 [00:04<00:00,  5.86it/s]
100%|██████████| 24/24 [00:04<00:00,  5.91it/s]
100%|██████████| 24/24 [00:04<00:00,  5.87it/s]
100%|██████████| 24/24 [00:04<00:00,  5.88it/s]
100%|██████████| 24/24 [00:04<00:00,  5.94it/s]
100%|██████████| 24/24 [00:04<00:00,  5.94it/s]
100%|██████████| 24/24 [00:04<00:00,  5.89it/s]
100%|██████████| 24/24 [00:04<00:00,  5.95it/s]
100%|██████████| 24/24 [00:04<00:00,  5.93it/s]
100%|██████████| 24/24 [00:04<00:00,  5.92it/s]
100%|██████████| 24/24 [00:04<00:00,  5.90it/s]
100%|██████████| 24/24 [00:04<00:00,  5.91it/s]
100%|██████████| 24/24 [00:04<00:00,  5.91it/s]
100%|██████████| 24/24 [00:04<00:00,  