In [1]:
from collections import namedtuple
from sympy import *
from itertools import product

In [17]:
class Node:
        
    def add_neighbors(self, nbrs):
        if not isinstance(nbrs, list):
            nbrs = [nbrs]
        assert all([isinstance(n, self.required_nbr_type) for n in nbrs])
        self.neighbors.extend(nbrs)
    
    @property
    def is_child(self):
        return len(self.neighbors) == 1


class Variable(Node):
    
    def __init__(self, name, vals):
        self.name = name
        self.neighbors = []
        self.required_nbr_type = Factor
        self.symbol = Symbol(name)
        self.vals = vals
        self.messages = set()
        
    def reset(self):
        self.symbol = Symbol(self.name)
        
#     def marginal(self, val):
#         if len(self.messages) != len(self.neighbors):
#             raise(f'Variable {self.name} has not yet received all incoming messages')
#         expr = 1
#         for msg in self.messages:
#             expr *= msg.expr
#         self.symbol = val; p_tilde = expr
#         print(expr)  # should print out an int
#         Z = 0
#         for v in self.vals:
#             self.symbol = v
#             Z += expr
#         return p_tilde / Z


class Factor(Node):
    
    def __init__(self, name):
        self.name = name
        self.neighbors = []
        self.required_nbr_type = Variable
        self._func = Function(name)
        self.messages = set()
        
    def add_neighbors(self, nbrs):
        for nbr in nbrs:
            nbr.add_neighbors([self])
        super().add_neighbors(nbrs)
        
    @property
    def func(self):
        return self._func(*[n.symbol for n in self.neighbors])
    
    
Message = namedtuple('Message', ['expr', 'from_node', 'to_node'])

In [77]:
g = set()

var_vals = [v for v in range(10)]

# Variables
x_1 = Variable('x_1', var_vals)
x_2 = Variable('x_2', var_vals)
x_3 = Variable('x_3', var_vals)
x_4 = Variable('x_4', var_vals)

# Factors
f_a = Factor('f_a')
f_b = Factor('f_b')
f_c = Factor('f_c')

# Define root
root = x_3

# Define factors
f_c.add_neighbors([x_2, x_4])
f_a.add_neighbors([x_1, x_2])
f_b.add_neighbors([x_2, x_3])

# Add factors to graph
g.add(f_a)
g.add(f_b)
g.add(f_c)

# Enumerate all nodes
nodes = set()
for f in g:
    nodes = nodes | set([f] + f.neighbors)

In [78]:
def compute_message(from_node, to_node):
    if from_node.is_child:
        if isinstance(from_node, Factor):
            return Message(from_node.func, from_node, to_node)
        else:
            return Message(1, from_node, to_node)
    expr = compute_product_incoming_msgs(from_node, to_node)
    if isinstance(from_node, Factor):
        expr *= compute_factor_summation(from_node, to_node)
    return Message(expr, from_node, to_node)

        
def compute_factor_summation(from_node, to_node):
    sum_over = [n for n in from_node.neighbors if n.name != to_node.name]
    expr = 0
    for vals in product(*[n.vals for n in sum_over]):
        for var, val in zip(sum_over, vals):
            var.symbol = val
        expr += from_node.func
    for var in sum_over:
        var.reset()
    return expr


def compute_product_incoming_msgs(from_node, to_node):
    expr = 1
    for msg in filter(lambda m: m.to_node == from_node, from_node.messages):
        expr *= msg.expr
    return expr

# Children to root

In [79]:
queue = [n for n in nodes if n.is_child and n != root]
visited = set()

while queue and visited != nodes:
    node = queue.pop(0)
    for nbr in filter(lambda n: n not in visited, node.neighbors):
        msg = compute_message(from_node=node, to_node=nbr)
        nbr.messages.add(msg)
        if nbr not in queue:
            queue.append(nbr)
    visited.add(node)

# Root to children

In [92]:
# implement BFS