In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
import numpy as np
import random
import time
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import initializers
from tensorflow.keras import Model
from tensorflow.keras.layers import Flatten, Dense
from sklearn.metrics import accuracy_score

In [2]:
x_train = np.array([[0, 0],
           [0, 1],
           [1, 0],
           [1, 1]])
y_train = np.array([[0],
           [1],
           [1],
           [0]])
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(4)

In [255]:
# MLP model
class StochasticMLP(Model):
    
    def __init__(self, hidden_layer_sizes=[100], n_outputs=10):
        super(StochasticMLP, self).__init__()
        self.hidden_layer_sizes = hidden_layer_sizes
        self.fc_layers = [Dense(layer_size) for layer_size in hidden_layer_sizes]
        self.output_layer = Dense(n_outputs)
    
    def call(self, x):
        
        #x = Flatten()(x)
        
        network = []
        
        for i, layer in enumerate(self.fc_layers):
            
            logits = layer(x)
            x = tfp.distributions.Bernoulli(logits=logits).sample()
            network.append(x)

        final_logits = self.output_layer(x) # initial the weight of output layer
            
        return network
    
    def target_log_prob(self, x, h, y):
        
        h_current = tf.math.sigmoid([tf.cast(h_i, dtype=tf.float32) for h_i in h])
        h_current = [h_current[0]]
        h_previous = [x] + h_current[:-1]
    
        nlog_prob = 0. # negative log probability
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current, h_previous, self.fc_layers)):
            
            ce = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=cv, logits=layer(pv))
            
            nlog_prob += tf.reduce_sum(ce, axis = -1)
        
        fce = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.cast(y, tf.float32), logits=self.output_layer(h_current[-1]))
        nlog_prob += tf.reduce_sum(fce, axis = -1)
            
        return -1 * nlog_prob

    def target_log_prob2(self, x, h, y):

        h_current = tf.math.sigmoid(tf.split(h, self.hidden_layer_sizes, axis = 1))
        h_current = [h_current[0]]
        h_previous = [x] + h_current[:-1]
        
        nlog_prob = 0.
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current, h_previous, self.fc_layers)):
            
            ce = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=cv, logits=layer(pv))
            
            nlog_prob += tf.reduce_sum(ce, axis = -1)
        
        fce = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.cast(y, tf.float32), logits=self.output_layer(h_current[-1]))
        nlog_prob += tf.reduce_sum(fce, axis = -1)
            
        return -1 * nlog_prob
    
    def generate_hmc_kernel(self, x, y, step_size = pow(1000, -1/4)):
        
        adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn = lambda v: self.target_log_prob2(x, v, y),
            num_leapfrog_steps = 2,
            step_size = step_size),
            num_adaptation_steps=int(100*0.8))
        
        return adaptive_hmc
    
    # new proposing-state method with HamiltonianMonteCarlo
    def propose_new_state_hamiltonian(self, x, h, y):
    
        h_current = h
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h_current]
        h_current = h_current[0]

        # initialize the HMC transition kernel
        
        adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn = lambda v: self.target_log_prob2(x, v, y),
            num_leapfrog_steps = 2,
            step_size = pow(1000, -1/4)),
            num_adaptation_steps=int(100*0.8))

        # run the chain (with burn-in)
        num_results = 1
        num_burnin_steps = 100

        samples = tfp.mcmc.sample_chain(
            num_results = num_results,
            num_burnin_steps = num_burnin_steps,
            current_state = h_current, # may need to be reshaped
            kernel = adaptive_hmc,
            trace_fn = None)

        h_new = tf.split(samples[0], self.hidden_layer_sizes, axis = 1)

        return(h_new)
    
    def update_weights(self, x, h, y, lr = 0.1):
        
        optimizer = tf.keras.optimizers.SGD(learning_rate = lr)
        with tf.GradientTape() as tape:
            loss = -1 * tf.reduce_mean(self.target_log_prob(x, h, y))
        
        grads = tape.gradient(loss, self.trainable_weights)
        optimizer.apply_gradients(zip(grads, self.trainable_weights))

    def get_predictions(self, x):

        logits = 0.0
        for layer in self.fc_layers:
            logits = layer(x)
            x = tf.math.sigmoid(logits)
        
        logits = self.output_layer(x)
        probs = tf.math.sigmoid(logits)
        #print(probs)
        labels = tf.cast(tf.math.greater(probs, 0.5), tf.int32)

        return labels

In [256]:
model = StochasticMLP(hidden_layer_sizes = [2], n_outputs=1)

In [257]:
network = [model.call(images) for images, labels in train_ds]

In [258]:
# Set weight
w_0 = np.array([[1, -1], [1, -1]], dtype = "float32")
b_0 = np.array([-0.5, 1], dtype = "float32")
l_0 = [w_0, b_0]

w_1 = np.array([[1], [1]], dtype = "float32")
b_1 = np.array([-1], dtype = "float32")
l_1 = [w_1, b_1]

model.fc_layers[0].set_weights(l_0)
model.output_layer.set_weights(l_1)

In [259]:
network = [model.call(images) for images, labels in train_ds]
network

[[<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
  array([[0, 0],
         [0, 0],
         [0, 1],
         [1, 0]], dtype=int32)>]]

In [260]:
kernels = [model.generate_hmc_kernel(images, labels) for images, labels in train_ds]

In [261]:
burnin = 100
for i in range(burnin):
    
    network_new = []
    kernels_new = []
    
    for (images, labels), net, hmc_kernel in zip(train_ds, network, kernels):
        net_current = net
        net_current = [tf.cast(net_i, dtype=tf.float32) for net_i in net_current]
        net_current = net_current[0]
        
        num_results = 1
        num_burnin_steps = 0

        samples = tfp.mcmc.sample_chain(
            num_results = num_results,
            num_burnin_steps = num_burnin_steps,
            current_state = net_current, # may need to be reshaped
            kernel = hmc_kernel,
            #trace_fn = lambda _, pkr: pkr.inner_results.accepted_results.new_step_size,
            trace_fn = None,
            return_final_kernel_results = True)
        
        #print(samples)
        #print("__________________________________")
        print(samples[2][4].numpy())
        new_step_size = samples[2][4].numpy()
        net_new = tf.split(samples[0][0], [2], axis = 1)   
        network_new.append(net_new)
        
        # build new kernel
        ker_new = model.generate_hmc_kernel(images, labels, new_step_size)
        kernels_new.append(ker_new)
            
    network = network_new
    kernels = kernels_new

0.17960621
0.18140228
0.1832163
0.18504846
0.18689895
0.18876794
0.19065562
0.19256218
0.1944878
0.19643266
0.198397
0.20038097
0.20238477
0.20440862
0.2064527
0.20851722
0.21060239
0.21270841
0.2148355
0.21698385
0.21915369
0.22134522
0.22355866
0.22579426
0.2280522
0.23033272
0.23263605
0.2349624
0.23731202
0.23968513
0.24208198
0.2445028
0.24694782
0.2494173
0.25191146
0.25443056
0.25697488
0.2595446
0.26214007
0.26476148
0.2674091
0.2700832
0.27278402
0.27551186
0.27826697
0.28104964
0.28386015
0.28669876
0.28956574
0.2924614
0.29538602
0.29833987
0.30132326
0.3043365
0.30737984
0.31045362
0.31355816
0.31669375
0.3198607
0.3230593
0.3262899
0.3295528
0.3328483
0.33617678
0.33953854
0.34293392
0.34636325
0.34982687
0.35332513
0.35685837
0.36042696
0.36403123
0.36767152
0.37134823
0.37506172
0.37881234
0.38260046
0.38642645
0.3902907
0.39419362
0.39813554
0.4021169
0.40613806
0.41019943
0.41430143
0.41844442
0.42262888
0.42685518
0.43112373
0.43543497
0.43978932
0.44418722
0.44862908

In [264]:
epochs = 500

start_time = time.time()
for epoch in range(epochs):
    
    loss = 0.0
    acc = 0.0
    for bs, (images, labels) in enumerate(train_ds):
        
        # only one mini-batch
        model.update_weights(images, network[bs], labels, 0.1)
        #network_new = [model.propose_new_state_hamiltonian(images, net, labels) for (images, labels), net in 
        #               zip(train_ds, network)]
        network_new = []
        kernels_new = []
        for (images, labels), net, hmc_kernel in zip(train_ds, network, kernels):
            net_current = net
            net_current = [tf.cast(net_i, dtype=tf.float32) for net_i in net_current]
            net_current = net_current[0]
        
            num_results = 1
            num_burnin_steps = 0

            samples = tfp.mcmc.sample_chain(
                num_results = num_results,
                num_burnin_steps = num_burnin_steps,
                current_state = net_current, # may need to be reshaped
                kernel = hmc_kernel,
                trace_fn = None,
                return_final_kernel_results = True)

            print(samples[2][4].numpy())
            new_step_size = samples[2][4].numpy()
            net_new = tf.split(samples[0][0], [2], axis = 1)   
            network_new.append(net_new)
            
            ker_new = model.generate_hmc_kernel(images, labels, new_step_size)
            kernels_new.append(ker_new)
            
        network = network_new
        kernels = kernels_new
        
        loss += -1 * tf.reduce_mean(model.target_log_prob(images, network[bs], labels))
    
    preds = [model.get_predictions(images) for images, labels in train_ds]
    train_acc = accuracy_score(np.array(preds[0]), y_train)
    print("Epoch %d/%d: - %.4fs/step - loss: %.4f - accuracy: %.4f" 
          % (epoch + 1, epochs, (time.time() - start_time) / (epoch + 1), loss, train_acc))

3995.9446
Epoch 1/500: - 0.0845s/step - loss: 0.7127 - accuracy: 0.5000
3956.3809
Epoch 2/500: - 0.0630s/step - loss: 0.7127 - accuracy: 0.5000
3917.2087
Epoch 3/500: - 0.0559s/step - loss: 0.7126 - accuracy: 0.5000
3956.3809
Epoch 4/500: - 0.0525s/step - loss: 0.7126 - accuracy: 0.5000
3917.2087
Epoch 5/500: - 0.0520s/step - loss: 0.7126 - accuracy: 0.5000
3956.3809
Epoch 6/500: - 0.0534s/step - loss: 0.7126 - accuracy: 0.5000
3995.9446
Epoch 7/500: - 0.0521s/step - loss: 0.7126 - accuracy: 0.5000
4035.904
Epoch 8/500: - 0.0506s/step - loss: 0.7125 - accuracy: 0.5000
4076.263
Epoch 9/500: - 0.0495s/step - loss: 0.7125 - accuracy: 0.5000
4117.0254
Epoch 10/500: - 0.0495s/step - loss: 0.7125 - accuracy: 0.5000
4158.196
Epoch 11/500: - 0.0493s/step - loss: 0.7125 - accuracy: 0.5000
4199.778
Epoch 12/500: - 0.0497s/step - loss: 0.7124 - accuracy: 0.5000
4241.7754
Epoch 13/500: - 0.0491s/step - loss: 0.7124 - accuracy: 0.5000
4284.193
Epoch 14/500: - 0.0484s/step - loss: 0.7124 - accuracy:

Epoch 115/500: - 0.0533s/step - loss: 0.7170 - accuracy: 0.5000
10283.702
Epoch 116/500: - 0.0534s/step - loss: 0.7167 - accuracy: 0.5000
10386.539
Epoch 117/500: - 0.0534s/step - loss: 0.7165 - accuracy: 0.5000
10490.404
Epoch 118/500: - 0.0535s/step - loss: 0.7163 - accuracy: 0.5000
10595.309
Epoch 119/500: - 0.0536s/step - loss: 0.7162 - accuracy: 0.5000
10701.262
Epoch 120/500: - 0.0536s/step - loss: 0.7160 - accuracy: 0.5000
10808.274
Epoch 121/500: - 0.0537s/step - loss: 0.7158 - accuracy: 0.5000
10916.357
Epoch 122/500: - 0.0538s/step - loss: 0.7157 - accuracy: 0.5000
11025.5205
Epoch 123/500: - 0.0539s/step - loss: 0.7156 - accuracy: 0.5000
11135.775
Epoch 124/500: - 0.0539s/step - loss: 0.7155 - accuracy: 0.5000
11247.133
Epoch 125/500: - 0.0539s/step - loss: 0.7154 - accuracy: 0.5000
11359.6045
Epoch 126/500: - 0.0539s/step - loss: 0.7153 - accuracy: 0.5000
11473.2
Epoch 127/500: - 0.0540s/step - loss: 0.7152 - accuracy: 0.5000
11587.932
Epoch 128/500: - 0.0541s/step - loss: 

30725.582
Epoch 230/500: - 0.0526s/step - loss: 0.7119 - accuracy: 0.5000
31032.838
Epoch 231/500: - 0.0526s/step - loss: 0.7119 - accuracy: 0.5000
31343.166
Epoch 232/500: - 0.0525s/step - loss: 0.7119 - accuracy: 0.5000
31656.598
Epoch 233/500: - 0.0525s/step - loss: 0.7119 - accuracy: 0.5000
31973.164
Epoch 234/500: - 0.0524s/step - loss: 0.7118 - accuracy: 0.5000
32292.895
Epoch 235/500: - 0.0524s/step - loss: 0.7118 - accuracy: 0.5000
32615.822
Epoch 236/500: - 0.0524s/step - loss: 0.7118 - accuracy: 0.5000
32941.98
Epoch 237/500: - 0.0523s/step - loss: 0.7118 - accuracy: 0.5000
33271.4
Epoch 238/500: - 0.0523s/step - loss: 0.7118 - accuracy: 0.5000
33604.113
Epoch 239/500: - 0.0523s/step - loss: 0.7117 - accuracy: 0.5000
33940.152
Epoch 240/500: - 0.0522s/step - loss: 0.7117 - accuracy: 0.5000
34279.555
Epoch 241/500: - 0.0522s/step - loss: 0.7117 - accuracy: 0.5000
34622.35
Epoch 242/500: - 0.0522s/step - loss: 0.7117 - accuracy: 0.5000
34968.574
Epoch 243/500: - 0.0521s/step - 

85625.11
Epoch 345/500: - 0.0530s/step - loss: 0.7097 - accuracy: 0.5000
86481.36
Epoch 346/500: - 0.0530s/step - loss: 0.7097 - accuracy: 0.5000
87346.17
Epoch 347/500: - 0.0530s/step - loss: 0.7097 - accuracy: 0.5000
88219.63
Epoch 348/500: - 0.0530s/step - loss: 0.7097 - accuracy: 0.5000
89101.83
Epoch 349/500: - 0.0530s/step - loss: 0.7096 - accuracy: 0.5000
89992.84
Epoch 350/500: - 0.0531s/step - loss: 0.7096 - accuracy: 0.5000
90892.77
Epoch 351/500: - 0.0531s/step - loss: 0.7096 - accuracy: 0.5000
91801.7
Epoch 352/500: - 0.0532s/step - loss: 0.7096 - accuracy: 0.5000
92719.72
Epoch 353/500: - 0.0532s/step - loss: 0.7096 - accuracy: 0.5000
93646.914
Epoch 354/500: - 0.0532s/step - loss: 0.7096 - accuracy: 0.5000
94583.38
Epoch 355/500: - 0.0532s/step - loss: 0.7095 - accuracy: 0.5000
95529.22
Epoch 356/500: - 0.0533s/step - loss: 0.7095 - accuracy: 0.5000
96484.51
Epoch 357/500: - 0.0533s/step - loss: 0.7095 - accuracy: 0.5000
97449.35
Epoch 358/500: - 0.0534s/step - loss: 0.70

250789.3
Epoch 459/500: - 0.0548s/step - loss: 0.7080 - accuracy: 0.5000
253297.19
Epoch 460/500: - 0.0548s/step - loss: 0.7080 - accuracy: 0.5000
255830.16
Epoch 461/500: - 0.0548s/step - loss: 0.7080 - accuracy: 0.5000
258388.45
Epoch 462/500: - 0.0548s/step - loss: 0.7080 - accuracy: 0.5000
260972.33
Epoch 463/500: - 0.0547s/step - loss: 0.7079 - accuracy: 0.5000
258388.45
Epoch 464/500: - 0.0547s/step - loss: 0.7079 - accuracy: 0.5000
260972.33
Epoch 465/500: - 0.0547s/step - loss: 0.7079 - accuracy: 0.5000
263582.06
Epoch 466/500: - 0.0547s/step - loss: 0.7079 - accuracy: 0.5000
266217.88
Epoch 467/500: - 0.0547s/step - loss: 0.7079 - accuracy: 0.5000
268880.06
Epoch 468/500: - 0.0548s/step - loss: 0.7079 - accuracy: 0.5000
271568.88
Epoch 469/500: - 0.0547s/step - loss: 0.7079 - accuracy: 0.5000
274284.56
Epoch 470/500: - 0.0547s/step - loss: 0.7079 - accuracy: 0.5000
277027.4
Epoch 471/500: - 0.0547s/step - loss: 0.7078 - accuracy: 0.5000
279797.7
Epoch 472/500: - 0.0547s/step -

In [126]:
pow(1000,-0.25)

0.1778279410038923