In [1]:
!pip install -U tensorflow-probability==0.16.0 --user
!pip install -U tensorflow==2.8.0 --user
!pip install tensorflow_datasets --user



In [2]:
!module load tensorflow
!module load tensorflow_probability
!module load tensorflow_datasets


The following have been reloaded with a version change:
  1) cudnn/8.2.0 => cudnn/8.3.2     3) tensorflow/2.6.0 => tensorflow/2.9.0
  2) gcc/10.3.0 => gcc/11.2.0

[1;31mLmod has detected the following error: [0m The following module(s) are
unknown: "tensorflow_probability"

Please check the spelling or version number. Also try "module spider ..."
It is also possible your cache file is out-of-date; it may help to try:
  $ module --ignore-cache load "tensorflow_probability"

Also make sure that all modulefiles written in TCL start with the string
#%Module



[1;31mLmod has detected the following error: [0m The following module(s) are
unknown: "tensorflow_datasets"

Please check the spelling or version number. Also try "module spider ..."
It is also possible your cache file is out-of-date; it may help to try:
  $ module --ignore-cache load "tensorflow_datasets"

Also make sure that all modulefiles written in TCL start with the string
#%Module





In [3]:
import numpy as np
import os,re
import sklearn.datasets as skd
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras import layers, Input
import time
import matplotlib.pyplot as plt

import tensorflow_probability as tfp
tfb = tfp.bijectors
tfd = tfp.distributions

import ffjord_terms

In [4]:
class MLP_ODE(keras.Model):
    """Multi-layer NN ode_fn."""
    def __init__(self, num_hidden, num_layers, num_output,num_cond=2,name='mlp_ode'):
        super(MLP_ODE, self).__init__()
        self._num_hidden = num_hidden
        self._num_output = num_output
        self._num_layers = num_layers
        self._num_cond = num_cond
        self._modules = []
        
        #Fully connected layers with tanh activation and linear output
        self._modules.append(Input(shape=(1+self._num_output+self._num_cond))) #time is part of the inputs
        for _ in range(self._num_layers - 1):
            self._modules.append(layers.Dense(self._num_hidden,activation='tanh'))
            
        self._modules.append(layers.Dense(self._num_output,activation=None))
        self._model = keras.Sequential(self._modules)

        if self._num_cond > 1:
            #In more dimensions, is useful to feed the conditional distributions after passing through an independent network model
            self._cond_model = keras.Sequential(
                [
                    Input(shape=(self._num_cond)),
                    layers.Dense(self._num_hidden,activation='relu'),
                    layers.Dense(self._num_cond,activation=None),
                ])
        
    @tf.function
    def call(self, t, data,conditional_input=None):
        if self._num_cond==1:
            #No network for a single feature
            cond_transform=tf.cast(conditional_input,dtype=tf.float32)
        else:
            cond_transform = self._cond_model(conditional_input)
            
        t = t*tf.ones([data.shape[0],1])
        inputs = tf.concat([t, data,cond_transform], -1)
        return self._model(inputs)

def make_bijector_kwargs(bijector, name_to_kwargs):
    #Hack to pass the conditional information through all the bijector layers
    if hasattr(bijector, 'bijectors'):
        return {b.name: make_bijector_kwargs(b, name_to_kwargs) for b in bijector.bijectors}
    else:
        for name_regex, kwargs in name_to_kwargs.items():
            if re.match(name_regex, bijector.name):
                return kwargs
    return {}

def save_model(model,name="ffjord",checkpoint_dir = './checkpoints'):
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)
    model.save_weights('{}/{}'.format(checkpoint_dir,name,save_format='tf'))

def load_model(model,name="ffjord",checkpoint_dir = './checkpoints'):
    model.load_weights('{}/{}'.format(checkpoint_dir,name,save_format='tf')).expect_partial()
    
        
class FFJORD(keras.Model):
    def __init__(self, stacked_mlps, batch_size,num_output,trace_type='hutchinson',name='FFJORD'): #editing
        super(FFJORD, self).__init__()
        self._num_output=num_output
        self._batch_size = batch_size 
        ode_solve_fn = tfp.math.ode.DormandPrince(atol=1e-5).solve
        #Gaussian noise to trace solver
        if trace_type=='hutchinson':
            trace_augmentation_fn = ffjord_terms.trace_jacobian_hutchinson
        elif trace_type == 'exact':
            trace_augmentation_fn = ffjord_terms.trace_jacobian_exact
        else:
            raise Exception("Invalid trace estimator")
        
        
        self.bijectors = []
        for imlp,mlp in enumerate(stacked_mlps):
            ffjord = ffjord_terms.FFJORD(
                state_time_derivative_fn=mlp,
                ode_solve_fn=ode_solve_fn,
                trace_augmentation_fn=trace_augmentation_fn,
                name='bijector{}'.format(imlp), #Bijectors need to be names to receive conditional inputs
                jacobian_factor = 0.1,
                kinetic_factor = 0.3
            )
            self.bijectors.append(ffjord)

        #Reverse the bijector order
        self.chain = tfb.Chain(list(reversed(self.bijectors)))

        self.loss_tracker = keras.metrics.Mean(name="loss")
        #Determine the base distribution... may need to be switched to 1D normal, not sure
        self.base_distribution = tfp.distributions.MultivariateNormalDiag(
            loc=self._num_output*[0.0], scale_diag=self._num_output*[1.0]
        )
        
        self.flow=self.Transform()
        self._variables = self.flow.variables
        
    @property
    def metrics(self):
        """List of the model's metrics.
        We make sure the loss tracker is listed as part of `model.metrics`
        so that `fit()` and `evaluate()` are able to `reset()` the loss tracker
        at the start of each epoch and at the start of an `evaluate()` call.
        """
        return [self.loss_tracker]
    
    @tf.function
    def call(self, inputs, conditional_input=None):
        kwargs = make_bijector_kwargs(self.flow.bijector,{'bijector.': {'conditional_input':conditional_input }})
        return self.flow.bijector.forward(inputs,**kwargs)
        
            
    def Transform(self):        
        return tfd.TransformedDistribution(distribution=self.base_distribution, bijector=self.chain)

    
    @tf.function
    def log_loss(self,_x,_c):
        loss = -tf.reduce_mean(self.flow.log_prob(
            _x,
            bijector_kwargs=make_bijector_kwargs(
                self.flow.bijector, {'bijector.': {'conditional_input': _c}})                                      
        ))
        
        regularization_loss = tf.zeros_like(_x)
        stacked = len(self.bijectors)
        current_positions = _x
        
        for i in range(stacked):
            index = stacked - 1 - i
            kwargs = bijector_kwargs=make_bijector_kwargs(self.bijectors[index], {'bijector.': {'conditional_input': _c}})
            current_positions, current_loss = self.bijectors[index]._regularization_loss(current_positions, **kwargs)
            regularization_loss = regularization_loss + current_loss
        
        loss = loss + regularization_loss
        
        return loss    
    
    @tf.function
    def conditional_prob(self,_x,_c):
        prob = self.flow.prob(
            _x,
            bijector_kwargs=make_bijector_kwargs(
                self.flow.bijector, {'bijector.': {'conditional_input': _c}})                                      
        )
        
        return prob
    
    
    @tf.function()
    def train_step(self, values):
        #Full shape needs to be given when using tf.dataset
        data = values[:self._batch_size,:self._num_output]
        cond = values[:self._batch_size,self._num_output:]
        data.set_shape((self._batch_size,self._num_output))
        cond.set_shape((self._batch_size,cond.shape[1]))

        with tf.GradientTape() as tape:
            loss = self.log_loss(data,cond) 
            
        g = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(g, self.trainable_variables))
        self.loss_tracker.update_state(loss)

        return {"loss": self.loss_tracker.result()}
    
    @tf.function
    def test_step(self, values):
        data = values[:self._batch_size,:self._num_output]
        cond = values[:self._batch_size,self._num_output:]
        data.set_shape((self._batch_size,self._num_output))
        cond.set_shape((self._batch_size,cond.shape[1]))
        
        loss = self.log_loss(data,cond)
        self.loss_tracker.update_state(loss)
        return {"loss": self.loss_tracker.result()}

In [None]:
if __name__ == '__main__':

    LR = 5e-3
    NUM_EPOCHS = 50
    STACKED_FFJORDS = 1 #Number of stacked transformations
    NUM_LAYERS = 4 #Hiddden layers per bijector
    NUM_OUTPUT = 28*28 #Output dimension
    NUM_HIDDEN = 4*NUM_OUTPUT #Hidden layer node size
    NUM_COND = 1 #Number of conditional dimensions
    BATCH_SIZE = 256 
    
    #MNIST data
    ds_train, ds_info = tfds.load(
        'mnist',
        split='train',
        shuffle_files=True,
        as_supervised=True,
        with_info=True
    )
    
    def normalize_img(image, label):
      """Normalizes images: `uint8` -> `float32`."""
      return tf.reshape(tf.cast(image, tf.float32) / 255., [-1]), label

    ds_train = ds_train.map(
        normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
    ds_train = ds_train.cache()
    ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
    ds_train = ds_train.batch(BATCH_SIZE)
    ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
    
    ds_numpy = tfds.as_numpy(ds_train)
    samples = np.concatenate([x for x, y in ds_numpy], axis=0)
    conditions = np.concatenate([y for x, y in ds_numpy], axis=0)
    conditions = np.reshape(conditions, [-1, 1]).astype(np.float32)
    samples = np.concatenate([samples,conditions],-1)
    
    #Stack of bijectors 
    stacked_mlps = []
    for _ in range(STACKED_FFJORDS):
        mlp_model = MLP_ODE(NUM_HIDDEN, NUM_LAYERS, NUM_OUTPUT,NUM_COND)
        stacked_mlps.append(mlp_model)

    #Create the model
    model = FFJORD(stacked_mlps,BATCH_SIZE,NUM_OUTPUT,trace_type='hutchinson')
    model.compile(optimizer=keras.optimizers.Adam(learning_rate=LR))
    
    history = model.fit(
        samples,
        batch_size=BATCH_SIZE,
        epochs=NUM_EPOCHS,
        verbose=1,
    )

2022-06-16 10:22:04.497662: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-06-16 10:22:06.155704: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38419 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:03:00.0, compute capability: 8.0
2022-06-16 10:22:06.156364: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 38419 MB memory:  -> device: 1, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:41:00.0, compute capability: 8.0
2022-06-16 10:22:06.156979: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/devi

Epoch 1/50
Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))


Instructions for updating:
back_prop=False is deprecated. Consider using tf.stop_gradient instead.
Instead of:
results = tf.while_loop(c, b, vars, back_prop=False)
Use:
results = tf.nest.map_structure(tf.stop_gradient, tf.while_loop(c, b, vars))
2022-06-16 10:22:37.437429: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.


  6/235 [..............................] - ETA: 28:11 - loss: 3718.6479

In [None]:
#Saving the model for future inference
save_model(model)

#Let's create a new model with the same architecture, but with the exact trace estimator
new_model = FFJORD(stacked_mlps,BATCH_SIZE,NUM_OUTPUT,trace_type='exact',name='loaded_model')
load_model(new_model)

In [None]:
NSAMPLES = DATASET_SIZE * 10
#Sample the learned distribution

base_samples = new_model.base_distribution.sample(NSAMPLES)

In [None]:
kwargs = bijector_kwargs=make_bijector_kwargs(new_model.chain, {'bijector.': {'conditional_input': np.ones((NSAMPLES,1),dtype=np.float32)}})
hopefully_works = new_model.call(base_samples, np.ones((NSAMPLES,1),dtype=np.float32))
print(hopefully_works)

In [None]:
#Inputs to operation AddN of type AddN must have the same size and shape.  Input 0: [10,1] != input 1: [1000,1] [Op:AddN]

In [None]:
index = -1
current_positions = tf.convert_to_tensor([base_samples])
all_positions = current_positions
step_size = 0.05

for i in range(STACKED_FFJORDS):
    index = i
    kwargs = bijector_kwargs=make_bijector_kwargs(model.bijectors[index], {'bijector.': {'conditional_input': np.ones((NSAMPLES,1),dtype=np.float32)}})
    current_positions = model.bijectors[index]._forward_timesteps(current_positions, step_size, **kwargs)
    all_positions = tf.concat([all_positions, current_positions], 0)

In [None]:
current_positions[-1]

In [None]:
plt.hist(current_positions[-1].numpy(), bins='auto')

In [None]:
x = np.arange(STACKED_FFJORDS * (1 / step_size) + 1)
y = np.repeat(x, NSAMPLES)
all_positions = np.reshape(all_positions, [-1])
num_steps = STACKED_FFJORDS * (1 / step_size) + 1
print(np.size(y))
print(np.size(all_positions))

In [None]:
import matplotlib.colors as mcolors
plt.hist2d(all_positions, y, bins=int(num_steps))

In [None]:
plt.hist(base_samples.numpy(), bins='auto')

In [None]:
plt.hist(samples[:,0], bins=50)

In [None]:
model.chain.bijectors