In [1]:
import tensorflow as tf
import numpy as np

In [2]:
import tensorflow_probability as tfp

In [3]:
# Load MNIST
(x_dev, y_dev), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

print('There are', len(x_dev), 'training images.')
print('There are', len(x_test), 'test images.')

There are 60000 training images.
There are 10000 test images.


In [4]:
train_ds = tf.data.Dataset.from_tensor_slices((x_dev, y_dev)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [9]:
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras import Model


class StochasticMLP(Model):
    
    def __init__(self, hidden_layer_sizes=[100], n_outputs=10):
        super(StochasticMLP, self).__init__()
        self.hidden_layer_sizes = hidden_layer_sizes
        self.fc_layers = [Dense(layer_size) for layer_size in hidden_layer_sizes]
        self.output_layer = Dense(n_outputs)

    
    def call(self, x):
        
        x = Flatten()(x)
        
        network = []
        
        for i, layer in enumerate(self.fc_layers):
            
            logits = layer(x)
            x = tfp.distributions.Bernoulli(logits=logits).sample()
            network.append(x)     
            
        final_logits = self.output_layer(x)
        
        return network
    
    def target_log_prob(self, x, h, y):
        
        x = Flatten()(x)
        
        h_current = h
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h_current]
        print(h_current)
        #print(np.asarray(h_current).shape)
        h_previous = [x] + h_current[:-1]
        
        nlog_prob = 0.
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current, h_previous, self.fc_layers)):
            
            ce = tf.nn.sigmoid_cross_entropy_with_logits( # only works for discretize version
                labels=cv, logits=layer(pv))
            nlog_prob += tf.reduce_sum(ce, axis = -1)
            
        nlog_prob += tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(y, tf.int32), logits=self.output_layer(h_current[-1]))
            
        return -1 * nlog_prob
    
    def target_log_prob2(self, x, h, y):
        
        x = Flatten()(x)
        #print(h)
        h_current = tf.split(h, self.hidden_layer_sizes, axis = 1)
        #print(h_current)
        h_previous = [x] + h_current[:-1]
        
        nlog_prob = 0.
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current, h_previous, self.fc_layers)):
            
            ce = tf.nn.sigmoid_cross_entropy_with_logits( # only works for discretize version
                labels=cv, logits=layer(pv))
            nlog_prob += tf.reduce_sum(ce, axis = -1)
            
        nlog_prob += tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(y, tf.int32), logits=self.output_layer(h_current[-1]))

        #print(log_prob)
        return -1 * nlog_prob
    
    def run_chain(self, x, h, y):
        
        #print(h)
        
        h_current = h
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h_current]
        h_current = tf.concat([h_current[0], h_current[1]], axis=1)
        print(h_current)
        
        #def tlp(args):
        #    return self.target_log_prob2(x, args, y)
        
        adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(tfp.mcmc.HamiltonianMonteCarlo(
            #target_log_prob_fn = tlp,
            target_log_prob_fn = lambda v: self.target_log_prob2(x, v, y),
            num_leapfrog_steps = 1,
            step_size = pow(1000, -1/4)),
            num_adaptation_steps=int(20 * 0.8))
        
        #print(tf.size(tlp(h_current)))
        
        # Run the chain (with burn-in).
        samples = tfp.mcmc.sample_chain(
          num_results=2, # need one step
          num_burnin_steps=10, # set to 1000
          current_state=h_current,
          kernel=adaptive_hmc,
          trace_fn=None)
        
        print(samples)
        
        return samples[0]

In [10]:
model = StochasticMLP(hidden_layer_sizes = [100, 50], n_outputs=10)

In [7]:
network = [model.call(images) for images, labels in train_ds]

In [11]:
hmc = [model.run_chain(images, net, labels) for (images, labels), net in zip(train_ds, network)]

tf.Tensor(
[[1. 1. 0. ... 1. 1. 1.]
 [1. 1. 1. ... 0. 1. 1.]
 [0. 1. 1. ... 0. 0. 0.]
 ...
 [1. 0. 1. ... 0. 1. 1.]
 [1. 1. 0. ... 0. 1. 0.]
 [1. 1. 1. ... 0. 1. 1.]], shape=(32, 150), dtype=float32)
tf.Tensor(
[[[-5.6951137   7.8013015   4.3768682  ...  0.60175574  1.3609412
    1.1187177 ]
  [ 1.          1.          1.         ...  0.          1.
    1.        ]
  [ 3.522678    4.6131873   7.1041045  ... -0.5670848  -0.06283997
   -0.10579166]
  ...
  [ 1.4024974   8.207304    1.7385392  ... -0.50878274  1.1038109
    0.66525245]
  [ 1.          1.          0.         ...  0.          1.
    0.        ]
  [ 1.          1.          1.         ...  0.          1.
    1.        ]]

 [[-5.6951137   7.8013015   4.3768682  ...  0.60175574  1.3609412
    1.1187177 ]
  [ 1.          1.          1.         ...  0.          1.
    1.        ]
  [ 3.522678    4.6131873   7.1041045  ... -0.5670848  -0.06283997
   -0.10579166]
  ...
  [ 1.4024974   8.207304    1.7385392  ... -0.50878274  1.10381

tf.Tensor(
[[[ 4.276373    6.0086646   3.6891496  ...  0.46253294  1.4644811
    0.86294615]
  [-0.25884122  2.4641366   5.0003066  ...  1.59483    -0.2860292
   -0.08633502]
  [ 0.          1.          0.         ...  0.          0.
    1.        ]
  ...
  [ 0.          1.          1.         ...  1.          0.
    1.        ]
  [ 1.          0.          1.         ...  0.          1.
    1.        ]
  [-9.168439   11.983639    7.174659   ...  0.3023031   1.1199961
    0.77323556]]

 [[ 4.276373    6.0086646   3.6891496  ...  0.46253294  1.4644811
    0.86294615]
  [-0.25884122  2.4641366   5.0003066  ...  1.59483    -0.2860292
   -0.08633502]
  [ 0.          1.          0.         ...  0.          0.
    1.        ]
  ...
  [ 0.          1.          1.         ...  1.          0.
    1.        ]
  [ 1.          0.          1.         ...  0.          1.
    1.        ]
  [-9.168439   11.983639    7.174659   ...  0.3023031   1.1199961
    0.77323556]]], shape=(2, 32, 150), dtype=floa

tf.Tensor(
[[[ 0.0000000e+00  1.0000000e+00  1.0000000e+00 ...  1.0000000e+00
    1.0000000e+00  0.0000000e+00]
  [ 5.0339389e-01  7.3274393e+00 -8.3706589e+00 ...  6.3243657e-01
   -1.5038556e-01  4.9910467e-02]
  [-5.9624052e+00  1.4227279e+01  1.3738582e+00 ...  4.3008214e-01
    1.3686094e+00  1.0152711e+00]
  ...
  [-6.9021549e+00  3.7514753e+00  1.2696905e+01 ... -8.4362328e-03
   -3.7718289e-02  8.7844020e-01]
  [ 5.5981917e+00 -1.3666780e+00  4.4991570e+00 ...  1.5272335e+00
    7.7192152e-01  1.8631378e+00]
  [ 0.0000000e+00  1.0000000e+00  0.0000000e+00 ...  1.0000000e+00
    1.0000000e+00  0.0000000e+00]]

 [[ 6.9831705e+00  1.2002772e+01  3.4656868e+00 ...  7.6104140e-01
    1.0722396e+00  3.8312137e-02]
  [ 5.0339389e-01  7.3274393e+00 -8.3706589e+00 ...  6.3243657e-01
   -1.5038556e-01  4.9910467e-02]
  [-5.9624052e+00  1.4227279e+01  1.3738582e+00 ...  4.3008214e-01
    1.3686094e+00  1.0152711e+00]
  ...
  [-6.9021549e+00  3.7514753e+00  1.2696905e+01 ... -8.4362328e-03

tf.Tensor(
[[[-4.4208055  -0.3509407   7.298955   ...  0.9963965   0.7738724
    0.7769941 ]
  [ 1.9654737   5.342252    5.351407   ... -0.16805458  0.12162067
    0.74014944]
  [ 0.          1.          0.         ...  1.          0.
    0.        ]
  ...
  [-1.4054703   2.4294024   3.9655204  ...  0.10626383  1.0314095
    1.3001223 ]
  [-3.5056324  10.360201   -0.7208296  ... -0.09551713  1.0954963
   -0.17053857]
  [-3.3097036   7.7549095   7.076102   ...  1.2880497   1.3899064
    0.8660152 ]]

 [[-4.4208055  -0.3509407   7.298955   ...  0.9963965   0.7738724
    0.7769941 ]
  [ 1.9654737   5.342252    5.351407   ... -0.16805458  0.12162067
    0.74014944]
  [ 0.          1.          0.         ...  1.          0.
    0.        ]
  ...
  [-1.4054703   2.4294024   3.9655204  ...  0.10626383  1.0314095
    1.3001223 ]
  [-3.5056324  10.360201   -0.7208296  ... -0.09551713  1.0954963
   -0.17053857]
  [-3.3097036   7.7549095   7.076102   ...  1.2880497   1.3899064
    0.8660152 ]]], 

tf.Tensor(
[[[ 1.2850587   8.23685    -0.31204173 ...  0.27800316 -0.0527844
    1.024491  ]
  [ 0.58518267 -0.4905175   4.7918763  ...  0.4170843   1.0034487
    0.5226351 ]
  [ 1.          1.          1.         ...  0.          1.
    1.        ]
  ...
  [ 1.          0.          1.         ...  0.          1.
    1.        ]
  [-6.1805935   7.5791626   5.227957   ...  0.71059173  1.406439
   -0.10009822]
  [ 0.          0.          0.         ...  1.          1.
    1.        ]]

 [[ 1.2850587   8.23685    -0.31204173 ...  0.27800316 -0.0527844
    1.024491  ]
  [ 0.58518267 -0.4905175   4.7918763  ...  0.4170843   1.0034487
    0.5226351 ]
  [ 1.          1.          1.         ...  0.          1.
    1.        ]
  ...
  [ 1.          0.          1.         ...  0.          1.
    1.        ]
  [-6.1805935   7.5791626   5.227957   ...  0.71059173  1.406439
   -0.10009822]
  [ 0.          0.          0.         ...  1.          1.
    1.        ]]], shape=(2, 32, 150), dtype=float3

tf.Tensor(
[[[-9.574444    1.3303851   2.476551   ... -0.651796    0.7664803
    0.7519889 ]
  [ 1.0659534  12.60506     7.696648   ...  0.02066525  0.7992706
   -0.1982836 ]
  [-7.87576     7.038083   -3.8140771  ... -0.2151734   1.4086924
    0.50018823]
  ...
  [ 0.          1.          1.         ...  0.          1.
    1.        ]
  [-6.2660394  14.994463    0.7016987  ...  0.03488117  1.5362352
    0.80348784]
  [ 1.          1.          1.         ...  1.          1.
    1.        ]]

 [[-9.574444    1.3303851   2.476551   ... -0.651796    0.7664803
    0.7519889 ]
  [ 1.0659534  12.60506     7.696648   ...  0.02066525  0.7992706
   -0.1982836 ]
  [-7.87576     7.038083   -3.8140771  ... -0.2151734   1.4086924
    0.50018823]
  ...
  [ 0.          1.          1.         ...  0.          1.
    1.        ]
  [-6.2660394  14.994463    0.7016987  ...  0.03488117  1.5362352
    0.80348784]
  [-5.1809096  13.005457    3.022835   ...  1.2794899   0.8445095
    0.60677314]]], shape=(2,

KeyboardInterrupt: 