In [49]:
# Training NN with Elegy then do HMC with Oryx on Mana
# Author: Peter Oct 24 2021
# Requirements: 
#!module load system/CUDA/11.0.2 
#!pip install --upgrade jax jaxlib==0.1.68+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html 
#!pip install tensorflow-io oryx elegy

from collections import defaultdict
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid')

import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '1'
#os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
#os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.10'
import jax
import jax.numpy as jnp
from jax import random
from jax import vmap
from jax import jit
from jax import grad

assert jax.default_backend() == 'gpu'

import oryx  # pip install oryx
import elegy # pip install elegy
import optax
#import flax.linen
import tensorflow_io as tfio # pip install tensorflow-io
import tensorflow as tf # Recommended not to import this with jax because will also try to grab memory.

tfd = oryx.distributions
state = oryx.core.state
ppl = oryx.core.ppl
inverse = oryx.core.inverse
ildj = oryx.core.ildj
plant = oryx.core.plant
reap = oryx.core.reap
sow = oryx.core.sow
unzip = oryx.core.unzip
nn = oryx.experimental.nn
mcmc = oryx.experimental.mcmc
optimizers = oryx.experimental.optimizers # Oryx based on https://github.com/deepmind/optax

def load_data_ams():
    """ Load AMS data from Claudio."""
    filename = '../data/BR2461.dat'
    alpha, cmf = 69.19, 5.17 # These are fixed for the current experiment.
    dataset_ams = np.loadtxt(filename) # Rigidity1, Rigidity2, Flux, Error
    r1, r2 = dataset_ams[:,0], dataset_ams[:,1]
    bins = np.concatenate([r1[:], r2[-1:]])
    observed = dataset_ams[:,2]   # Observed Flux
    uncertainty = dataset_ams[:,3]
    assert len(bins) == len(observed)+1
    return bins, observed, uncertainty, alpha, cmf

def load_preprocessed_data_ams():
    """ Load AMS data along with hardcoded auxiliary vectors for ppmodel."""
    bins, observed, uncertainty, alpha, cmf = load_data_ams()
    if False:
        pass
        # Need to compute this once and then hardcode results.
        #xloc = np.sort(np.concatenate([jnp.arange(245), bins+1e-12])) # Both the lattice and the given bins. Add epsilon to bin to distinguish it; shouldn't affect result.
        #iloc = jnp.where(xloc % 1 > 0)[0] # Throws error https://github.com/google/jax/issues/4309. Just hardcode result as a quick workaround.
        #print(iloc) # Hardcode this below to make JAX happy. 
    else:
        # Hardcode a list containing the lattice of range(245) and bin edges from the AMS data.
        xloc = jnp.sort(jnp.concatenate([jnp.arange(245), bins])) 
        # Hardcode the indices of the AMS bin edges in this list, so that we can integrate between them.
        iloc = [  2,   3,   4,   5,   6,   7,   9,  10,  11,  12,  14,  15,
                      17,  18,  19,  21,  22,  24,  26,  27,  29,  31,  33,  35,
                      37,  39,  41,  43,  45,  48,  50,  53,  55,  58,  61,  64,
                      68,  71,  75,  78,  82,  87,  91,  96, 101, 106]
        #slices = [slice(iloc[i], iloc[i+1]) for i in range(45)] # These are the slices for each bin.
    return xloc, iloc, observed, uncertainty, alpha, cmf

def define_ppmodel(model, xloc, iloc, observed, uncertainty, alpha_norm, cmf_norm):
    ''' 
    Defines a probabilistic program model from NN and data.
    Inputs:
        model = NN from elegy
        xloc =
        iloc =
        observed =
    Returns:
        jax probalistic program
    '''
    def ppmodel(key):
        """ 
        Define probabilistic programming model as a random sample from input space, 
        then application of NN, then computation of likelihood.
        """
        # First sample from 7-D gaussian prior.
        #x = ppl.random_variable(tfd.MultivariateNormalDiag(jnp.zeros(7), jnp.ones(7)), name='x')(key)
        #prior = tfd.Uniform(low=jnp.zeros(7), high=5*jnp.ones(7)) # Doesn't work.
        # NN is trained on min-max scaled data, so inputs should be in [0,1].
        # prior = tfd.Independent(tfd.Uniform(low=[alpha, cmf, 0., 0., 0., 0., 0.],
        #                                     high=[alpha, cmf, 1., 1., 1., 1., 1.]), 
        #                                     validate_args=True, reinterpreted_batch_ndims=1) # This doesn't work because all proposals rejected in hmc.
        prior = tfd.Sample(tfd.Uniform(0., 1.), sample_shape=(5,))
        x = ppl.random_variable(prior, name='x')
        x = x(key) # Sample from 5-dimensional space.
        x = jnp.concatenate([jnp.array([alpha_norm, cmf_norm]), x]) # Create 7d input to NN.

        # Apply NN and get predicted flux on 245 lattice points.
        yhat = model.predict(x) # Reminder: model should be in eager mode. 
        yhat = jnp.exp(yhat) - 1. # Undo logp1 transform of target output. 
        # Interpolate to get predicted flux at both lattice and bin points.
        yloc = jnp.interp(xloc, jnp.arange(245), yhat) 
        # Integrate over bin regions, and compare to observed to get likelihood.
        loglikelihood = 0.0
        for i in range(45):
            # Integrate over bin by trapezoid method.
            predicted : float = jnp.trapz(yloc[iloc[i]:iloc[i+1]], x=xloc[iloc[i]:iloc[i+1]])
            # Use equation provided by Claudio for likelihood of bin.
            bin_loglikelihood = ((predicted - observed[i])/uncertainty[i])**2
            # Sum to get joint loglikelihood.
            loglikelihood += bin_loglikelihood
        return jnp.exp(loglikelihood) # Oryx expects likelihood.
    return ppmodel

def remove_consecutive_duplicates(X):
    ''' 
    Remove consecutive duplicate rows from array. This means sample from MCMC was rejected.
    Input:
        X = 2d array, where each row is a sample
    Returns:
        rval = 2d array with consecutive duplicate rows removed
    '''
    consecutive_repeat_rows = np.all(samples[:-1,:] == samples[1:,:], axis=1)
    return X[~consecutive_repeat_rows, :]

# Load trained NN model that maps 7 parameters to predicted flux at rigidity vals range(245).
model = elegy.load('my_model')
model.run_eagerly = True # Settable attribute. Required to be true for ppmodel.

# Load observation data.
xloc, iloc, observed, uncertainty, alpha, cmf = load_preprocessed_data_ams()


In [None]:
%%time


# Hyperparameters
num_samples = 1e6
num_leapfrog_steps = 1e4
step_size = 1e-6
seed = 2

# Run Hamiltonian Monte Carlo
start = 0.5*jnp.ones(5) # This is the starting state for MCMC.
alpha_norm = (alpha - 20.) / 55. # Min max scaling
cmf_norm = (cmf - 4.5) / 4. # Min max scaling
ppmodel = define_ppmodel(model, xloc, iloc, observed, uncertainty, alpha_norm, cmf_norm)
hmc = mcmc.hmc(ppl.joint_log_prob(ppmodel),
               num_leapfrog_steps = num_leapfrog_steps, # Too low and samples will be correlated.
               step_size = step_size) # Too high and proposals won't be accepted. 
# Experiments:
# N=samples, L=num_leapfrog_steps, e=step_size, p=repeated_states, T=Time.
# L=1e3,e=1e-4,p=1/3. L=1e3,e=1e-5,p=0.02. L=1e3,e=1e-6,p=0.0,T=1e2sec but bad mixing.
# L=1e4,e=1e-6,N=1e5, p=0.02,T=14min20s.
unnormalized_samples = jit(mcmc.sample_chain(hmc, num_samples))(random.PRNGKey(seed), start) # use log_prob doesn't work because function too complex.
unnormalized_samples = unnormalized_samples.to_py()

# De-normalize samples.
#MIN_VALS = np.array([20., 4.5, 50., 0.2, 0.2, 0.2, 0.2])
#MAX_VALS = np.array([75., 8.5, 250., 2., 2.3, 2., 2.3])
MIN_VALS = np.array([50., 0.2, 0.2, 0.2, 0.2])
MAX_VALS = np.array([250., 2., 2.3, 2., 2.3])
RANGE_VALS = MAX_VALS - MIN_VALS
all_samples = unnormalized_samples * RANGE_VALS + MIN_VALS

# Remove rejected samples. 
samples = remove_consecutive_duplicates(all_samples) 
print(f'Acceptance rate: {len(samples)/len(all_samples)}. Decrease step_size to increase rate.')
    
# Throw out initial samples to ensure proper mixing.
samples = samples[np.floor(0.01 * len(samples)):, :]

# Save samples.
np.savetxt(fname=f'results/samples_{num_samples}_{num_leapfrog_steps}_{step_size}.csv', X=samples)

# Plot marginals.
parameter_names = ['cpa', 'pwr1par', 'pwr2par', 'pwr1perr', 'pwr2perr']
plt.figure(1, figsize=(24, 4))
for i in range(5):
    plt.subplot(1, 5, i+1)
    plt.hist(samples[:, i])
    plt.xlabel(parameter_names[i])
    plt.xlim((MIN_VALS[i], MAX_VALS[i]))
plt.savefig(fname=f'results/marginals_{num_samples}_{num_leapfrog_steps}_{step_size}.pdf', bbox_inches='tight')