In [1]:
import tensorflow as tf
import numpy as np

In [2]:
import tensorflow_probability as tfp

In [3]:
# Load MNIST
(x_dev, y_dev), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

print('There are', len(x_dev), 'training images.')
print('There are', len(x_test), 'test images.')

There are 60000 training images.
There are 10000 test images.


In [4]:
train_ds = tf.data.Dataset.from_tensor_slices((x_dev, y_dev)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [5]:
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras import Model


class StochasticMLP(Model):
    
    def __init__(self, hidden_layer_sizes=[100], n_outputs=10):
        super(StochasticMLP, self).__init__()
        self.hidden_layer_sizes = hidden_layer_sizes
        self.fc_layers = [Dense(layer_size) for layer_size in hidden_layer_sizes]
        self.output_layer = Dense(n_outputs)

    
    def call(self, x):
        
        x = Flatten()(x)
        
        network = []
        
        for i, layer in enumerate(self.fc_layers):
            
            logits = layer(x)
            x = tfp.distributions.Bernoulli(logits=logits).sample()
            network.append(x)     
            
        final_logits = self.output_layer(x)
        
        return network
    
    def target_log_prob(self, x, h, y):
        
        x = Flatten()(x)
        
        h_current = h
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h_current]
        h_previous = [x] + h_current[:-1]
        
        log_prob = 0.
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current, h_previous, self.fc_layers)):
            
            ce = tf.nn.sigmoid_cross_entropy_with_logits( # only works for discretize version
                labels=cv, logits=layer(pv))
            log_prob += tf.reduce_sum(ce, axis = -1)
            
        log_prob += tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(y, tf.int32), logits=self.output_layer(h_current[-1]))
            
        return log_prob
    
    def target_log_prob2(self, x, h_current, y):
        
        x = Flatten()(x)
        #h_current_new = [tf.split(h_current[i], [self.hidden_layer_sizes[0]]) for i in range(32)]
        h_current_new = tf.split(h_current, [self.hidden_layer_sizes[0], self.hidden_layer_sizes[1]], axis = 1)
        h_previous = [x] + h_current_new[:-1]
        
        log_prob = 0.
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current_new, h_previous, self.fc_layers)):
            
            ce = tf.nn.sigmoid_cross_entropy_with_logits( # only works for discretize version
                labels=cv, logits=layer(pv))
            log_prob += tf.reduce_sum(ce, axis = -1)
            
        log_prob += tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(y, tf.int32), logits=self.output_layer(h_current_new[-1]))
            
        return log_prob
    
    def run_chain(self, x, h, y):
        
        x = Flatten()(x)
        #h_current = h
        h_current = tf.concat([h[0], h[1]], axis=1)
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h_current]
        print(np.asarray(h_current).shape)
        #print(h)
        
        def tlp(*args):
            return self.target_log_prob2(x, args, y)
        
        adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn = tlp,
            num_leapfrog_steps = 10,
            step_size = pow(1000, -1/4)),
            num_adaptation_steps=int(1000 * 0.8))
        
        # Run the chain (with burn-in).
        samples = tfp.mcmc.sample_chain(
          num_results=5,
          num_burnin_steps=1000, # set to 1000
          current_state=h_current,
          kernel=adaptive_hmc,
          trace_fn=None)
        
        sample_mean = tf.reduce_mean(samples)
        
        return sample_mean

In [6]:
model = StochasticMLP(hidden_layer_sizes = [100, 50], n_outputs=10)

In [7]:
network = [model.call(images) for images, labels in train_ds]

In [8]:
hmc = [model.run_chain(images, net, labels) for (images, labels), net in zip(train_ds, network)]

(32, 150)


  'TensorFloat-32 matmul/conv are enabled for NVIDIA Ampere+ GPUs. The '


InvalidArgumentError: Inputs to operation AddN of type AddN must have the same size and shape.  Input 0: [32] != input 2: [150] [Op:AddN]