In [1]:
# use Gibbs to test the loss

In [2]:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow.math as tm
import numpy as np
import time
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import initializers
from tensorflow.keras import Model
from tensorflow.keras import models
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
from sklearn.metrics import accuracy_score
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split

In [3]:
def convert2_zero_one(x):
    
    t = [tf.math.sigmoid(i) for i in x]    
    return t

def cont_bern_log_norm(lam, l_lim=0.49, u_lim=0.51):
    '''
    computes the log normalizing constant of a continuous Bernoulli distribution in a numerically stable way.
    returns the log normalizing constant for lam in (0, l_lim) U (u_lim, 1) and a Taylor approximation in
    [l_lim, u_lim].
    cut_y below might appear useless, but it is important to not evaluate log_norm near 0.5 as tf.where evaluates
    both options, regardless of the value of the condition.
    '''
    
    cut_lam = tf.where(tm.logical_or(tm.less(lam, l_lim), tm.greater(lam, u_lim)), lam, l_lim * tf.ones_like(lam))
    log_norm = tm.log(tm.abs(2.0 * tm.atanh(1 - 2.0 * cut_lam))) - tm.log(tm.abs(1 - 2.0 * cut_lam))
    taylor = tm.log(2.0) + 4.0 / 3.0 * tm.pow(lam - 0.5, 2) + 104.0 / 45.0 * tm.pow(lam - 0.5, 4)
    return tf.where(tm.logical_or(tm.less(lam, l_lim), tm.greater(lam, u_lim)), log_norm, taylor)

In [4]:
class StochasticMLP(Model):
    
    def __init__(self, hidden_layer_sizes=[100], n_outputs=10, lr=1e-3):
        super(StochasticMLP, self).__init__()
        self.hidden_layer_sizes = hidden_layer_sizes
        self.fc_layers = [Dense(layer_size) for layer_size in hidden_layer_sizes]
        self.output_layer = Dense(n_outputs)
        self.optimizer = tf.keras.optimizers.Adam(learning_rate = lr)
        
    def call(self, x):
        
        network = []
        
        for i, layer in enumerate(self.fc_layers):
            
            logits = layer(x)
            x = tfp.distributions.Bernoulli(logits=logits).sample()
            network.append(x)

        final_logits = self.output_layer(x) # initial the weight of output layer
            
        return network
    
    def target_log_prob(self, x, h, y, is_gibbs = False, is_hmc = False):
        
        # get current state
        if is_hmc:
            h_current = tf.split(h, self.hidden_layer_sizes, axis = 1)
        else:    
            h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h]
        h_current = convert2_zero_one(h_current)
        h_previous = [x] + h_current[:-1]
    
        nlog_prob = 0. # negative log probability
        
        for i, (cv, pv, layer) in enumerate(zip(h_current, h_previous, self.fc_layers)):
            
            logits = layer(pv)
            ce = tf.nn.sigmoid_cross_entropy_with_logits(labels = cv, logits = logits)
            if not is_gibbs:
                ce += cont_bern_log_norm(tf.nn.sigmoid(logits))
            
            nlog_prob += tf.reduce_sum(ce, axis = -1)
        
        fce = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.cast(y, tf.float32), logits=self.output_layer(h_current[-1]))
        nlog_prob += tf.reduce_sum(fce, axis = -1)
            
        return -1 * nlog_prob
    
    def gibbs_new_state(self, x, h, y):
        
        '''
            generate a new state for the network node by node in Gibbs setting.
        '''
        
        h_current = h
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h_current]
        
        in_layers = self.fc_layers
        out_layers = self.fc_layers[1:] + [self.output_layer]
        
        prev_vals = [x] + h_current[:-1]
        curr_vals = h_current
        next_vals = h_current[1:] + [y]
        
        for i, (in_layer, out_layer, pv, cv, nv) in enumerate(zip(in_layers, out_layers, prev_vals, curr_vals, next_vals)):

            # node by node
            
            nodes = tf.transpose(cv)
            prob_parents = tm.sigmoid(in_layer(pv))
            
            out_layer_weights = out_layer.get_weights()[0]
            
            next_logits = out_layer(cv)
            
            new_layer = []
            
            for j, node in enumerate(nodes):
                
                # get info for current node (i, j)
                
                prob_parents_j = prob_parents[:, j]
                out_layer_weights_j = out_layer_weights[j]
                
                # calculate logits and logprob for node is 0 or 1
                next_logits_if_node_0 = next_logits[:, :] - node[:, None] * out_layer_weights_j[None, :]
                next_logits_if_node_1 = next_logits[:, :] + (1 - node[:, None]) * out_layer_weights_j[None, :]
                
                #print(next_logits_if_node_0, next_logits_if_node_1)
                
                logprob_children_if_node_0 = -1 * tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.cast(nv, dtype = tf.float32), logits=next_logits_if_node_0), axis = -1)
                
                logprob_children_if_node_1 = -1 * tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(
                    labels=tf.cast(nv, dtype = tf.float32), logits=next_logits_if_node_1), axis = -1)
                
                # calculate prob for node (i, j)
                prob_0 = (1 - prob_parents_j) * tm.exp(logprob_children_if_node_0)
                prob_1 = prob_parents_j * tm.exp(logprob_children_if_node_1)
                prob_j = prob_1 / (prob_1 + prob_0)
            
                # sample new state with prob_j for node (i, j)
                new_node = tfp.distributions.Bernoulli(probs = prob_j).sample() # MAY BE SLOW
                
                # update nodes and logits for following calculation
                new_node_casted = tf.cast(new_node, dtype = "float32")
                next_logits = next_logits_if_node_0 * (1 - new_node_casted)[:, None] \
                            + next_logits_if_node_1 * new_node_casted[:, None] 
                
                # keep track of new node values (in prev/curr/next_vals and h_new)
                new_layer.append(new_node)
           
            new_layer = tf.transpose(new_layer)
            h_current[i] = new_layer
            prev_vals = [x] + h_current[:-1]
            curr_vals = h_current
            next_vals = h_current[1:] + [y]
        
        return h_current
    
    def update_weights(self, x, h, y, is_gibbs = False):
        
        with tf.GradientTape() as tape:
            loss = -1 * tf.reduce_mean(self.target_log_prob(x, h, y, is_gibbs = is_gibbs))
        
        grads = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
    
    def get_predictions(self, x):

        logits = 0.0
        for layer in self.fc_layers:
            logits = layer(x)
            x = tm.sigmoid(logits)
        
        logits = self.output_layer(x)
        probs = tm.sigmoid(logits)
        labels = tf.cast(tm.greater(probs, 0.5), tf.int32)

        return labels
    
    def get_loss(self, x, y):
        
        logits = 0.0
        for layer in self.fc_layers:
            logits = layer(x)
            x = tm.sigmoid(logits)
            
        logits = self.output_layer(x)
        loss = tf.nn.sigmoid_cross_entropy_with_logits(labels = tf.cast(y, tf.float32), logits = logits)
        
        return tf.reduce_sum(loss, axis = -1)

In [5]:
def gibbs(size, dat_train, epochs, burnin = 500):
    '''
    Gibbs Training
    '''
    # Setting
    # Get train labels and val labels
    target_train = np.concatenate([target for data, target in dat_train.as_numpy_iterator()])
    
    print("Start Gibbs")
    model = StochasticMLP(hidden_layer_sizes = [size], n_outputs=1, lr = 0.01)
    network = [model.call(data) for data, target in dat_train]
    
    # Burnin
    print("Start Gibbs Burning")    
    for i in range(burnin):
    
        if(i % 100 == 0): print("Step %d" % i)

        res = []
        burnin_loss = 0.0
        for bs, (data, target) in enumerate(dat_train):
            res.append(model.gibbs_new_state(data, network[bs], target))
 
        network = res
    
    # Training
    train_losses = []
    train_accs = []
    
    start_time = time.time()
    for epoch in range(epochs):
        
        # train
        for bs, (data, target) in enumerate(dat_train):
        
            model.update_weights(data, network[bs], target, is_gibbs = True)
            network = [model.gibbs_new_state(x, net, y) for (x, y), net in zip(dat_train, network)]
            
        train_loss = 0.0
        for data, target in dat_train:
            train_loss += tf.reduce_mean(model.get_loss(data, target))
        train_loss /= (bs + 1)
        train_losses.append(train_loss)       
        
        train_preds = [model.get_predictions(data) for data, target in dat_train]
        train_acc = accuracy_score(np.concatenate(train_preds), target_train)
        train_accs.append(train_acc) 
        
        print("Epoch %d/%d: - %.4fs/step - train_loss: %.4f - train_acc: %.4f " 
            % (epoch + 1, epochs, (time.time() - start_time) / (epoch + 1), train_loss, train_acc))
    
    y_logits = []
    for data, target in dat_train:
        
        logit = 0.0
        x = data
        for layer in model.fc_layers:
            logit = layer(x)
            x = tm.sigmoid(logit)
            
        logit = model.output_layer(x)
        y_logits.append(logit.numpy())
    
    print(y_logits)
    y_train = [target for data, target in dat_train.as_numpy_iterator()]
    bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
    true_loss = bce(y_train, y_logits).numpy()
    
    return {"train_acc": train_accs, "train_loss": train_losses}, true_loss

In [6]:
np.random.seed(1234)
X, Y = make_moons(200, noise = 0.3)

# Split into test and training data
x_train, x_val, y_train, y_val = train_test_split(X, Y, test_size = 0.2, random_state=73)
y_train = y_train.reshape(-1, 1)
y_val = y_val.reshape(-1, 1)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
val_ds = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(32)

2022-02-15 00:12:12.269401: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
size = 32
epochs = 10
burnin = 100

hist, ce = gibbs(size, train_ds, epochs, burnin)
print(hist['train_loss'][-1].numpy())
print(ce)

Start Gibbs
Start Gibbs Burning
Step 0
Epoch 1/10: - 2.0823s/step - train_loss: 0.6949 - train_acc: 0.5188 
Epoch 2/10: - 2.0978s/step - train_loss: 0.6886 - train_acc: 0.5000 
Epoch 3/10: - 2.1034s/step - train_loss: 0.6831 - train_acc: 0.6312 
Epoch 4/10: - 2.1094s/step - train_loss: 0.6784 - train_acc: 0.6875 
Epoch 5/10: - 2.1054s/step - train_loss: 0.6756 - train_acc: 0.5563 
Epoch 6/10: - 2.1046s/step - train_loss: 0.6717 - train_acc: 0.5875 
Epoch 7/10: - 2.1044s/step - train_loss: 0.6682 - train_acc: 0.7312 
Epoch 8/10: - 2.1044s/step - train_loss: 0.6680 - train_acc: 0.7250 
Epoch 9/10: - 2.1028s/step - train_loss: 0.6685 - train_acc: 0.6687 
Epoch 10/10: - 2.1005s/step - train_loss: 0.6686 - train_acc: 0.6000 
[array([[-0.12767382],
       [-0.02509977],
       [-0.16293673],
       [-0.08931442],
       [-0.2111633 ],
       [-0.18161868],
       [-0.13277508],
       [-0.03981177],
       [ 0.02000271],
       [-0.20347504],
       [-0.11888413],
       [-0.13537873],
     