In [13]:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow.math as tm
import numpy as np
import random
import time
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import initializers
from tensorflow.keras import Model
from tensorflow.keras.layers import Flatten, Dense
from sklearn.metrics import accuracy_score

In [2]:
# Load MNIST
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Select binary data
label_sub = [0,1]
x_train_sub = [x.reshape(-1) for x, y in zip(x_train, y_train) if y in label_sub]
y_train_sub = [y.reshape(-1) for y in y_train if y in label_sub]
x_test_sub = [x.reshape(-1) for x, y in zip(x_test, y_test) if y in label_sub]
y_test_sub = [y.reshape(-1) for y in y_test if y in label_sub]


print('There are', len(x_train_sub), 'training images.')
print('There are', len(x_test_sub), 'test images.')

There are 12665 training images.
There are 2115 test images.


In [15]:
x_train_sub = x_train_sub[:100]
y_train_sub = y_train_sub[:100]

print('There are', len(x_train_sub), 'training images.')

There are 100 training images.


In [18]:
x_train_sub[0].shape

(784,)

In [16]:
train_ds = tf.data.Dataset.from_tensor_slices((x_train_sub, y_train_sub)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test_sub, y_test_sub)).batch(32)

In [None]:
# Standard BP
model_bp = keras.Sequential(
    [
        keras.Input(shape=(784,)),
        layers.Dense(32, activation = "sigmoid"),
        layers.Dense(1, activation = "sigmoid")
    ]
)

batch_size = 32
epochs = 200
opt = tf.keras.optimizers.SGD(learning_rate=.1)
st = time.time()
model_bp.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy", "AUC"])
history = model_bp.fit(train_ds, batch_size=batch_size, epochs=epochs)
print(time.time() - st)

In [11]:
def convert2_zero_one(x):
    
    t = [tf.math.sigmoid(i) for i in x]    
    return t

def cont_bern_log_norm(lam, l_lim=0.49, u_lim=0.51):
    '''
    computes the log normalizing constant of a continuous Bernoulli distribution in a numerically stable way.
    returns the log normalizing constant for lam in (0, l_lim) U (u_lim, 1) and a Taylor approximation in
    [l_lim, u_lim].
    cut_y below might appear useless, but it is important to not evaluate log_norm near 0.5 as tf.where evaluates
    both options, regardless of the value of the condition.
    '''
    
    cut_lam = tf.where(tm.logical_or(tm.less(lam, l_lim), tm.greater(lam, u_lim)), lam, l_lim * tf.ones_like(lam))
    log_norm = tm.log(tm.abs(2.0 * tm.atanh(1 - 2.0 * cut_lam))) - tm.log(tm.abs(1 - 2.0 * cut_lam))
    taylor = tm.log(2.0) + 4.0 / 3.0 * tm.pow(lam - 0.5, 2) + 104.0 / 45.0 * tm.pow(lam - 0.5, 4)
    return tf.where(tm.logical_or(tm.less(lam, l_lim), tm.greater(lam, u_lim)), log_norm, taylor)

In [70]:
# MLP model
class StochasticMLP(Model):
    
    def __init__(self, hidden_layer_sizes=[100], n_outputs=10):
        super(StochasticMLP, self).__init__()
        self.hidden_layer_sizes = hidden_layer_sizes
        self.fc_layers = [Dense(layer_size) for layer_size in hidden_layer_sizes]
        self.output_layer = Dense(n_outputs)
    
    def call(self, x):
        
        #x = Flatten()(x)
        
        network = []
        
        for i, layer in enumerate(self.fc_layers):
            
            logits = layer(x)
            x = tfp.distributions.Bernoulli(logits=logits).sample()
            network.append(x)

        final_logits = self.output_layer(x) # initial the weight of output layer
            
        return network
    
    def target_log_prob(self, x, h, y):
        
        h_current = convert2_zero_one([tf.cast(h_i, dtype=tf.float32) for h_i in h])
        h_previous = [x] + h_current[:-1]
    
        nlog_prob = 0. # negative log probability
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current, h_previous, self.fc_layers)):
            
            logits = layer(pv)
            print(cont_bern_log_norm(tf.nn.sigmoid(logits)))
            
            ce = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=cv, logits=logits)
            
            ce += cont_bern_log_norm(tf.nn.sigmoid(logits))
            
            nlog_prob += tf.reduce_sum(ce, axis = -1)
        
        fce = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.cast(y, tf.float32), logits=self.output_layer(h_current[-1]))
        nlog_prob += tf.reduce_sum(fce, axis = -1)
            
        return -1 * nlog_prob

    def target_log_prob2(self, x, h, y):

        h_current = convert2_zero_one(tf.split(h, self.hidden_layer_sizes, axis = 1))
        h_previous = [x] + h_current[:-1]
        
        nlog_prob = 0.
        
        for i, (cv, pv, layer) in enumerate(
            zip(h_current, h_previous, self.fc_layers)):
            
            logits = layer(pv)
            ce = tf.nn.sigmoid_cross_entropy_with_logits(
                labels=cv, logits=logits)
            
            ce += cont_bern_log_norm(tf.nn.sigmoid(logits))
            
            nlog_prob += tf.reduce_sum(ce, axis = -1)
        
        fce = tf.nn.sigmoid_cross_entropy_with_logits(
            labels=tf.cast(y, tf.float32), logits=self.output_layer(h_current[-1]))
        nlog_prob += tf.reduce_sum(fce, axis = -1)
            
        return -1 * nlog_prob
    
    def generate_hmc_kernel(self, x, y, step_size = pow(1000, -1/4)):
        
        adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(tfp.mcmc.HamiltonianMonteCarlo(
            target_log_prob_fn = lambda v: self.target_log_prob2(x, v, y),
            num_leapfrog_steps = 2,
            step_size = step_size),
            num_adaptation_steps=int(100 * 0.8))
        
        return adaptive_hmc
    
    # new proposing-state method with HamiltonianMonteCarlo
    def propose_new_state_hamiltonian(self, x, h, y, hmc_kernel, is_update_kernel = True):
    
        h_current = h
        h_current = [tf.cast(h_i, dtype=tf.float32) for h_i in h_current]
        h_current = tf.concat(h_current, axis = 1)

        # run the chain (with burn-in)
        num_burnin_steps = 0
        num_results = 1

        samples = tfp.mcmc.sample_chain(
            num_results = num_results,
            num_burnin_steps = num_burnin_steps,
            current_state = h_current, # may need to be reshaped
            kernel = hmc_kernel,
            trace_fn = None,
            return_final_kernel_results = True)
    
        # Generate new states of chains
        #h_state = rerange(samples[0][0])
        h_state = samples[0][0]
        h_new = tf.split(h_state, self.hidden_layer_sizes, axis = 1) 
        
        # Update the kernel if necesssary
        if is_update_kernel:
            new_step_size = samples[2].new_step_size.numpy()
            ker_new = self.generate_hmc_kernel(x, y, new_step_size)
            return(h_new, ker_new)
        else:
            return h_new
    
    def update_weights(self, x, h, y, lr = 0.1):
        
        optimizer = tf.keras.optimizers.SGD(learning_rate = lr)
        with tf.GradientTape() as tape:
            loss = -1 * tf.reduce_mean(self.target_log_prob(x, h, y))
        
        grads = tape.gradient(loss, self.trainable_weights)
        optimizer.apply_gradients(zip(grads, self.trainable_weights))

    def get_predictions(self, x):

        logits = 0.0
        for layer in self.fc_layers:
            logits = layer(x)
            x = tm.sigmoid(logits)
        
        logits = self.output_layer(x)
        probs = tm.sigmoid(logits)
        #print(probs)
        labels = tf.cast(tm.greater(probs, 0.5), tf.int32)

        return labels

In [71]:
model = StochasticMLP(hidden_layer_sizes = [32], n_outputs = 1)

In [72]:
network = [model.call(images) for images, labels in train_ds]

In [73]:
kernels = [model.generate_hmc_kernel(images, labels) for images, labels in train_ds]

In [74]:
for bs, (x, y) in enumerate(train_ds):
    if bs == 0:
        xx = x
        yy = y
        break

In [75]:
model.target_log_prob(xx, network[0], yy)

tf.Tensor(
[[      inf       inf       inf ...       inf       inf       inf]
 [      inf 1.2329103       inf ...       inf       inf       inf]
 [      inf       inf       inf ...       inf       inf       inf]
 ...
 [      inf       inf       inf ...       inf       inf       inf]
 [      inf       inf 1.9193465 ...       inf       inf       inf]
 [      inf       inf       inf ...       inf       inf       inf]], shape=(32, 32), dtype=float32)


<tf.Tensor: shape=(32,), dtype=float32, numpy=
array([-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
       -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf,
       -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
      dtype=float32)>

In [29]:
# Burn-in
burnin = 50

for i in range(burnin):

    if(i % 10 == 0): print("Step %d" % i)
    network_new = []
    kernels_new = []
    
    res = [model.propose_new_state_hamiltonian(x, net, y, ker) 
           for (x, y), net, ker in zip(train_ds, network, kernels)]
    
    network_new, kernels_new = zip(*res)
    network = network_new
    kernels = kernels_new

Step 0
Step 10
Step 20
Step 30
Step 40


In [33]:
for bs, (x, y) in enumerate(train_ds):
    print(model.target_log_prob(x, network[bs], y))

tf.Tensor(
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan], shape=(32,), dtype=float32)
tf.Tensor(
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan], shape=(32,), dtype=float32)
tf.Tensor(
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan], shape=(32,), dtype=float32)
tf.Tensor([nan nan nan nan], shape=(4,), dtype=float32)


In [None]:
# plot values of nodes
import matplotlib.pyplot as plt

#units = list(np.random.randint(100, size = 10))

for j in range(200):

    print(j)
    k = j % 10
    if(k == 0):
        fig, ax = plt.subplots(nrows=5, ncols=2, figsize=(14, 20))
    
    ax[k // 2, k % 2].plot(np.arange(20000), [net_values[i][j] for i in range(20000)])
    ax[k // 2, k % 2].set_title('Unit %i' % j)
    ax[k // 2, k % 2].set_ylim([-6, 6])

    if(k == 9):
        plt.tight_layout()
        plt.savefig('plots/100_50_50/merge_20k_%d' % (j//10))
        plt.close()

In [32]:
# Training
epochs = 200
loss_ls = []
acc_ls = []
start_time = time.time()

for epoch in range(epochs):
    
    loss = 0.0
    acc = 0.0
    for bs, (x, y) in enumerate(train_ds):
        
        # only one mini-batch
        model.update_weights(x, network[bs], y, 0.1)
        res = [model.propose_new_state_hamiltonian(x, net, y, ker, is_update_kernel = False) \
                   for (x, y), net, ker in zip(train_ds, network, kernels)]
        network = res
        loss += -1 * tf.reduce_sum(model.target_log_prob(x, network[bs], y))
    
    preds = [model.get_predictions(images) for images, labels in train_ds]
    train_acc = accuracy_score(np.concatenate(preds), y_train_sub)
    loss_ls.append(loss / len(x_train_sub))
    acc_ls.append(train_acc)
    
    print("Epoch %d/%d: - %.4fs/step - loss: %.4f - accuracy: %.4f" 
          % (epoch + 1, epochs, (time.time() - start_time) / (epoch + 1), loss / len(x_train_sub), train_acc))

print("Time of HMC: ", time.time() - start_time)

tf.Tensor(
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan], shape=(32,), dtype=float32)
tf.Tensor(
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan], shape=(32,), dtype=float32)
tf.Tensor(
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan], shape=(32,), dtype=float32)
tf.Tensor([nan nan nan nan], shape=(4,), dtype=float32)
Epoch 1/200: - 0.4906s/step - loss: nan - accuracy: 0.4300
tf.Tensor(
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan], shape=(32,), dtype=float32)
tf.Tensor(
[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan], shape=(32,), dtype=float32)
tf.Tensor(
[nan nan nan nan na

KeyboardInterrupt: 