In [1]:
# Training NN with Elegy then do HMC with Oryx on Mana
# Author: Peter Oct 24 2021
# Requirements: 
#!module load system/CUDA/11.0.2 
#!pip install --upgrade jax jaxlib==0.1.68+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html 
#!pip install tensorflow-io oryx elegy

from collections import defaultdict
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid')

import os
#os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
#os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.10'
import jax
import jax.numpy as jnp
from jax import random
from jax import vmap
from jax import jit
from jax import grad

assert jax.default_backend() == 'gpu'

import oryx  # pip install oryx
import elegy # pip install elegy
import optax
#import flax.linen
import tensorflow_io as tfio # pip install tensorflow-io
import tensorflow as tf # Recommended not to import this with jax because will also try to grab memory.

tfd = oryx.distributions

state = oryx.core.state
ppl = oryx.core.ppl

inverse = oryx.core.inverse
ildj = oryx.core.ildj
plant = oryx.core.plant
reap = oryx.core.reap
sow = oryx.core.sow
unzip = oryx.core.unzip

nn = oryx.experimental.nn
mcmc = oryx.experimental.mcmc
optimizers = oryx.experimental.optimizers # Oryx based on https://github.com/deepmind/optax
#optimizer = optimizers.adam(1e-4) # Oryx based on https://github.com/deepmind/optax
#opt_state = state.init(optimizer)(opt_key, network, network)


# Train and save NN model.

In [2]:
# Define dataset
f = '/mnt/lts/nfs_fs02/sadow_lab/shared/gcr/data/proposal/data_processed.hdf5'
MEAN_VALS = np.array([41.094597, 6.4734573, 160.1557, 1.0845096, 1.3110111, 1.1005557, 1.2972884])
MIN_VALS = np.array([20., 4.5, 50., 0.2, 0.2, 0.2, 0.2])
MAX_VALS = np.array([75., 8.5, 250., 2., 2.3, 2., 2.3])
RANGE_VALS = MAX_VALS - MIN_VALS
#f = './data_processed_sample.hdf5'
x = tfio.IODataset.from_hdf5(f, dataset='/features')
y = tfio.IODataset.from_hdf5(f, dataset='/logp1_flux')
full = tf.data.Dataset.zip((x, y))
full = full.map(lambda x, y: ((x - MIN_VALS) / RANGE_VALS, y)) # Min-max scale.
# Split
train = full.take(np.floor(1435308 *.9)).repeat()
test = full.skip(np.floor(1435308 *.9)).repeat()
# Batch
BATCH_SIZE = 128
train = train.batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
test = test.batch(BATCH_SIZE, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)


class MLP(elegy.Module):
    # Defines stateful model. See elegy docs.
    # def __init__(self, scale: np.ndarray, **kwargs):
    #     self.Scale = Scale(scale)
    #     super().__init__(**kwargs)
    def call(self, x: jnp.ndarray) -> jnp.ndarray:
        #x = x / MEAN_VALS # Mean-scale inputs.
        x = elegy.nn.Linear(1024)(x)
        x = jax.nn.relu(x)
        x = elegy.nn.Linear(1024)(x)
        x = jax.nn.tanh(x)
        x = elegy.nn.Linear(245)(x)
        return x

model = elegy.Model(
    # Creates stateful model and optimizer. See elegy docs.
    module=MLP(),
    loss=elegy.losses.MeanSquaredError(),
    #loss=[elegy.losses.MeanSquaredError(), elegy.regularizers.GlobalL2(l=1e-5)]
    #metrics=elegy.metrics.BinaryAccuracy(),
    optimizer=optax.adam(1e-4), 
    #optimizer=optimizers.rmsprop(1e-3), # Oryx based on optax. Doesn't work with elegy.
    #run_eagerly=True, # Seems to be necessary for 
)

callbacks = [elegy.callbacks.EarlyStopping(monitor="val_loss", patience=1),
             #elegy.callbacks.TensorBoard("summaries"),
            ]

history = model.fit(
    x=train,
    epochs=6,
    steps_per_epoch=10000,
    validation_data=test,
    validation_steps=1000,
    shuffle=False,
    verbose=2,
    callbacks=callbacks,
)

# Save model. 
model.save('my_model')  # creates folder at 'my_model'
#del model  # deletes the existing model
# Test ability to make prediction.
x = np.random.rand((7))
model.predict(x).shape

2021-10-28 20:13:23.880349: W tensorflow_io/core/kernels/audio_video_mp3_kernels.cc:271] libmp3lame.so.0 or lame functions are not available
2021-10-28 20:13:24.658039: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:39] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.


Epoch 1/6
10000/10000 - 88s - loss: 0.2650 - mean_squared_error_loss: 0.2650 - val_loss: 1.7224e-04 - val_mean_squared_error_loss: 1.7224e-04
Epoch 2/6
10000/10000 - 81s - loss: 1.4207e-04 - mean_squared_error_loss: 1.4207e-04 - val_loss: 1.2932e-04 - val_mean_squared_error_loss: 1.2932e-04
Epoch 3/6
10000/10000 - 79s - loss: 1.2226e-04 - mean_squared_error_loss: 1.2226e-04 - val_loss: 1.3966e-04 - val_mean_squared_error_loss: 1.3966e-04


(245,)

# Use trained model for Hamiltonian Monte Carlo

In [3]:
# See hmc_gcr.ipynb