In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow.math as tm
import numpy as np
import random
import time
import matplotlib.pyplot as plt

from tensorflow.keras import Model
from tensorflow.keras.layers import Dense
from sklearn.metrics import accuracy_score
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split

In [2]:
np.random.seed(1234)
X, Y = make_moons(100, noise = 0.3)

# Split into test and training data
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=73)
y_train = y_train.reshape(-1, 1)
y_test = y_test.reshape(-1, 1)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

2021-10-13 13:55:57.226609: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
class StochasticMLP(Model):
    
    def __init__(self, hidden_layer_sizes=[100], n_outputs=10):
        super(StochasticMLP, self).__init__()
        self.hidden_layer_sizes = hidden_layer_sizes
        self.fc_layers = [Dense(layer_size) for layer_size in hidden_layer_sizes]
        self.output_layer = Dense(n_outputs)
        
    def call(self, x):
        
        network = []
        
        for i, layer in enumerate(self.fc_layers):
            
            logits = layer(x)
            x = tfp.distributions.Bernoulli(logits=logits).sample()
            network.append(x)

        final_logits = self.output_layer(x) # initial the weight of output layer
            
        return network
    
    def target_log_prob(self, x, h, y):
        
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h]
        h_previous = [x] + h_current[:-1]
    
        nlog_prob = 0. # negative log probability
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current, h_previous, self.fc_layers)):
            
            logits = layer(pv)
            
            ce = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=cv, logits=logits)
            
            nlog_prob += tf.reduce_sum(ce, axis = -1)
        
        fce = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.cast(y, tf.float32), logits=self.output_layer(h_current[-1]))
        nlog_prob += tf.reduce_sum(fce, axis = -1)
            
        return -1 * nlog_prob
    
    def gibbs_new_state(self, x, h, y):
        
        '''
            generate a new state for the network node by node in Gibbs setting.
        '''
        
        h_current = h
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h_current]
        
        in_layers = self.fc_layers
        out_layers = self.fc_layers[1:] + [self.output_layer]
        
        prev_vals = [x] + h_current[:-1]
        curr_vals = h_current
        next_vals = h_current[1:] + [y]
        
        for i, (in_layer, out_layer, pv, cv, nv) in enumerate(zip(in_layers, out_layers, prev_vals, curr_vals, next_vals)):

            # node by node
            
            nodes = tf.transpose(cv)
            prob_parents = tm.sigmoid(in_layer(pv))
            
            out_layer_weights = out_layer.get_weights()[0]
            
            next_logits = out_layer(cv)
            
            new_layer = []
            
            for j, node in enumerate(nodes):
                
                # get info for current node (i, j)
                
                prob_parents_j = prob_parents[:, j]
                out_layer_weights_j = out_layer_weights[j]
                
                # calculate logits and logprob for node is 0 or 1
                next_logits_if_node_0 = next_logits[:, :] - node[:, None] * out_layer_weights_j[None, :]
                next_logits_if_node_1 = next_logits[:, :] + (1 - node[:, None]) * out_layer_weights_j[None, :]
                
                #print(next_logits_if_node_0, next_logits_if_node_1)
                
                logprob_children_if_node_0 = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.cast(nv, dtype = tf.float32), logits=next_logits_if_node_0), axis = -1)
                
                logprob_children_if_node_1 = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.cast(nv, dtype = tf.float32), logits=next_logits_if_node_1), axis = -1)
                
                # calculate prob for node (i, j)
                prob_0 = (1 - prob_parents_j) * tm.exp(logprob_children_if_node_0)
                prob_1 = prob_parents_j * tm.exp(logprob_children_if_node_1)
                prob_j = prob_1 / (prob_1 + prob_0)
            
                # sample new state with prob_j for node (i, j)
                new_node = tfp.distributions.Bernoulli(probs = prob_j).sample() # MAY BE SLOW
                
                # update nodes and logits for following calculation
                new_node_casted = tf.cast(new_node, dtype = "float32")
                next_logits = next_logits_if_node_0 * (1 - new_node_casted)[:, None] \
                            + next_logits_if_node_1 * new_node_casted[:, None] 
                
                # keep track of new node values (in prev/curr/next_vals and h_new)
                new_layer.append(new_node)
           
            new_layer = tf.transpose(new_layer)
            h_current[i] = new_layer
            prev_vals = [x] + h_current[:-1]
            curr_vals = h_current
            next_vals = h_current[1:] + [y]
        
        return h_current
    
    def update_weights(self, x, h, y, lr = 0.1):
        
        optimizer = tf.keras.optimizers.Adam(learning_rate = lr)
        with tf.GradientTape() as tape:
            loss = -1 * tf.reduce_mean(self.target_log_prob(x, h, y))
        
        grads = tape.gradient(loss, self.trainable_weights)
        optimizer.apply_gradients(zip(grads, self.trainable_weights))
    
    def get_predictions(self, x):

        logits = 0.0
        for layer in self.fc_layers:
            logits = layer(x)
            x = tm.sigmoid(logits)
        
        logits = self.output_layer(x)
        probs = tm.sigmoid(logits)
        #print(probs)
        labels = tf.cast(tm.greater(probs, 0.5), tf.int32)

        return labels
    
    

In [4]:
model = StochasticMLP(hidden_layer_sizes = [32], n_outputs=1)
network = [model.call(x) for x, y in train_ds]

In [5]:
# Gibbs Burnin
burnin = 50

for i in range(burnin):
    
    if(i % 10 == 0): print("Step %d" % i)
    network = [model.gibbs_new_state(images, net, labels) for (images, labels), net in zip(train_ds, network)]

Step 0
Step 10
Step 20
Step 30
Step 40


In [6]:
# Training
epochs = 50
loss_ls = []
acc_ls = []
start_time = time.time()

for epoch in range(epochs):
    
    loss = 0.0
    acc = 0.0
    for bs, (x, y) in enumerate(train_ds):
        
        # only one mini-batch
        model.update_weights(x, network[bs], y, 0.1)
        network = [model.gibbs_new_state(x, net, y) for (x, y), net in zip(train_ds, network)]
        loss += -1 * tf.reduce_mean(model.target_log_prob(x, network[bs], y))
    
    preds = [model.get_predictions(images) for images, labels in train_ds]
    train_acc = accuracy_score(np.concatenate(preds), y_train)
    loss_ls.append(loss)
    acc_ls.append(train_acc)
    
    print("Epoch %d/%d: - %.4fs/step - loss: %.4f - accuracy: %.4f" 
          % (epoch + 1, epochs, (time.time() - start_time) / (epoch + 1), loss, train_acc))

print("Time of Gibbs: ", time.time() - start_time)

Epoch 1/50: - 0.8365s/step - loss: 67.2969 - accuracy: 0.5500
Epoch 2/50: - 0.8042s/step - loss: 66.4549 - accuracy: 0.5500
Epoch 3/50: - 0.7865s/step - loss: 66.1563 - accuracy: 0.5500
Epoch 4/50: - 0.7777s/step - loss: 64.2668 - accuracy: 0.5500
Epoch 5/50: - 0.7729s/step - loss: 64.9815 - accuracy: 0.5500
Epoch 6/50: - 0.7720s/step - loss: 64.7212 - accuracy: 0.4750
Epoch 7/50: - 0.7785s/step - loss: 62.7816 - accuracy: 0.7875
Epoch 8/50: - 0.7750s/step - loss: 62.7981 - accuracy: 0.5000
Epoch 9/50: - 0.7721s/step - loss: 61.4250 - accuracy: 0.5500
Epoch 10/50: - 0.7697s/step - loss: 62.3493 - accuracy: 0.7750
Epoch 11/50: - 0.7682s/step - loss: 61.5713 - accuracy: 0.5500
Epoch 12/50: - 0.7671s/step - loss: 61.2463 - accuracy: 0.6875
Epoch 13/50: - 0.7660s/step - loss: 61.5908 - accuracy: 0.8000
Epoch 14/50: - 0.7652s/step - loss: 60.6586 - accuracy: 0.7125
Epoch 15/50: - 0.7645s/step - loss: 61.6575 - accuracy: 0.6375
Epoch 16/50: - 0.7641s/step - loss: 59.5514 - accuracy: 0.8000
E

In [8]:
# print plot and save data
fig, ax = plt.subplots()

ax.plot(list(range(epochs)), acc_ls, label = 'HMC')
ax.legend()
fig.savefig('make_moon_gibbs_acc_1000.png')
plt.close(fig)

In [10]:
with open('make_moon_gibbs_loss_1000.npy', 'wb') as f:
    np.save(f, np.array(acc_ls))
    np.save(f, np.array(loss_ls))