In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow.math as tm
import numpy as np
import random
import time
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import initializers
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense
from sklearn.metrics import accuracy_score

In [2]:
x_train = np.array([[0, 0],
           [0, 1],
           [1, 0],
           [1, 1]])
y_train = np.array([[0],
           [1],
           [1],
           [0]])
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(4)

2021-10-14 17:50:29.989718: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [22]:
class StochasticMLP(Model):
    
    def __init__(self, hidden_layer_sizes=[100], n_outputs=10):
        super(StochasticMLP, self).__init__()
        self.hidden_layer_sizes = hidden_layer_sizes
        self.fc_layers = [Dense(layer_size) for layer_size in hidden_layer_sizes]
        self.output_layer = Dense(n_outputs)
        
    def call(self, x):
        
        network = []
        
        for i, layer in enumerate(self.fc_layers):
            
            logits = layer(x)
            x = tfp.distributions.Bernoulli(logits=logits).sample()
            network.append(x)

        final_logits = self.output_layer(x) # initial the weight of output layer
            
        return network
    
    def target_log_prob(self, x, h, y):
        
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h]
        h_previous = [x] + h_current[:-1]
    
        nlog_prob = 0. # negative log probability
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current, h_previous, self.fc_layers)):
            
            logits = layer(pv)
            
            ce = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=cv, logits=logits)
            
            nlog_prob += tf.reduce_sum(ce, axis = -1)
        
        fce = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.cast(y, tf.float32), logits=self.output_layer(h_current[-1]))
        nlog_prob += tf.reduce_sum(fce, axis = -1)
            
        return -1 * nlog_prob
    
    def gibbs_new_state(self, x, h, y):
        
        '''
            generate a new state for the network node by node in Gibbs setting.
        '''
        
        h_current = h
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h_current]
        
        in_layers = self.fc_layers
        out_layers = self.fc_layers[1:] + [self.output_layer]
        
        prev_vals = [x] + h_current[:-1]
        curr_vals = h_current
        next_vals = h_current[1:] + [y]
        
        for i, (in_layer, out_layer, pv, cv, nv) in enumerate(zip(in_layers, out_layers, prev_vals, curr_vals, next_vals)):

            # node by node
            
            nodes = tf.transpose(cv)
            prob_parents = tm.sigmoid(in_layer(pv))
            
            out_layer_weights = out_layer.get_weights()[0]
            
            next_logits = out_layer(cv)
            
            new_layer = []
            
            for j, node in enumerate(nodes):
                
                # get info for current node (i, j)
                
                prob_parents_j = prob_parents[:, j]
                out_layer_weights_j = out_layer_weights[j]
                
                # calculate logits and logprob for node is 0 or 1
                next_logits_if_node_0 = next_logits[:, :] - node[:, None] * out_layer_weights_j[None, :]
                next_logits_if_node_1 = next_logits[:, :] + (1 - node[:, None]) * out_layer_weights_j[None, :]
                
                #print(next_logits_if_node_0, next_logits_if_node_1)
                
                logprob_children_if_node_0 = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.cast(nv, dtype = tf.float32), logits=next_logits_if_node_0), axis = -1)
                
                logprob_children_if_node_1 = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.cast(nv, dtype = tf.float32), logits=next_logits_if_node_1), axis = -1)
                
                # calculate prob for node (i, j)
                prob_0 = (1 - prob_parents_j) * tm.exp(logprob_children_if_node_0)
                prob_1 = prob_parents_j * tm.exp(logprob_children_if_node_1)
                prob_j = prob_1 / (prob_1 + prob_0)
            
                # sample new state with prob_j for node (i, j)
                new_node = tfp.distributions.Bernoulli(probs = prob_j).sample() # MAY BE SLOW
                
                # update nodes and logits for following calculation
                new_node_casted = tf.cast(new_node, dtype = "float32")
                next_logits = next_logits_if_node_0 * (1 - new_node_casted)[:, None] \
                            + next_logits_if_node_1 * new_node_casted[:, None] 
                
                # keep track of new node values (in prev/curr/next_vals and h_new)
                new_layer.append(new_node)
           
            new_layer = tf.transpose(new_layer)
            h_current[i] = new_layer
            prev_vals = [x] + h_current[:-1]
            curr_vals = h_current
            next_vals = h_current[1:] + [y]
        
        return h_current
    
    def update_weights(self, x, h, y, lr = 0.1):
        
        optimizer = tf.keras.optimizers.SGD(learning_rate = lr)
        with tf.GradientTape() as tape:
            loss = -1 * tf.reduce_mean(self.target_log_prob(x, h, y))
        
        grads = tape.gradient(loss, self.trainable_weights)
        optimizer.apply_gradients(zip(grads, self.trainable_weights))
    
    def get_predictions(self, x):

        logits = 0.0
        for layer in self.fc_layers:
            logits = layer(x)
            x = tm.sigmoid(logits)
        
        logits = self.output_layer(x)
        probs = tm.sigmoid(logits)
        #print(probs)
        labels = tf.cast(tm.greater(probs, 0.5), tf.int32)

        return labels

In [23]:
model = StochasticMLP(hidden_layer_sizes = [2], n_outputs=1)
network = [model.call(x) for x, y in train_ds]

In [24]:
# Set weight
w_0 = np.array([[1, -1], [1, -1]], dtype = "float32")
b_0 = np.array([-0.5, 1], dtype = "float32")
l_0 = [w_0, b_0]

w_1 = np.array([[1], [1]], dtype = "float32")
b_1 = np.array([-1], dtype = "float32")
l_1 = [w_1, b_1]

model.fc_layers[0].set_weights(l_0)
model.output_layer.set_weights(l_1)

In [25]:
network = [model.call(images) for images, labels in train_ds]
network

[[<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
  array([[0, 1],
         [1, 1],
         [1, 0],
         [0, 0]], dtype=int32)>]]

In [26]:
# Gibbs Burnin
burnin = 500

for i in range(burnin):
    
    if(i % 100 == 0): print("Step %d" % i)
    network = [model.gibbs_new_state(images, net, labels) for (images, labels), net in zip(train_ds, network)]

Step 0
Step 100
Step 200
Step 300
Step 400


In [27]:
# Training
epochs = 1500
loss_ls = []
acc_ls = []
start_time = time.time()

for epoch in range(epochs):
    
    loss = 0.0
    acc = 0.0
    for bs, (x, y) in enumerate(train_ds):
        
        # only one mini-batch
        model.update_weights(x, network[bs], y, 0.1)
        network = [model.gibbs_new_state(x, net, y) for (x, y), net in zip(train_ds, network)]
        loss += -1 * tf.reduce_mean(model.target_log_prob(x, network[bs], y))
    
    preds = [model.get_predictions(images) for images, labels in train_ds]
    train_acc = accuracy_score(np.concatenate(preds), y_train)
    loss_ls.append(loss)
    acc_ls.append(train_acc)
    
    print("Epoch %d/%d: - %.4fs/step - loss: %.4f - accuracy: %.4f" 
          % (epoch + 1, epochs, (time.time() - start_time) / (epoch + 1), loss, train_acc))

print("Time of HMC: ", time.time() - start_time)

Epoch 1/1500: - 0.0297s/step - loss: 2.2921 - accuracy: 0.5000
Epoch 2/1500: - 0.0292s/step - loss: 2.2524 - accuracy: 0.5000
Epoch 3/1500: - 0.0281s/step - loss: 1.7860 - accuracy: 0.5000
Epoch 4/1500: - 0.0266s/step - loss: 2.0150 - accuracy: 0.5000
Epoch 5/1500: - 0.0252s/step - loss: 2.8227 - accuracy: 0.5000
Epoch 6/1500: - 0.0243s/step - loss: 1.6086 - accuracy: 0.5000
Epoch 7/1500: - 0.0235s/step - loss: 2.3664 - accuracy: 0.5000
Epoch 8/1500: - 0.0229s/step - loss: 1.5967 - accuracy: 0.7500
Epoch 9/1500: - 0.0225s/step - loss: 2.0020 - accuracy: 0.7500
Epoch 10/1500: - 0.0221s/step - loss: 2.1339 - accuracy: 1.0000
Epoch 11/1500: - 0.0221s/step - loss: 2.5579 - accuracy: 1.0000
Epoch 12/1500: - 0.0223s/step - loss: 2.2627 - accuracy: 0.5000
Epoch 13/1500: - 0.0226s/step - loss: 2.2205 - accuracy: 0.5000
Epoch 14/1500: - 0.0230s/step - loss: 1.7753 - accuracy: 0.5000
Epoch 15/1500: - 0.0232s/step - loss: 1.9041 - accuracy: 0.5000
Epoch 16/1500: - 0.0233s/step - loss: 2.4827 - ac

Epoch 132/1500: - 0.0230s/step - loss: 2.0294 - accuracy: 0.5000
Epoch 133/1500: - 0.0230s/step - loss: 1.6797 - accuracy: 0.5000
Epoch 134/1500: - 0.0230s/step - loss: 2.0618 - accuracy: 0.5000
Epoch 135/1500: - 0.0230s/step - loss: 2.2272 - accuracy: 0.5000
Epoch 136/1500: - 0.0229s/step - loss: 1.6405 - accuracy: 0.5000
Epoch 137/1500: - 0.0229s/step - loss: 1.9143 - accuracy: 0.5000
Epoch 138/1500: - 0.0229s/step - loss: 1.6015 - accuracy: 0.5000
Epoch 139/1500: - 0.0229s/step - loss: 2.2152 - accuracy: 0.5000
Epoch 140/1500: - 0.0228s/step - loss: 1.8016 - accuracy: 0.5000
Epoch 141/1500: - 0.0228s/step - loss: 1.8033 - accuracy: 0.5000
Epoch 142/1500: - 0.0228s/step - loss: 1.5752 - accuracy: 0.5000
Epoch 143/1500: - 0.0228s/step - loss: 2.2257 - accuracy: 0.5000
Epoch 144/1500: - 0.0229s/step - loss: 1.4603 - accuracy: 0.5000
Epoch 145/1500: - 0.0229s/step - loss: 2.2281 - accuracy: 0.5000
Epoch 146/1500: - 0.0229s/step - loss: 1.7156 - accuracy: 0.5000
Epoch 147/1500: - 0.0229s

Epoch 261/1500: - 0.0223s/step - loss: 1.8069 - accuracy: 0.5000
Epoch 262/1500: - 0.0223s/step - loss: 2.3474 - accuracy: 0.5000
Epoch 263/1500: - 0.0223s/step - loss: 1.7276 - accuracy: 0.5000
Epoch 264/1500: - 0.0223s/step - loss: 1.7206 - accuracy: 0.5000
Epoch 265/1500: - 0.0223s/step - loss: 1.7087 - accuracy: 0.5000
Epoch 266/1500: - 0.0222s/step - loss: 1.9821 - accuracy: 0.5000
Epoch 267/1500: - 0.0222s/step - loss: 1.6382 - accuracy: 0.5000
Epoch 268/1500: - 0.0222s/step - loss: 1.5744 - accuracy: 0.5000
Epoch 269/1500: - 0.0222s/step - loss: 1.7559 - accuracy: 0.5000
Epoch 270/1500: - 0.0222s/step - loss: 1.7463 - accuracy: 0.5000
Epoch 271/1500: - 0.0222s/step - loss: 1.7369 - accuracy: 0.5000
Epoch 272/1500: - 0.0222s/step - loss: 2.6837 - accuracy: 0.5000
Epoch 273/1500: - 0.0222s/step - loss: 1.9544 - accuracy: 0.5000
Epoch 274/1500: - 0.0222s/step - loss: 1.4598 - accuracy: 0.5000
Epoch 275/1500: - 0.0222s/step - loss: 2.5067 - accuracy: 0.5000
Epoch 276/1500: - 0.0222s

Epoch 393/1500: - 0.0224s/step - loss: 1.8250 - accuracy: 0.5000
Epoch 394/1500: - 0.0224s/step - loss: 1.7043 - accuracy: 0.5000
Epoch 395/1500: - 0.0224s/step - loss: 1.7167 - accuracy: 0.5000
Epoch 396/1500: - 0.0224s/step - loss: 1.7824 - accuracy: 0.5000
Epoch 397/1500: - 0.0224s/step - loss: 2.1751 - accuracy: 0.2500
Epoch 398/1500: - 0.0224s/step - loss: 2.0007 - accuracy: 0.2500
Epoch 399/1500: - 0.0224s/step - loss: 1.5035 - accuracy: 0.5000
Epoch 400/1500: - 0.0224s/step - loss: 1.8299 - accuracy: 0.2500
Epoch 401/1500: - 0.0224s/step - loss: 1.7277 - accuracy: 0.2500
Epoch 402/1500: - 0.0224s/step - loss: 1.6986 - accuracy: 0.2500
Epoch 403/1500: - 0.0224s/step - loss: 1.6849 - accuracy: 0.2500
Epoch 404/1500: - 0.0224s/step - loss: 1.7951 - accuracy: 0.7500
Epoch 405/1500: - 0.0225s/step - loss: 1.6850 - accuracy: 0.2500
Epoch 406/1500: - 0.0225s/step - loss: 1.5864 - accuracy: 0.7500
Epoch 407/1500: - 0.0225s/step - loss: 2.4266 - accuracy: 0.7500
Epoch 408/1500: - 0.0225s

Epoch 529/1500: - 0.0227s/step - loss: 1.9719 - accuracy: 0.7500
Epoch 530/1500: - 0.0227s/step - loss: 1.4797 - accuracy: 0.7500
Epoch 531/1500: - 0.0227s/step - loss: 1.8373 - accuracy: 0.5000
Epoch 532/1500: - 0.0227s/step - loss: 2.4977 - accuracy: 0.7500
Epoch 533/1500: - 0.0227s/step - loss: 2.4093 - accuracy: 0.7500
Epoch 534/1500: - 0.0227s/step - loss: 1.9528 - accuracy: 0.7500
Epoch 535/1500: - 0.0227s/step - loss: 1.7023 - accuracy: 0.7500
Epoch 536/1500: - 0.0227s/step - loss: 1.8394 - accuracy: 0.7500
Epoch 537/1500: - 0.0227s/step - loss: 1.6669 - accuracy: 0.7500
Epoch 538/1500: - 0.0227s/step - loss: 2.2809 - accuracy: 0.7500
Epoch 539/1500: - 0.0227s/step - loss: 1.8309 - accuracy: 0.7500
Epoch 540/1500: - 0.0227s/step - loss: 1.8172 - accuracy: 0.7500
Epoch 541/1500: - 0.0227s/step - loss: 2.4038 - accuracy: 0.7500
Epoch 542/1500: - 0.0227s/step - loss: 1.6748 - accuracy: 0.7500
Epoch 543/1500: - 0.0227s/step - loss: 2.0576 - accuracy: 0.7500
Epoch 544/1500: - 0.0227s

Epoch 660/1500: - 0.0221s/step - loss: 1.9408 - accuracy: 0.5000
Epoch 661/1500: - 0.0221s/step - loss: 1.4588 - accuracy: 0.5000
Epoch 662/1500: - 0.0221s/step - loss: 1.9840 - accuracy: 0.5000
Epoch 663/1500: - 0.0221s/step - loss: 2.0955 - accuracy: 0.5000
Epoch 664/1500: - 0.0221s/step - loss: 1.4343 - accuracy: 0.2500
Epoch 665/1500: - 0.0221s/step - loss: 1.8342 - accuracy: 0.2500
Epoch 666/1500: - 0.0221s/step - loss: 1.5727 - accuracy: 0.2500
Epoch 667/1500: - 0.0221s/step - loss: 1.6756 - accuracy: 0.2500
Epoch 668/1500: - 0.0221s/step - loss: 1.4535 - accuracy: 0.2500
Epoch 669/1500: - 0.0221s/step - loss: 1.8286 - accuracy: 0.2500
Epoch 670/1500: - 0.0221s/step - loss: 1.8394 - accuracy: 0.2500
Epoch 671/1500: - 0.0221s/step - loss: 1.7125 - accuracy: 0.2500
Epoch 672/1500: - 0.0221s/step - loss: 2.0106 - accuracy: 0.2500
Epoch 673/1500: - 0.0221s/step - loss: 1.6615 - accuracy: 0.2500
Epoch 674/1500: - 0.0221s/step - loss: 1.6968 - accuracy: 0.5000
Epoch 675/1500: - 0.0221s

Epoch 792/1500: - 0.0217s/step - loss: 1.4801 - accuracy: 0.5000
Epoch 793/1500: - 0.0217s/step - loss: 1.5677 - accuracy: 0.5000
Epoch 794/1500: - 0.0217s/step - loss: 1.7013 - accuracy: 0.5000
Epoch 795/1500: - 0.0217s/step - loss: 1.7153 - accuracy: 0.5000
Epoch 796/1500: - 0.0217s/step - loss: 2.7757 - accuracy: 0.7500
Epoch 797/1500: - 0.0217s/step - loss: 1.6154 - accuracy: 0.7500
Epoch 798/1500: - 0.0217s/step - loss: 1.9249 - accuracy: 0.5000
Epoch 799/1500: - 0.0217s/step - loss: 1.4780 - accuracy: 0.5000
Epoch 800/1500: - 0.0217s/step - loss: 1.9867 - accuracy: 0.7500
Epoch 801/1500: - 0.0217s/step - loss: 1.6885 - accuracy: 0.5000
Epoch 802/1500: - 0.0217s/step - loss: 1.5857 - accuracy: 0.5000
Epoch 803/1500: - 0.0217s/step - loss: 1.5642 - accuracy: 0.5000
Epoch 804/1500: - 0.0217s/step - loss: 1.7879 - accuracy: 0.5000
Epoch 805/1500: - 0.0217s/step - loss: 1.9614 - accuracy: 0.7500
Epoch 806/1500: - 0.0217s/step - loss: 2.0114 - accuracy: 0.5000
Epoch 807/1500: - 0.0217s

Epoch 928/1500: - 0.0218s/step - loss: 1.3856 - accuracy: 0.5000
Epoch 929/1500: - 0.0218s/step - loss: 1.2689 - accuracy: 0.5000
Epoch 930/1500: - 0.0218s/step - loss: 1.2576 - accuracy: 0.5000
Epoch 931/1500: - 0.0218s/step - loss: 1.8397 - accuracy: 0.5000
Epoch 932/1500: - 0.0218s/step - loss: 1.4837 - accuracy: 0.7500
Epoch 933/1500: - 0.0218s/step - loss: 1.3552 - accuracy: 0.2500
Epoch 934/1500: - 0.0218s/step - loss: 1.3645 - accuracy: 0.2500
Epoch 935/1500: - 0.0218s/step - loss: 1.6070 - accuracy: 0.2500
Epoch 936/1500: - 0.0218s/step - loss: 1.5862 - accuracy: 0.5000
Epoch 937/1500: - 0.0218s/step - loss: 1.3412 - accuracy: 0.5000
Epoch 938/1500: - 0.0218s/step - loss: 1.3470 - accuracy: 0.5000
Epoch 939/1500: - 0.0218s/step - loss: 1.9172 - accuracy: 0.5000
Epoch 940/1500: - 0.0218s/step - loss: 1.6612 - accuracy: 0.5000
Epoch 941/1500: - 0.0218s/step - loss: 1.7417 - accuracy: 0.5000
Epoch 942/1500: - 0.0218s/step - loss: 1.8571 - accuracy: 0.5000
Epoch 943/1500: - 0.0218s

Epoch 1063/1500: - 0.0217s/step - loss: 1.3389 - accuracy: 0.5000
Epoch 1064/1500: - 0.0217s/step - loss: 1.6748 - accuracy: 0.5000
Epoch 1065/1500: - 0.0217s/step - loss: 2.1326 - accuracy: 0.5000
Epoch 1066/1500: - 0.0217s/step - loss: 1.3490 - accuracy: 0.5000
Epoch 1067/1500: - 0.0217s/step - loss: 2.0335 - accuracy: 0.5000
Epoch 1068/1500: - 0.0217s/step - loss: 1.5064 - accuracy: 0.5000
Epoch 1069/1500: - 0.0217s/step - loss: 1.6801 - accuracy: 0.5000
Epoch 1070/1500: - 0.0217s/step - loss: 2.6174 - accuracy: 0.5000
Epoch 1071/1500: - 0.0217s/step - loss: 2.3560 - accuracy: 0.2500
Epoch 1072/1500: - 0.0217s/step - loss: 2.2617 - accuracy: 0.5000
Epoch 1073/1500: - 0.0217s/step - loss: 1.9684 - accuracy: 0.5000
Epoch 1074/1500: - 0.0217s/step - loss: 2.4236 - accuracy: 0.5000
Epoch 1075/1500: - 0.0217s/step - loss: 1.8206 - accuracy: 0.2500
Epoch 1076/1500: - 0.0217s/step - loss: 1.3546 - accuracy: 0.5000
Epoch 1077/1500: - 0.0217s/step - loss: 1.3453 - accuracy: 0.5000
Epoch 1078

Epoch 1197/1500: - 0.0218s/step - loss: 1.5275 - accuracy: 0.5000
Epoch 1198/1500: - 0.0218s/step - loss: 1.6822 - accuracy: 0.5000
Epoch 1199/1500: - 0.0218s/step - loss: 1.7923 - accuracy: 0.5000
Epoch 1200/1500: - 0.0218s/step - loss: 1.8246 - accuracy: 0.5000
Epoch 1201/1500: - 0.0218s/step - loss: 2.0871 - accuracy: 0.5000
Epoch 1202/1500: - 0.0218s/step - loss: 1.6940 - accuracy: 0.5000
Epoch 1203/1500: - 0.0218s/step - loss: 1.5747 - accuracy: 0.5000
Epoch 1204/1500: - 0.0218s/step - loss: 1.6478 - accuracy: 0.5000
Epoch 1205/1500: - 0.0218s/step - loss: 1.8734 - accuracy: 0.5000
Epoch 1206/1500: - 0.0218s/step - loss: 1.7196 - accuracy: 0.5000
Epoch 1207/1500: - 0.0218s/step - loss: 1.9239 - accuracy: 0.5000
Epoch 1208/1500: - 0.0218s/step - loss: 1.6850 - accuracy: 0.5000
Epoch 1209/1500: - 0.0218s/step - loss: 1.7202 - accuracy: 0.5000
Epoch 1210/1500: - 0.0218s/step - loss: 2.1867 - accuracy: 0.5000
Epoch 1211/1500: - 0.0217s/step - loss: 1.4192 - accuracy: 0.5000
Epoch 1212

Epoch 1331/1500: - 0.0218s/step - loss: 1.4404 - accuracy: 0.5000
Epoch 1332/1500: - 0.0218s/step - loss: 2.5043 - accuracy: 0.5000
Epoch 1333/1500: - 0.0218s/step - loss: 1.9245 - accuracy: 0.5000
Epoch 1334/1500: - 0.0218s/step - loss: 1.6420 - accuracy: 0.5000
Epoch 1335/1500: - 0.0218s/step - loss: 1.6997 - accuracy: 0.7500
Epoch 1336/1500: - 0.0218s/step - loss: 1.9077 - accuracy: 0.7500
Epoch 1337/1500: - 0.0218s/step - loss: 1.7111 - accuracy: 0.5000
Epoch 1338/1500: - 0.0218s/step - loss: 1.8670 - accuracy: 0.5000
Epoch 1339/1500: - 0.0218s/step - loss: 2.1733 - accuracy: 0.5000
Epoch 1340/1500: - 0.0218s/step - loss: 1.9273 - accuracy: 0.5000
Epoch 1341/1500: - 0.0218s/step - loss: 2.2882 - accuracy: 0.5000
Epoch 1342/1500: - 0.0218s/step - loss: 1.9111 - accuracy: 0.5000
Epoch 1343/1500: - 0.0218s/step - loss: 1.7011 - accuracy: 0.5000
Epoch 1344/1500: - 0.0218s/step - loss: 2.5576 - accuracy: 0.5000
Epoch 1345/1500: - 0.0218s/step - loss: 1.9000 - accuracy: 0.5000
Epoch 1346

Epoch 1457/1500: - 0.0217s/step - loss: 1.3707 - accuracy: 0.5000
Epoch 1458/1500: - 0.0217s/step - loss: 1.4919 - accuracy: 0.5000
Epoch 1459/1500: - 0.0217s/step - loss: 1.7107 - accuracy: 0.5000
Epoch 1460/1500: - 0.0217s/step - loss: 2.2446 - accuracy: 0.5000
Epoch 1461/1500: - 0.0217s/step - loss: 1.5835 - accuracy: 0.5000
Epoch 1462/1500: - 0.0217s/step - loss: 2.0425 - accuracy: 0.2500
Epoch 1463/1500: - 0.0217s/step - loss: 1.5743 - accuracy: 0.2500
Epoch 1464/1500: - 0.0217s/step - loss: 1.7864 - accuracy: 0.2500
Epoch 1465/1500: - 0.0217s/step - loss: 2.0261 - accuracy: 0.2500
Epoch 1466/1500: - 0.0217s/step - loss: 1.8952 - accuracy: 0.5000
Epoch 1467/1500: - 0.0217s/step - loss: 1.6940 - accuracy: 0.5000
Epoch 1468/1500: - 0.0217s/step - loss: 2.7373 - accuracy: 0.2500
Epoch 1469/1500: - 0.0217s/step - loss: 1.4118 - accuracy: 0.2500
Epoch 1470/1500: - 0.0217s/step - loss: 1.6963 - accuracy: 0.2500
Epoch 1471/1500: - 0.0217s/step - loss: 1.3726 - accuracy: 0.2500
Epoch 1472

In [28]:
model2 = StochasticMLP(hidden_layer_sizes = [30], n_outputs=1)
network2 = [model2.call(images) for images, labels in train_ds]

In [29]:
burnin = 500

for i in range(burnin):
    
    if(i % 100 == 0): print("Step %d" % i)
    network2 = [model2.gibbs_new_state(images, net, labels) for (images, labels), net in zip(train_ds, network2)]

Step 0
Step 100
Step 200
Step 300
Step 400


In [30]:
epochs = 1500
loss_ls2 = []
acc_ls2 = []
start_time = time.time()

for epoch in range(epochs):
    
    loss = 0.0
    acc = 0.0
    for bs, (x, y) in enumerate(train_ds):
        
        # only one mini-batch
        model2.update_weights(x, network2[bs], y, 0.1)
        network2 = [model2.gibbs_new_state(x, net, y) for (x, y), net in zip(train_ds, network2)]
        loss += -1 * tf.reduce_mean(model2.target_log_prob(x, network2[bs], y))
    
    preds = [model2.get_predictions(images) for images, labels in train_ds]
    train_acc = accuracy_score(np.concatenate(preds), y_train)
    loss_ls2.append(loss)
    acc_ls2.append(train_acc)
    
    print("Epoch %d/%d: - %.4fs/step - loss: %.4f - accuracy: %.4f" 
          % (epoch + 1, epochs, (time.time() - start_time) / (epoch + 1), loss, train_acc))

print("Time of HMC: ", time.time() - start_time)

Epoch 1/1500: - 0.1170s/step - loss: 21.8262 - accuracy: 0.5000
Epoch 2/1500: - 0.1069s/step - loss: 21.7924 - accuracy: 0.5000
Epoch 3/1500: - 0.1027s/step - loss: 21.3837 - accuracy: 0.2500
Epoch 4/1500: - 0.1019s/step - loss: 21.7313 - accuracy: 0.2500
Epoch 5/1500: - 0.0993s/step - loss: 21.2275 - accuracy: 0.7500
Epoch 6/1500: - 0.0971s/step - loss: 21.6323 - accuracy: 0.5000
Epoch 7/1500: - 0.0972s/step - loss: 21.6699 - accuracy: 0.5000
Epoch 8/1500: - 0.0971s/step - loss: 21.9053 - accuracy: 0.5000
Epoch 9/1500: - 0.0965s/step - loss: 21.5302 - accuracy: 0.5000
Epoch 10/1500: - 0.0962s/step - loss: 21.4342 - accuracy: 0.5000
Epoch 11/1500: - 0.0961s/step - loss: 21.8704 - accuracy: 0.5000
Epoch 12/1500: - 0.0967s/step - loss: 21.4405 - accuracy: 0.5000
Epoch 13/1500: - 0.0981s/step - loss: 21.2260 - accuracy: 0.5000
Epoch 14/1500: - 0.0976s/step - loss: 21.4604 - accuracy: 0.5000
Epoch 15/1500: - 0.0968s/step - loss: 21.2719 - accuracy: 0.2500
Epoch 16/1500: - 0.0970s/step - lo

Epoch 128/1500: - 0.0956s/step - loss: 20.2570 - accuracy: 0.5000
Epoch 129/1500: - 0.0956s/step - loss: 20.3392 - accuracy: 0.5000
Epoch 130/1500: - 0.0956s/step - loss: 20.5031 - accuracy: 0.5000
Epoch 131/1500: - 0.0956s/step - loss: 20.9058 - accuracy: 0.5000
Epoch 132/1500: - 0.0956s/step - loss: 22.2201 - accuracy: 0.2500
Epoch 133/1500: - 0.0955s/step - loss: 20.8196 - accuracy: 0.5000
Epoch 134/1500: - 0.0955s/step - loss: 20.4197 - accuracy: 0.5000
Epoch 135/1500: - 0.0955s/step - loss: 20.2463 - accuracy: 0.5000
Epoch 136/1500: - 0.0955s/step - loss: 20.1691 - accuracy: 0.5000
Epoch 137/1500: - 0.0955s/step - loss: 21.3789 - accuracy: 0.5000
Epoch 138/1500: - 0.0954s/step - loss: 21.8214 - accuracy: 0.5000
Epoch 139/1500: - 0.0954s/step - loss: 19.8590 - accuracy: 0.5000
Epoch 140/1500: - 0.0953s/step - loss: 20.8529 - accuracy: 0.5000
Epoch 141/1500: - 0.0953s/step - loss: 20.1057 - accuracy: 0.5000
Epoch 142/1500: - 0.0952s/step - loss: 21.0615 - accuracy: 0.2500
Epoch 143/

Epoch 253/1500: - 0.0922s/step - loss: 21.6701 - accuracy: 0.5000
Epoch 254/1500: - 0.0921s/step - loss: 20.2877 - accuracy: 0.5000
Epoch 255/1500: - 0.0921s/step - loss: 20.7583 - accuracy: 0.5000
Epoch 256/1500: - 0.0921s/step - loss: 21.2738 - accuracy: 0.5000
Epoch 257/1500: - 0.0921s/step - loss: 20.7975 - accuracy: 0.7500
Epoch 258/1500: - 0.0921s/step - loss: 20.0929 - accuracy: 0.7500
Epoch 259/1500: - 0.0921s/step - loss: 19.6457 - accuracy: 0.5000
Epoch 260/1500: - 0.0920s/step - loss: 20.6524 - accuracy: 0.5000
Epoch 261/1500: - 0.0920s/step - loss: 19.9399 - accuracy: 0.5000
Epoch 262/1500: - 0.0920s/step - loss: 20.2795 - accuracy: 0.5000
Epoch 263/1500: - 0.0920s/step - loss: 20.1590 - accuracy: 0.5000
Epoch 264/1500: - 0.0920s/step - loss: 21.4912 - accuracy: 0.5000
Epoch 265/1500: - 0.0920s/step - loss: 20.2264 - accuracy: 0.5000
Epoch 266/1500: - 0.0919s/step - loss: 19.5342 - accuracy: 0.5000
Epoch 267/1500: - 0.0919s/step - loss: 21.0409 - accuracy: 0.5000
Epoch 268/

Epoch 379/1500: - 0.0906s/step - loss: 20.5633 - accuracy: 0.5000
Epoch 380/1500: - 0.0906s/step - loss: 19.6685 - accuracy: 0.7500
Epoch 381/1500: - 0.0906s/step - loss: 19.8192 - accuracy: 0.5000
Epoch 382/1500: - 0.0906s/step - loss: 20.6885 - accuracy: 0.2500
Epoch 383/1500: - 0.0906s/step - loss: 20.8116 - accuracy: 0.5000
Epoch 384/1500: - 0.0906s/step - loss: 21.8307 - accuracy: 0.5000
Epoch 385/1500: - 0.0906s/step - loss: 20.6148 - accuracy: 0.5000
Epoch 386/1500: - 0.0906s/step - loss: 21.0798 - accuracy: 0.5000
Epoch 387/1500: - 0.0906s/step - loss: 19.4908 - accuracy: 0.5000
Epoch 388/1500: - 0.0906s/step - loss: 21.3997 - accuracy: 0.7500
Epoch 389/1500: - 0.0906s/step - loss: 19.5065 - accuracy: 0.5000
Epoch 390/1500: - 0.0905s/step - loss: 19.9610 - accuracy: 0.5000
Epoch 391/1500: - 0.0905s/step - loss: 20.4478 - accuracy: 0.5000
Epoch 392/1500: - 0.0905s/step - loss: 21.2405 - accuracy: 0.5000
Epoch 393/1500: - 0.0905s/step - loss: 20.2455 - accuracy: 0.5000
Epoch 394/

Epoch 505/1500: - 0.0899s/step - loss: 20.2481 - accuracy: 0.5000
Epoch 506/1500: - 0.0899s/step - loss: 21.0660 - accuracy: 0.2500
Epoch 507/1500: - 0.0899s/step - loss: 21.2024 - accuracy: 0.2500
Epoch 508/1500: - 0.0899s/step - loss: 21.7332 - accuracy: 0.5000
Epoch 509/1500: - 0.0899s/step - loss: 20.5280 - accuracy: 0.2500
Epoch 510/1500: - 0.0899s/step - loss: 19.2434 - accuracy: 0.5000
Epoch 511/1500: - 0.0899s/step - loss: 19.1360 - accuracy: 0.2500
Epoch 512/1500: - 0.0899s/step - loss: 19.9153 - accuracy: 0.2500
Epoch 513/1500: - 0.0899s/step - loss: 19.4757 - accuracy: 0.5000
Epoch 514/1500: - 0.0899s/step - loss: 21.2004 - accuracy: 0.5000
Epoch 515/1500: - 0.0899s/step - loss: 19.8807 - accuracy: 0.2500
Epoch 516/1500: - 0.0898s/step - loss: 17.5406 - accuracy: 0.5000
Epoch 517/1500: - 0.0898s/step - loss: 20.0835 - accuracy: 0.5000
Epoch 518/1500: - 0.0898s/step - loss: 19.3238 - accuracy: 0.5000
Epoch 519/1500: - 0.0898s/step - loss: 20.7288 - accuracy: 0.5000
Epoch 520/

Epoch 631/1500: - 0.0894s/step - loss: 18.9297 - accuracy: 0.2500
Epoch 632/1500: - 0.0894s/step - loss: 20.7412 - accuracy: 0.2500
Epoch 633/1500: - 0.0894s/step - loss: 19.4489 - accuracy: 0.5000
Epoch 634/1500: - 0.0894s/step - loss: 20.8218 - accuracy: 0.5000
Epoch 635/1500: - 0.0894s/step - loss: 19.2418 - accuracy: 0.5000
Epoch 636/1500: - 0.0894s/step - loss: 20.0003 - accuracy: 0.7500
Epoch 637/1500: - 0.0894s/step - loss: 20.5488 - accuracy: 0.7500
Epoch 638/1500: - 0.0894s/step - loss: 19.6214 - accuracy: 0.7500
Epoch 639/1500: - 0.0894s/step - loss: 19.9956 - accuracy: 0.5000
Epoch 640/1500: - 0.0894s/step - loss: 19.4676 - accuracy: 0.2500
Epoch 641/1500: - 0.0894s/step - loss: 19.9668 - accuracy: 0.7500
Epoch 642/1500: - 0.0894s/step - loss: 18.9069 - accuracy: 0.5000
Epoch 643/1500: - 0.0894s/step - loss: 20.0962 - accuracy: 0.7500
Epoch 644/1500: - 0.0894s/step - loss: 19.9233 - accuracy: 0.2500
Epoch 645/1500: - 0.0894s/step - loss: 19.5541 - accuracy: 0.2500
Epoch 646/

Epoch 756/1500: - 0.0897s/step - loss: 19.0479 - accuracy: 0.5000
Epoch 757/1500: - 0.0898s/step - loss: 20.1569 - accuracy: 0.5000
Epoch 758/1500: - 0.0897s/step - loss: 20.4008 - accuracy: 0.7500
Epoch 759/1500: - 0.0898s/step - loss: 18.6221 - accuracy: 0.7500
Epoch 760/1500: - 0.0898s/step - loss: 17.7174 - accuracy: 0.7500
Epoch 761/1500: - 0.0898s/step - loss: 18.9884 - accuracy: 0.7500
Epoch 762/1500: - 0.0898s/step - loss: 19.4922 - accuracy: 0.7500
Epoch 763/1500: - 0.0898s/step - loss: 20.5752 - accuracy: 0.7500
Epoch 764/1500: - 0.0898s/step - loss: 20.4146 - accuracy: 0.5000
Epoch 765/1500: - 0.0898s/step - loss: 18.5193 - accuracy: 0.2500
Epoch 766/1500: - 0.0898s/step - loss: 19.3831 - accuracy: 0.2500
Epoch 767/1500: - 0.0898s/step - loss: 19.0064 - accuracy: 0.5000
Epoch 768/1500: - 0.0898s/step - loss: 20.1111 - accuracy: 0.5000
Epoch 769/1500: - 0.0898s/step - loss: 19.7174 - accuracy: 0.5000
Epoch 770/1500: - 0.0898s/step - loss: 19.3410 - accuracy: 0.5000
Epoch 771/

Epoch 881/1500: - 0.0903s/step - loss: 19.8852 - accuracy: 0.5000
Epoch 882/1500: - 0.0903s/step - loss: 18.5356 - accuracy: 0.5000
Epoch 883/1500: - 0.0903s/step - loss: 17.3227 - accuracy: 0.5000
Epoch 884/1500: - 0.0903s/step - loss: 19.5014 - accuracy: 0.5000
Epoch 885/1500: - 0.0903s/step - loss: 19.0966 - accuracy: 0.5000
Epoch 886/1500: - 0.0903s/step - loss: 17.8608 - accuracy: 0.5000
Epoch 887/1500: - 0.0903s/step - loss: 19.5069 - accuracy: 0.5000
Epoch 888/1500: - 0.0903s/step - loss: 17.2041 - accuracy: 0.5000
Epoch 889/1500: - 0.0903s/step - loss: 18.3853 - accuracy: 0.5000
Epoch 890/1500: - 0.0903s/step - loss: 17.4352 - accuracy: 0.5000
Epoch 891/1500: - 0.0903s/step - loss: 20.3126 - accuracy: 0.7500
Epoch 892/1500: - 0.0903s/step - loss: 17.9865 - accuracy: 0.2500
Epoch 893/1500: - 0.0903s/step - loss: 17.0079 - accuracy: 0.2500
Epoch 894/1500: - 0.0903s/step - loss: 19.1710 - accuracy: 0.5000
Epoch 895/1500: - 0.0904s/step - loss: 17.3048 - accuracy: 0.7500
Epoch 896/

Epoch 1007/1500: - 0.0900s/step - loss: 17.4129 - accuracy: 0.5000
Epoch 1008/1500: - 0.0900s/step - loss: 19.4418 - accuracy: 0.5000
Epoch 1009/1500: - 0.0900s/step - loss: 17.2780 - accuracy: 0.7500
Epoch 1010/1500: - 0.0900s/step - loss: 19.1941 - accuracy: 0.7500
Epoch 1011/1500: - 0.0900s/step - loss: 18.6282 - accuracy: 0.7500
Epoch 1012/1500: - 0.0900s/step - loss: 18.0544 - accuracy: 0.7500
Epoch 1013/1500: - 0.0900s/step - loss: 19.1910 - accuracy: 0.5000
Epoch 1014/1500: - 0.0900s/step - loss: 19.6901 - accuracy: 0.5000
Epoch 1015/1500: - 0.0900s/step - loss: 18.5222 - accuracy: 0.5000
Epoch 1016/1500: - 0.0900s/step - loss: 20.0553 - accuracy: 0.5000
Epoch 1017/1500: - 0.0900s/step - loss: 19.4049 - accuracy: 0.5000
Epoch 1018/1500: - 0.0900s/step - loss: 17.1115 - accuracy: 0.2500
Epoch 1019/1500: - 0.0900s/step - loss: 17.9292 - accuracy: 0.2500
Epoch 1020/1500: - 0.0900s/step - loss: 18.3972 - accuracy: 0.5000
Epoch 1021/1500: - 0.0900s/step - loss: 17.2824 - accuracy: 0.

Epoch 1130/1500: - 0.0903s/step - loss: 18.0307 - accuracy: 0.5000
Epoch 1131/1500: - 0.0903s/step - loss: 19.4406 - accuracy: 0.5000
Epoch 1132/1500: - 0.0903s/step - loss: 17.7497 - accuracy: 0.5000
Epoch 1133/1500: - 0.0903s/step - loss: 22.4866 - accuracy: 0.7500
Epoch 1134/1500: - 0.0903s/step - loss: 19.9421 - accuracy: 0.7500
Epoch 1135/1500: - 0.0903s/step - loss: 19.2033 - accuracy: 0.5000
Epoch 1136/1500: - 0.0903s/step - loss: 21.5713 - accuracy: 0.2500
Epoch 1137/1500: - 0.0903s/step - loss: 18.7638 - accuracy: 0.5000
Epoch 1138/1500: - 0.0903s/step - loss: 19.2953 - accuracy: 0.7500
Epoch 1139/1500: - 0.0903s/step - loss: 19.6477 - accuracy: 0.7500
Epoch 1140/1500: - 0.0903s/step - loss: 18.3522 - accuracy: 0.7500
Epoch 1141/1500: - 0.0903s/step - loss: 18.2316 - accuracy: 0.7500
Epoch 1142/1500: - 0.0903s/step - loss: 18.1672 - accuracy: 0.7500
Epoch 1143/1500: - 0.0903s/step - loss: 19.8666 - accuracy: 0.7500
Epoch 1144/1500: - 0.0903s/step - loss: 18.3434 - accuracy: 0.

Epoch 1253/1500: - 0.0901s/step - loss: 16.4697 - accuracy: 0.5000
Epoch 1254/1500: - 0.0901s/step - loss: 17.0758 - accuracy: 0.5000
Epoch 1255/1500: - 0.0901s/step - loss: 19.8685 - accuracy: 0.5000
Epoch 1256/1500: - 0.0901s/step - loss: 17.7554 - accuracy: 0.5000
Epoch 1257/1500: - 0.0901s/step - loss: 18.0798 - accuracy: 0.5000
Epoch 1258/1500: - 0.0901s/step - loss: 18.9291 - accuracy: 0.5000
Epoch 1259/1500: - 0.0901s/step - loss: 19.4850 - accuracy: 0.5000
Epoch 1260/1500: - 0.0901s/step - loss: 19.4951 - accuracy: 0.5000
Epoch 1261/1500: - 0.0901s/step - loss: 18.3813 - accuracy: 0.5000
Epoch 1262/1500: - 0.0901s/step - loss: 18.6757 - accuracy: 0.5000
Epoch 1263/1500: - 0.0901s/step - loss: 19.8244 - accuracy: 0.5000
Epoch 1264/1500: - 0.0901s/step - loss: 17.9950 - accuracy: 0.5000
Epoch 1265/1500: - 0.0901s/step - loss: 18.0181 - accuracy: 0.5000
Epoch 1266/1500: - 0.0901s/step - loss: 20.0885 - accuracy: 0.5000
Epoch 1267/1500: - 0.0901s/step - loss: 19.2845 - accuracy: 0.

Epoch 1376/1500: - 0.0902s/step - loss: 18.2034 - accuracy: 0.5000
Epoch 1377/1500: - 0.0902s/step - loss: 18.8822 - accuracy: 0.5000
Epoch 1378/1500: - 0.0902s/step - loss: 19.3849 - accuracy: 0.5000
Epoch 1379/1500: - 0.0902s/step - loss: 17.7625 - accuracy: 0.5000
Epoch 1380/1500: - 0.0902s/step - loss: 18.7809 - accuracy: 0.5000
Epoch 1381/1500: - 0.0902s/step - loss: 18.7122 - accuracy: 0.5000
Epoch 1382/1500: - 0.0902s/step - loss: 19.2552 - accuracy: 0.5000
Epoch 1383/1500: - 0.0902s/step - loss: 18.4986 - accuracy: 0.5000
Epoch 1384/1500: - 0.0902s/step - loss: 17.7517 - accuracy: 0.5000
Epoch 1385/1500: - 0.0902s/step - loss: 19.0605 - accuracy: 0.5000
Epoch 1386/1500: - 0.0902s/step - loss: 17.7344 - accuracy: 0.5000
Epoch 1387/1500: - 0.0902s/step - loss: 19.2295 - accuracy: 0.5000
Epoch 1388/1500: - 0.0902s/step - loss: 19.0198 - accuracy: 0.5000
Epoch 1389/1500: - 0.0902s/step - loss: 18.5919 - accuracy: 0.5000
Epoch 1390/1500: - 0.0902s/step - loss: 18.2850 - accuracy: 0.

Epoch 1499/1500: - 0.0900s/step - loss: 15.7985 - accuracy: 0.5000
Epoch 1500/1500: - 0.0900s/step - loss: 18.1045 - accuracy: 0.5000
Time of HMC:  135.0693519115448
