In [1]:
import tensorflow as tf
import numpy as np

In [2]:
import tensorflow_probability as tfp

In [3]:
# Load MNIST
(x_dev, y_dev), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

print('There are', len(x_dev), 'training images.')
print('There are', len(x_test), 'test images.')

There are 60000 training images.
There are 10000 test images.


In [4]:
train_ds = tf.data.Dataset.from_tensor_slices((x_dev, y_dev)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [17]:
from tensorflow.keras.layers import Flatten, Dense
from tensorflow.keras import Model


class StochasticMLP(Model):
    
    def __init__(self, hidden_layer_sizes=[100], n_outputs=10):
        super(StochasticMLP, self).__init__()
        self.hidden_layer_sizes = hidden_layer_sizes
        self.fc_layers = [Dense(layer_size) for layer_size in hidden_layer_sizes]
        self.output_layer = Dense(n_outputs)

    
    def call(self, x):
        
        x = Flatten()(x)
        
        network = []
        
        for i, layer in enumerate(self.fc_layers):
            
            logits = layer(x)
            x = tfp.distributions.Bernoulli(logits=logits).sample()
            network.append(x)     
            
        final_logits = self.output_layer(x)
        
        return network
    
    def target_log_prob(self, x, h, y):
        
        x = Flatten()(x)
        
        h_current = h
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h_current]
        print(h_current)
        #print(np.asarray(h_current).shape)
        h_previous = [x] + h_current[:-1]
        
        nlog_prob = 0.
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current, h_previous, self.fc_layers)):
            
            ce = tf.nn.sigmoid_cross_entropy_with_logits( # only works for discretize version
                labels=cv, logits=layer(pv))
            nlog_prob += tf.reduce_sum(ce, axis = -1)
            
        nlog_prob += tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(y, tf.int32), logits=self.output_layer(h_current[-1]))
            
        return -1 * nlog_prob
    
    def target_log_prob2(self, x, h, y):
        
        x = Flatten()(x)
        #print(h)
        h_current = tf.split(h, self.hidden_layer_sizes, axis = 1)
        #print(h_current)
        h_previous = [x] + h_current[:-1]
        
        nlog_prob = 0.
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current, h_previous, self.fc_layers)):
            
            ce = tf.nn.sigmoid_cross_entropy_with_logits( # only works for discretize version
                labels=cv, logits=layer(pv))
            nlog_prob += tf.reduce_sum(ce, axis = -1)
            
        nlog_prob += tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=tf.cast(y, tf.int32), logits=self.output_layer(h_current[-1]))

        #print(log_prob)
        return -1 * nlog_prob
    
    def run_chain(self, x, h, y):
        
        #print(h)
        
        h_current = h
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h_current]
        h_current = tf.concat([h_current[0], h_current[1]], axis=1)
        #print(h_current)
        
        #def tlp(args):
        #    return self.target_log_prob2(x, args, y)
        
        adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn = lambda v: self.target_log_prob2(x, v, y),
            num_leapfrog_steps = 2,
            step_size = 0.1), #pow(1000, -1/4)),
            num_adaptation_steps = int(100 * 0.8))
        
        #print(tf.size(tlp(h_current)))
        
        # Run the chain (with burn-in).
        samples = tfp.mcmc.sample_chain(
          num_results=1, # need one step
          num_burnin_steps=100, # set to 1000
          current_state=h_current,
          kernel=adaptive_hmc,
          trace_fn = None,
          #trace_fn = lambda _, pkr: pkr.inner_results.accepted_results.step_size)
          return_final_kernel_results = True)
        
        print(samples)
        # is_accepted = tf.reduce_mean(tf.cast(is_accepted, dtype=tf.float32))
        new_state = tf.math.sign(tf.math.sign(samples[0]) - 1) + 1
        h_new = tf.split(new_state, self.hidden_layer_sizes, axis = 1)
        
        #print(h_new)
        #print(is_accepted)
        #print(samples)
        
        return h_new

In [18]:
model = StochasticMLP(hidden_layer_sizes = [100, 50], n_outputs=10)

In [12]:
network = [model.call(images) for images, labels in train_ds]

In [19]:
hmc = [model.run_chain(images, net, labels) for (images, labels), net in zip(train_ds, network)]

CheckpointableStatesAndTrace(
  all_states=<tf.Tensor: shape=(1, 32, 150), dtype=float32, numpy=
    array([[[ 2.49343834e+01, -3.70349464e+01,  5.96246605e+01, ...,
              1.35633886e-01, -7.43007660e-01,  4.56144857e+00],
            [-1.18630180e+01, -1.71547661e+01,  2.02893143e+01, ...,
              1.37810886e+00, -8.39252710e-01,  4.92110825e+00],
            [ 1.34826088e+00, -1.32906685e+01,  3.29797134e+01, ...,
              1.67644775e+00,  1.04760993e+00,  5.42276669e+00],
            ...,
            [-7.92915225e-01, -1.89870090e+01,  4.76548271e+01, ...,
              4.88576055e-01,  2.92850518e+00,  5.71255445e+00],
            [ 3.32595520e+01,  3.42668457e+01,  5.35454750e+01, ...,
              2.24132204e+00,  6.78587914e+00,  1.99496672e-01],
            [ 1.96138611e+01, -5.03597717e+01,  4.40457916e+01, ...,
             -2.46609950e+00, -2.59272754e-04, -3.38173479e-01]]],
          dtype=float32)>,
  trace=(),
  final_kernel_results=SimpleStepSizeAdap

InvalidArgumentError: Determined shape must either match input shape along split_dim exactly if fully specified, or be less than the size of the input along split_dim if not fully specified.  Got: 150 [Op:SplitV] name: split

In [9]:
a = tf.constant([[1.08,2.0,-3.0],[-4,0,-1]])
b = tf.math.sign(tf.math.sign(a)-1)+1
b

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 0.],
       [0., 0., 0.]], dtype=float32)>

In [64]:
pow(1000, -1/4)

0.1778279410038923