In [1]:
from itertools import product
from copy import deepcopy, copy

import numpy as np

In [2]:
class Node():
    def __init__(self, name, vals):
        self.name = name
        self.vals = vals
        self.vals.sort()
    
    def init_messages(self, adj_mtx_row, factors):
        self.factors = []
        self.factor2pos_idx = {}
        self.factor2adj_nodes = {}
        self.factor2message = {}
        for i in np.where(adj_mtx_row == 1)[0]:
            factor = factors[i]
            
            self.factors.append(factor)
            self.factor2pos_idx[factor] = factor.nodes.index(self)            
            
            adj_nodes = factor.nodes[:]
            adj_nodes.remove(self)
            adj_vals = list(product(*[adj_node.vals for adj_node in adj_nodes]))
            self.factor2adj_nodes[factor] = (adj_nodes, adj_vals)
            
            self.factor2message[factor] = {val: 0. for val in self.vals}
    
    def message(self, factor, val):
        return self.factor2message[factor][val]
    
    def update_messages(self):
        new_factor2message = copy(self.factor2message)
        for val in self.vals:
            for factor in self.factors:
                adj_nodes, adj_vals = self.factor2adj_nodes[factor]
                if len(adj_nodes) != 0:
                    summation = 0
                    for adj_val in adj_vals:
                        factor_in = list(deepcopy(adj_val))
                        factor_in.insert(self.factor2pos_idx[factor], val)
                        prod = factor(tuple(factor_in))
                        for j, adj_node in enumerate(adj_nodes):
                            prod *= factor.message(adj_node, adj_val[j])
                        summation += prod
                    new_factor2message[factor][val] = summation
                else:
                    new_factor2message[factor][val] = factor(val)
        
        return new_factor2message
    
    def get_marginal_distribution(self):
        P = {}
        for val in self.vals:
            all_messages = [self.message(factor, val) for factor in self.factors]
            P[val] = np.prod(all_messages)
            
        P_temp = deepcopy(P)
        for val in self.vals:
            P[val] /= np.sum([value for _, value in P_temp.items()])
        
        return P

In [3]:
class Factor():
    def __init__(self, name, nodes, maps, val_check=True):
        '''
            name: The name of this factor. e.g. "phi_1"
            nodes: The list of the input names of this factor. e.g. ["X_1", "X_2"]
            maps: The dictionary of the map btw the value tuples of the inputs and the factor output values. e.g.
                {(0, 0): 0.25, (0, 1): 0.25, (1, 0): 0.25, (1, 1): 0.25}
        '''
        self.name = name
        self.nodes = nodes
        self.maps = maps
        if val_check:
            self.val_check()
    
    def __call__(self, vals):
        '''
            An example of this method's usage:
                factor((1, 1)), Result: 0.25
                factor(0), Result: 0.25
        '''
        return self.maps[vals]
    
    def init_messages(self):
        self.node2adj_factors = {}
        self.node2message = {}
        for node in self.nodes:            
            adj_factors = node.factors[:]
            adj_factors.remove(self)
            self.node2adj_factors[node] = adj_factors
            
            self.node2message[node] = \
                {val: 1/len(node.vals) for val in node.vals} # Initialize with uniform distribution.
    
    def message(self, node, val):
        return self.node2message[node][val]
    
    def update_messages(self):
        new_node2message = copy(self.node2message)
        for node in self.nodes:
            for val in node.vals:
                adj_factors = self.node2adj_factors[node]
                if len(adj_factors) != 0:
                    prod = 1
                    for adj_factor in adj_factors:
                        prod *= node.message(adj_factor, val)
                    new_node2message[node][val] = prod
                else:
                    new_node2message[node][val] = self.node2message[node][val]
        
        return new_node2message
    
    def val_check(self):
        for i, node in enumerate(self.nodes):
            vals = set()
            for key in self.maps.keys():
                try:
                    vals.add(key[i])
                except:
                    vals.add(key)
            vals = list(vals)
            vals.sort()
            
            if node.vals != vals:
                raise ValueError

In [4]:
class MarkovRandomField():
    def __init__(self, nodes, factors):
        self.nodes = nodes
        self.factors = factors
        
        self.num_nodes = len(self.nodes)
        self.num_factors = len(self.factors)
        
        self.adj_mtx = self.get_adj_mtx() # [num_nodes, num_factors]
        for row, node in zip(self.adj_mtx, self.nodes):
            node.init_messages(row, self.factors)
        for factor in self.factors:
            factor.init_messages()
    
    def belief_propagation(self, steps):
        for i in range(steps):
            for node in self.nodes:
                new_factor2message = node.update_messages()
                node.factor2message = new_factor2message
            for factor in self.factors:
                new_node2message = factor.update_messages()
                factor.node2message = new_node2message
    
    def get_adj_mtx(self):
        adj_mtx = np.zeros([self.num_nodes, self.num_factors])
        
        for i, node in enumerate(self.nodes):
            for j, factor in enumerate(self.factors):
                adj_mtx[i, j] = 1 if node in factor.nodes else 0
        
        return adj_mtx                    

In [5]:
X_1 = Node("X_1", [0, 1])
X_2 = Node("X_2", [0, 1])
X_3 = Node("X_3", [0, 1])
X_4 = Node("X_4", [0, 1])

Phi_1 = Factor("Phi_1", [X_1],
   {0: 0.1, 1: 0.9}
)
Phi_2 = Factor("Phi_2", [X_1, X_2],
   {
       (0, 0): 0.3,
       (0, 1): 0.7,
       (1, 0): 0.9,
       (1, 1): 0.1
   }
)
Phi_3 = Factor("Phi_3", [X_3],
   {0: 0.8, 1: 0.2}
)
Phi_4 = Factor("Phi_4", [X_2, X_3, X_4],
   {
       (0, 0, 0): 0.3,
       (0, 0, 1): 0.7,
       (0, 1, 0): 0.8,
       (0, 1, 1): 0.2,
       (1, 0, 0): 0.1,
       (1, 0, 1): 0.9,
       (1, 1, 0): 0.2,
       (1, 1, 1): 0.8
   }
)

nodes = [X_1, X_2, X_3, X_4]
factors = [Phi_1, Phi_2, Phi_3, Phi_4]

In [6]:
mrf = MarkovRandomField(nodes, factors)

In [7]:
mrf.belief_propagation(10)

X_1
Phi_1
[<__main__.Node object at 0x0000019A5937D668>]
[]
{0: 0.1, 1: 0.9}
0
{<__main__.Factor object at 0x0000019A5937D780>: {0: 0.0, 1: 0.0}, <__main__.Factor object at 0x0000019A5937D7B8>: {0: 0.0, 1: 0.0}}
X_1
Phi_1
[<__main__.Node object at 0x0000019A5937D668>]
[]
{0: 0.1, 1: 0.9}
1
{<__main__.Factor object at 0x0000019A5937D780>: {0: 0.1, 1: 0.0}, <__main__.Factor object at 0x0000019A5937D7B8>: {0: 0.5, 1: 0.0}}
X_3
Phi_3
[<__main__.Node object at 0x0000019A5937D710>]
[]
{0: 0.8, 1: 0.2}
0
{<__main__.Factor object at 0x0000019A5937D6A0>: {0: 0.0, 1: 0.0}, <__main__.Factor object at 0x0000019A5937D6D8>: {0: 0.0, 1: 0.0}}
X_3
Phi_3
[<__main__.Node object at 0x0000019A5937D710>]
[]
{0: 0.8, 1: 0.2}
1
{<__main__.Factor object at 0x0000019A5937D6A0>: {0: 0.8, 1: 0.0}, <__main__.Factor object at 0x0000019A5937D6D8>: {0: 0.5, 1: 0.0}}
X_1
Phi_1
[<__main__.Node object at 0x0000019A5937D668>]
[]
{0: 0.1, 1: 0.9}
0
{<__main__.Factor object at 0x0000019A5937D780>: {0: 0.1, 1: 0.9}, <__mai

In [10]:
X_1.get_marginal_distribution(), X_2.get_marginal_distribution(), X_3.get_marginal_distribution(), X_4.get_marginal_distribution()

{0: 0.35520000000000007, 1: 0.6447999999999999}