In this notebook, we will look into the most common inference algorithm for graphical models. The name of the algorithm is Belief Propagation (BP), or Sum-product algorithm. BP algorithm is very powerful tool for infering a given graphical model, such as Bayesian Network or Markov Random Field. By this inference processes, we can get the marginal distributions or the joint distributions of some given random variables easily.

In [1]:
from itertools import product
from copy import deepcopy, copy

import numpy as np
import matplotlib.pyplot as plt

### Node Declaration
The first step for implementation of graphical model is to declare the Node class. A Node instance should denote a random variable of the given system. Hence, Node class should have the moethods which deal with the message getting into the node itself: `init_messages`, `message`, `update_messages`, `get_marginal_distribution`.

In [2]:
class Node():
    def __init__(self, name, vals):
        self.name = name
        self.vals = vals
        self.vals.sort()
    
    def init_messages(self, adj_mtx_row, factors):
        '''
        
        Initializes the messages from the adjacent factors to this node.

        Args:
            adj_mtx_row: The adjacent matrix row whose elements denote 1 for nodes adjacent to this node
                and 0 for not, and factors
            factors: The list for all factors which is declared before
            
        '''
        self.factors = []
        self.factor2pos_idx = {}
        self.factor2adj_nodes = {}
        self.factor2message = {}
        for i in np.where(adj_mtx_row == 1)[0]:
            factor = factors[i]
            
            self.factors.append(factor)
            self.factor2pos_idx[factor] = factor.nodes.index(self)            
            
            adj_nodes = factor.nodes[:]
            adj_nodes.remove(self)
            adj_vals = list(product(*[adj_node.vals for adj_node in adj_nodes]))
            self.factor2adj_nodes[factor] = (adj_nodes, adj_vals)
            
            self.factor2message[factor] = {val: 0. for val in self.vals}
    
    def message(self, factor, val):
        '''
        
        Returned the message getting into this node with respect to the given arguments.

        Args:
            factor: The factor index that we want to get
            val: The value associated with the random variable
            
        '''
        return self.factor2message[factor][val]
    
    def update_messages(self):
        '''
        
        Update the messages for the one step in BP algorithm.
        
        '''
        new_factor2message = copy(self.factor2message)
        for val in self.vals:
            for factor in self.factors:
                adj_nodes, adj_vals = self.factor2adj_nodes[factor]
                if len(adj_nodes) != 0:
                    summation = 0
                    for adj_val in adj_vals:
                        factor_in = list(deepcopy(adj_val))
                        factor_in.insert(self.factor2pos_idx[factor], val)
                        prod = factor(tuple(factor_in))
                        for j, adj_node in enumerate(adj_nodes):
                            prod *= factor.message(adj_node, adj_val[j])
                        summation += prod
                    new_factor2message[factor][val] = summation
                else:
                    new_factor2message[factor][val] = factor(val)
        
        return new_factor2message
    
    def marginal_distribution(self):
        '''
        
        Returns the marginal distribution of the corresponding random variable. The result of this methods
        reflects the BP algorithm result.
        
        '''
        P = {}
        for val in self.vals:
            all_messages = [self.message(factor, val) for factor in self.factors]
            P[val] = np.prod(all_messages)
            
        P_temp = deepcopy(P)
        for val in self.vals:
            P[val] /= np.sum([value for _, value in P_temp.items()])
        
        return P

### Factor Declaration
The next step is to decalre Factor class. This procedure is very similar to the previous Node declaration. A Factor instance should denote the system energy contribution of the random variables corresponding to this factor. Thus, the instance has the node list for the nodes corresponding to itself. The methods of Factor class are: `init_messages`, `message`, `update_messages`, `val_check`.

In [3]:
class Factor():
    def __init__(self, name, nodes, maps, val_check=True):
        '''
        
        Args:
            name: The name of this factor, e.g. "phi_1"
            nodes: The list of the input names of this factor, e.g. ["X_1", "X_2"]
            maps: The dictionary of the map btw the value tuples of the inputs and the factor output values, e.g.
                {(0, 0): 0.25, (0, 1): 0.25, (1, 0): 0.25, (1, 1): 0.25}, {0: 0.5, 1: 0.5}
                
        '''
        self.name = name
        self.nodes = nodes
        self.maps = maps
        if val_check:
            self.val_check()
    
    def __call__(self, vals):
        '''
        
        An example of this method's usage:
            factor((1, 1)), Result: 0.25
            factor(0), Result: 0.25
                
        '''
        return self.maps[vals]
    
    def init_messages(self):
        self.node2adj_factors = {}
        self.node2message = {}
        for node in self.nodes:            
            adj_factors = node.factors[:]
            adj_factors.remove(self)
            self.node2adj_factors[node] = adj_factors
            
            self.node2message[node] = \
                {val: 1/len(node.vals) for val in node.vals} # Initialize as an uniform distribution.
    
    def message(self, node, val):
        return self.node2message[node][val]
    
    def update_messages(self):
        new_node2message = copy(self.node2message)
        for node in self.nodes:
            for val in node.vals:
                adj_factors = self.node2adj_factors[node]
                if len(adj_factors) != 0:
                    prod = 1
                    for adj_factor in adj_factors:
                        prod *= node.message(adj_factor, val)
                    new_node2message[node][val] = prod
                else:
                    new_node2message[node][val] = self.node2message[node][val]
        
        return new_node2message
    
    def val_check(self):
        for i, node in enumerate(self.nodes):
            vals = set()
            for key in self.maps.keys():
                try:
                    vals.add(key[i])
                except:
                    vals.add(key)
            vals = list(vals)
            vals.sort()
            
            if node.vals != vals:
                raise ValueError
                
    def joint_distribution(self):
        '''
        
        Returns the joint distribution of the random variable associate with this factor.
        The result of this methods reflects the BP algorithm result.
        
        '''
        P = {}
        for vals, factor_val in self.maps.items():
            if len(self.nodes) == 1:
                all_messages = [self.message(self.nodes[0], vals)]
            else:
                all_messages = [self.message(node, val) for node, val in zip(self.nodes, vals)]
            P[vals] = factor_val * np.prod(all_messages)
        
        P_temp = deepcopy(P)
        for vals, _ in self.maps.items():
            P[vals] /= np.sum([value for _, value in P_temp.items()])
        
        return P

### Markov Random Field (MRF) Declaration
The final step is to declare MRF class. An instance of MRF class has to have its corresponding nodes and factors. Also the instance should have its own node and factor messages to update by using Belief Propagation algorithm. The methods of MRF class are: `belief_propagation`, `get_adj_mtx`.

In [4]:
class MarkovRandomField():
    def __init__(self, nodes, factors):
        self.nodes = nodes
        self.factors = factors
        
        self.num_nodes = len(self.nodes)
        self.num_factors = len(self.factors)
        
        self.adj_mtx = self.get_adj_mtx() # [num_nodes, num_factors]
        for row, node in zip(self.adj_mtx, self.nodes):
            node.init_messages(row, self.factors)
        for factor in self.factors:
            factor.init_messages()
    
    def belief_propagation(self, steps):
        for i in range(steps):
            for node in self.nodes:
                new_factor2message = node.update_messages()
                node.factor2message = new_factor2message
            for factor in self.factors:
                new_node2message = factor.update_messages()
                factor.node2message = new_node2message
    
    def get_adj_mtx(self):
        adj_mtx = np.zeros([self.num_nodes, self.num_factors])
        
        for i, node in enumerate(self.nodes):
            for j, factor in enumerate(self.factors):
                adj_mtx[i, j] = 1 if node in factor.nodes else 0
        
        return adj_mtx

### Beilef Propagation Test
To test the BP algorithm scipts, we should create a factor graph first. In this notebook we will use the example of the following link: [Graphical Model Tutorial 01 by Hyungcheol Noh](https://hcnoh.github.io/2020-01-26-graphical-model-01)

Note that this example is based on a Bayesian Network.

![](./img_BP_MRF/01.png)

![](./img_BP_MRF/02.png)

The first step for our testing is to convert this Bayesian Network to the corresponding Markov Random Field. This is very simple procedure. Please check the link for explaining this procedure: [Graphical Model Tutorial 02 by Hyungcheol Noh](https://hcnoh.github.io/2020-01-26-graphical-model-02)

![](./img_BP_MRF/03.png)

![](./img_BP_MRF/04.png)

The creation of the nodes:

In [5]:
X_1 = Node("X_1", [0, 1])
X_2 = Node("X_2", [0, 1])
X_3 = Node("X_3", [0, 1])
X_4 = Node("X_4", [0, 1])

The creation of the factors:

In [6]:
Phi_1 = Factor("Phi_1", [X_1],
   {0: 0.1, 1: 0.9}
)
Phi_2 = Factor("Phi_2", [X_1, X_2],
   {
       (0, 0): 0.3,
       (0, 1): 0.7,
       (1, 0): 0.9,
       (1, 1): 0.1
   }
)
Phi_3 = Factor("Phi_3", [X_3],
   {0: 0.8, 1: 0.2}
)
Phi_4 = Factor("Phi_4", [X_2, X_3, X_4],
   {
       (0, 0, 0): 0.3,
       (0, 0, 1): 0.7,
       (0, 1, 0): 0.8,
       (0, 1, 1): 0.2,
       (1, 0, 0): 0.1,
       (1, 0, 1): 0.9,
       (1, 1, 0): 0.2,
       (1, 1, 1): 0.8
   }
)

The declaration of a MRF instance:

In [7]:
nodes = [X_1, X_2, X_3, X_4]
factors = [Phi_1, Phi_2, Phi_3, Phi_4]

mrf = MarkovRandomField(nodes, factors)

Perform the Belief Propagation process by 10 steps:

In [8]:
mrf.belief_propagation(10)

Check the results of the marginal distribution of the each node:

In [9]:
X_1.marginal_distribution(), X_2.marginal_distribution(), X_3.marginal_distribution(), X_4.marginal_distribution()

({0: 0.1, 1: 0.9},
 {0: 0.8400000000000001, 1: 0.16},
 {0: 0.8, 1: 0.20000000000000007},
 {0: 0.35520000000000007, 1: 0.6447999999999999})

Check the results of the joint distribution of the random variables corresponding to the each factor:

In [10]:
Phi_1.joint_distribution(), Phi_2.joint_distribution(), Phi_3.joint_distribution(), Phi_4.joint_distribution(), 

({0: 0.1, 1: 0.9},
 {(0, 0): 0.03,
  (0, 1): 0.06999999999999999,
  (1, 0): 0.81,
  (1, 1): 0.09000000000000001},
 {0: 0.8, 1: 0.20000000000000007},
 {(0, 0, 0): 0.20159999999999997,
  (0, 0, 1): 0.4704,
  (0, 1, 0): 0.13440000000000002,
  (0, 1, 1): 0.033600000000000005,
  (1, 0, 0): 0.012799999999999997,
  (1, 0, 1): 0.11519999999999998,
  (1, 1, 0): 0.006399999999999999,
  (1, 1, 1): 0.025599999999999994})