In [174]:
from sympy import *
from itertools import product

In [191]:
class Node:
        
    def add_neighbors(self, nbrs):
        if not isinstance(nbrs, list):
            nbrs = [nbrs]
        assert all([isinstance(n, self.required_nbr_type) for n in nbrs])
        self.neighbors.extend(nbrs)
    
    @property
    def is_child(self):
        return len(self.neighbors) == 1


class Variable(Node):
    
    def __init__(self, name, vals):
        self.name = name
        self.neighbors = []
        self.required_nbr_type = Factor
        self.symbol = Symbol(name)
        self.vals = vals
        self.messages = set()
        
    def reset(self):
        self.symbol = Symbol(self.name)


class Factor(Node):
    
    def __init__(self, name):
        self.name = name
        self.neighbors = []
        self.required_nbr_type = Variable
        self._func = Function(name)
        self.messages = set()
        
    def add_neighbors(self, nbrs):
        for nbr in nbrs:
            nbr.add_neighbors([self])
        super().add_neighbors(nbrs)
        
    @property
    def func(self):
        return self._func(*[n.symbol for n in self.neighbors])

In [227]:
g = set()

var_vals = range(10)

# Variables
x_1 = Variable('x_1', var_vals)
x_2 = Variable('x_2', var_vals)
x_3 = Variable('x_3', var_vals)
x_4 = Variable('x_4', var_vals)

# Factors
f_a = Factor('f_a')
f_b = Factor('f_b')
f_c = Factor('f_c')

# Define root
root = x_3

# Define factors
f_c.add_neighbors([x_2, x_4])
f_a.add_neighbors([x_1, x_2])
f_b.add_neighbors([x_2, x_3])

# Add factors to graph
g.add(f_a)
g.add(f_b)
g.add(f_c)

# Enumerate all nodes
nodes = set()
for f in g:
    nodes = nodes | set([f] + f.neighbors)

In [228]:
def compute_message(from_node, to_node):
    if from_node.is_child and isinstance(from_node, Variable):
        return 1
    if from_node.is_child and isinstance(from_node, Factor):
        return from_node.func
    if isinstance(from_node, Factor):
        factor_summation = compute_factor_summation(from_node, to_node)
        # compute product over incoming nodes
    if isinstance(from_node, Variable):
        pass
        # compute product over incoming nodes

        
def compute_factor_summation(from_node, to_node):
    sum_over = [n for n in f_a.neighbors if n.name != to_node.name]
    expr = 0
    for vals in product(*[n.vals for n in sum_over]):
        for var, val in zip(sum_over, vals):
            var.symbol = val
        expr += from_node.func
    for var in sum_over:
        var.reset()
    return expr   

In [229]:
queue = [n for n in nodes if n.is_child and n != root]
visited = set()

while queue and visited != nodes:
    node = queue.pop(0)
    for nbr in node.neighbors:
        msg = compute_message(from_node=node, to_node=nbr)
        nbr.messages.add(msg)
        if nbr not in visited and nbr not in queue:
            queue.append(nbr)
    visited.add(node)

In [233]:
f_c.messages

{1, None}