In [None]:
# Example of training NN with Elegy then doing HMC with Oryx on Mana
# Author: Peter Oct 22 2021
# Requirements: 
#!module load system/CUDA/11.0.2 
#!pip install --upgrade jax jaxlib==0.1.68+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html 
#!pip install tensorflow-io oryx elegy

In [None]:
from collections import defaultdict
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid')

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
#os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
#os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.10'
import jax
import jax.numpy as jnp
from jax import random
from jax import vmap
from jax import jit
from jax import grad

assert jax.default_backend() == 'gpu'

import oryx  # pip install oryx
import elegy # pip install elegy
import optax
#import flax.linen
import tensorflow_io as tfio # pip install tensorflow-io
#import tensorflow as tf # Recommended not to import this with jax because will also try to grab memory.


tfd = oryx.distributions

state = oryx.core.state
ppl = oryx.core.ppl

inverse = oryx.core.inverse
ildj = oryx.core.ildj
plant = oryx.core.plant
reap = oryx.core.reap
sow = oryx.core.sow
unzip = oryx.core.unzip

nn = oryx.experimental.nn
mcmc = oryx.experimental.mcmc
optimizers = oryx.experimental.optimizers # Oryx based on https://github.com/deepmind/optax
#optimizer = optimizers.adam(1e-4) # Oryx based on https://github.com/deepmind/optax
#opt_state = state.init(optimizer)(opt_key, network, network)

# Define dataset
#f = '/mnt/lts/nfs_fs02/sadow_lab/shared/gcr/data/proposal/data_processed.hdf5'
# f = './data_processed_sample.hdf5'
# x = tfio.IODataset.from_hdf5(f, dataset='/features')
# y = tfio.IODataset.from_hdf5(f, dataset='/flux')
# full = tf.data.Dataset.zip((x, y)).repeat()
#full = tf.data.TextLineDataset('sample_data/mnist_train_small.csv')
features, labels = (np.random.sample((100,7)), np.random.sample((100,245)))
#full = tf.data.Dataset.from_tensor_slices((features,labels)).batch(4).repeat()
#full = full.take(100)
#BATCH_SIZE = 128
#fullbatched = full.batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
#train = full.take(np.floor(1435308/BATCH_SIZE *.6))
#test = full.skip(np.floor(1435308/BATCH_SIZE *.9))

In [None]:
class MLP(elegy.Module):
    # Defines stateful model.
    # From elegy docs.
    def call(self, x: jnp.ndarray) -> jnp.ndarray:
        x = elegy.nn.Linear(5)(x)
        x = jax.nn.relu(x)
        x = elegy.nn.Linear(245)(x)
        return x

model = elegy.Model(
    module=MLP(),
    loss=[
        elegy.losses.MeanSquaredError(),
        #elegy.regularizers.GlobalL2(l=1e-5),
        ],
    #metrics=elegy.metrics.BinaryAccuracy(),
    optimizer=optax.adam(1e-4), 
    #optimizer=optimizers.rmsprop(1e-3), # Oryx based on optax.
    run_eagerly=True, # Seems to be necessary for 
)

history = model.fit(
    x=features,
    y=labels,
    epochs=1,
    steps_per_epoch=10,
    shuffle=False,
    verbose=1,
    #callbacks=[elegy.callbacks.TensorBoard("summaries")]
)

model.save('my_model')  # creates folder at 'my_model'
del model  # deletes the existing model
# returns a model identical to the previous one
model = elegy.load('my_model')
#print(model.states.rng)
# Test ability to make prediction.
x = np.random.rand((7))
model.predict(x).shape

In [None]:
model.run_eagerly = True # Settable attribute. Required to be true for ppmodel.

def ppmodel(key):
  # Define probabilistic programming model as a random sample from input space, 
  # then application of NN, then computation of likelihood.

  # First sample from 7-D gaussian. This is the prior.
  x = ppl.random_variable(tfd.MultivariateNormalDiag(jnp.zeros(7), jnp.ones(7)), name='x')(key)
  # Apply NN.
  yhat = model.predict(x) # Reminder: model should be in eager mode.
  # Compute unnormalized likelihood by comparing prediction with data.
  y = 2. # This is our observed data (only one point in this example.)
  likelihood = jnp.exp((yhat - y)**2) # MSE of prediction.
  return likelihood

start = 10*jnp.ones(7) # This is the starting state for MCMC.
samples = jit(mcmc.sample_chain(mcmc.hmc(ppl.joint_log_prob(ppmodel)), 100))(random.PRNGKey(0), start) # use log_prob doesn't work because function too complex.

# Plot 7d samples from the posterior in 2d.
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5)
plt.show()