In [1]:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow.math as tm
import numpy as np
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import initializers
from tensorflow.keras import Model
from tensorflow.keras import models
from tensorflow.keras.layers import Dense

In [2]:
def convert2_zero_one(x):
    
    t = [tf.math.sigmoid(i) for i in x]    
    return t

def cont_bern_log_norm(lam, l_lim=0.49, u_lim=0.51):
    '''
    computes the log normalizing constant of a continuous Bernoulli distribution in a numerically stable way.
    returns the log normalizing constant for lam in (0, l_lim) U (u_lim, 1) and a Taylor approximation in
    [l_lim, u_lim].
    cut_y below might appear useless, but it is important to not evaluate log_norm near 0.5 as tf.where evaluates
    both options, regardless of the value of the condition.
    '''
    
    cut_lam = tf.where(tm.logical_or(tm.less(lam, l_lim), tm.greater(lam, u_lim)), lam, l_lim * tf.ones_like(lam))
    log_norm = tm.log(tm.abs(2.0 * tm.atanh(1 - 2.0 * cut_lam))) - tm.log(tm.abs(1 - 2.0 * cut_lam))
    taylor = tm.log(2.0) + 4.0 / 3.0 * tm.pow(lam - 0.5, 2) + 104.0 / 45.0 * tm.pow(lam - 0.5, 4)
    return tf.where(tm.logical_or(tm.less(lam, l_lim), tm.greater(lam, u_lim)), log_norm, taylor)

def bin_bits(x, n_bits):
    
    binx = []
    for i in range(n_bits):
        binx = [x % 2] + binx
        x //= 2
    return binx

def get_ls(x, n_bits):
    
    n = 2**n_bits
    l = n//(2**(x+1))
    false_ls = []
    true_ls = []
    
    for i in range(0, n//l, 2):
        false_ls.extend(list(range(i*l, (i+1)*l)))
        true_ls.extend(list(range((i+1)*l, (i+2)*l)))
    
    return false_ls, true_ls 

In [29]:
class LBN(Model):
    
    def __init__(self, hidden_layer_sizes = [4, 4], n_outputs = 1, learning_rate = 0.01):
        '''
        initialize the loopy belief network and the optimizer.
        '''
        
        super(LBN, self).__init__()
        self.hidden_layer_sizes = hidden_layer_sizes
        self.hidden_layers = [Dense(layer_size) for layer_size in hidden_layer_sizes]
        self.n_outputs = n_outputs
        self.output_layer = Dense(n_outputs)
        self.optimizer = tf.keras.optimizers.SGD(learning_rate = learning_rate)
        
        # initial CPTs for all the nodes (hidden nodes and output node)
        # Each node in first hidden layer needs 2 cells (They have evidence from input nodes.)
        # Nodes in other layers need 2^(#nodes of last layer) cells
        # With N hidden layers, cpts have N+1 layers (include output layer) 
        # p_trues have N (only for hidden layers) 
        
        all_layer_sizes = [0] + hidden_layer_sizes + [n_outputs]
        self.cpts = [np.zeros([layer_size, 2**(last_layer_size + 1)], dtype = "float") for layer_size, 
                         last_layer_size in zip(all_layer_sizes[1:], all_layer_sizes[:-1])]
        
        self.p_trues = [np.zeros([layer_size, 2], dtype = "float") for layer_size 
                            in hidden_layer_sizes]
        
    def call(self, x):
        '''
        initialize the weights of the NN and calculate CPTs
        '''
        
        all_layers = self.hidden_layers + [self.output_layer]
        input_nodes = x
        
        # initialize weights
        for layer in all_layers:
            logits = layer(x)
            x = tm.sigmoid(logits)
        
        # initialize cpts
        # 0 = false, 1 = true
        probs = tm.sigmoid(all_layers[0](input_nodes))[0].numpy()
        self.cpts[0][:, 0] = 1 - probs
        self.cpts[0][:, 1] = probs
        
        for i, (last_layer_size, layer) in enumerate(zip(self.hidden_layer_sizes, all_layers[1:])):
            
            n_cells = 2**last_layer_size
            for j in range(n_cells):
                nodes = np.array([bin_bits(j, last_layer_size)])
                probs = tm.sigmoid(layer(nodes))[0].numpy()
                self.cpts[i + 1][:, j] = 1 - probs
                self.cpts[i + 1][:, j + n_cells] = probs

    def mp(self, y):
        '''
        message passing, including forward passing and backward passing
        '''
                
        # forward passing
        last_layer_sizes = [0] + self.hidden_layer_sizes
        n_layer = len(self.hidden_layers)
        
        for i, last_layer_size in enumerate(last_layer_sizes):
            
            # first hidden layer 
            if i == 0: 
                self.p_trues[i] = self.cpts[i]
            else:
                n_cells = 2**last_layer_size
                for j in range(n_cells):
                    nodes = bin_bits(j, last_layer_size)
                    messages = [self.p_trues[i - 1][k, flag] for k, flag in enumerate(nodes)]
                    tot_message = np.prod(messages)
                    self.cpts[i][:, j] *= tot_message
                    self.cpts[i][:, j + n_cells] *= tot_message
            
                # output layer
                if i == n_layer:
                    # get message from the observation
                    if y == 0: self.cpts[i] = self.cpts[i][:, :n_cells]
                    if y == 1: self.cpts[i] = self.cpts[i][:, n_cells:]        
                else:
                    p_true = np.sum(self.cpts[i][:, :n_cells], axis = 1)
                    p_false = np.sum(self.cpts[i][:, n_cells:], axis = 1)
                    self.p_trues[i][:, 0] = p_false / (p_true + p_false)
                    self.p_trues[i][:, 1] = p_true / (p_true + p_false)
                
        
        # backward passing
        
        curr_sizes = self.hidden_layer_sizes[::-1]
        #next_sizes = [self.n_outputs] + self.hidden_layer_sizes[:0:-1]
        #print(curr_sizes)
        
        
        for i, curr_size in enumerate(curr_sizes):
            
            # current layer number
            curr_id = n_layer - i - 1 
            
            for j in range(curr_size):
                
                # get messages from next layer
                f_cells, t_cells = get_ls(j, curr_size)
                
                if i != 0: # not the last hidden layer
                    ff_cells = [x + 2**curr_size for x in f_cells]
                    f_cells.extend(ff_cells)
                    
                    tt_cells = [x + 2**curr_size for x in t_cells]
                    t_cells.extend(tt_cells)
                
                p_true = np.sum(self.cpts[curr_id + 1][:, t_cells], axis = 1) / self.p_trues[curr_id][j, 1]
                p_false = np.sum(self.cpts[curr_id + 1][:, f_cells], axis = 1) / self.p_trues[curr_id][j, 0]
                message_true = np.prod(p_true)
                message_false = np.prod(p_false)
                
                n_cells = self.cpts[i].shape[1] // 2
                self.cpts[i][j, :n_cells] *= message_false
                self.cpts[i][j, n_cells:] *= message_true
        
        

In [30]:
nn = LBN(hidden_layer_sizes = [2, 2], n_outputs = 1, learning_rate = 0.01)
print(nn.cpts[0])

[[0. 0.]
 [0. 0.]]


In [31]:
x = np.array([[1,0]])
y = np.array([1])

In [32]:
nn.call(x)

In [33]:
nn.cpts

ListWrapper([array([[0.32601833, 0.67398167],
       [0.46811044, 0.53188956]]), array([[0.5       , 0.35928935, 0.25800598, 0.16317272, 0.5       ,
        0.64071065, 0.74199402, 0.83682728],
       [0.5       , 0.63740325, 0.7709558 , 0.8554284 , 0.5       ,
        0.36259675, 0.22904417, 0.1445716 ]]), array([[0.5       , 0.52978426, 0.277807  , 0.30235946, 0.5       ,
        0.47021574, 0.722193  , 0.69764054]])])

In [34]:
nn.p_trues

ListWrapper([array([[0., 0.],
       [0., 0.]]), array([[0., 0.],
       [0., 0.]])])

In [35]:
nn.mp(y)

[2, 2]


In [27]:
nn.cpts

ListWrapper([array([[0.26640116, 0.26368794],
       [0.29639798, 0.1478194 ]]), array([[0.40944883, 0.16058854, 0.75907189, 0.26993928, 0.77564417,
        0.169211  , 0.675426  , 0.1336017 ],
       [3.70483624, 1.2103338 , 5.67742828, 1.83858099, 1.2725838 ,
        0.36099619, 1.51722998, 0.42664102]]), array([[0.08358126, 0.14718558, 0.09277467, 0.19680356]])])

In [None]:
a=np.array([[1,2,3,4],[5,6,7,8],[9,10,11,12]])
b=[1,3]
c=[0,2]
np.sum(a[:, b], axis = 1) / 10

In [None]:
a