In [338]:
from collections import deque
from functools import reduce
from itertools import product
import operator

from sympy import Symbol

In [339]:
class Node:
        
    def add_neighbors(self, nbrs):
        if not isinstance(nbrs, list):
            nbrs = [nbrs]
        assert all([isinstance(n, self.required_nbr_type) for n in nbrs])
        self.neighbors.extend(nbrs)
  
    def clear_messages(self):
        self.messages = set()
    
    @property
    def is_child(self):
        return len(self.neighbors) == 1


class Variable(Node):
    
    def __init__(self, name, vals):
        self.name = name
        self.neighbors = []
        self.required_nbr_type = Factor
        self.symbol = Symbol(name)
        self.vals = vals
        self.messages = set()
        
    def reset(self):
        self.symbol = Symbol(self.name)
        
    def pprint_messages(self):
        return {m.pprint() for m in self.messages}
        
    def marginal(self, val):
        if len(self.messages) != len(self.neighbors):
            raise(f'Variable {self.name} has not yet received all incoming messages')
        expr = 1
        for msg in self.messages:
            expr *= msg.expr.subs({self.symbol: val})
        return expr
#         self.symbol = val; p_tilde = expr
#         print(expr)  # should print out an int
#         Z = 0
#         for v in self.vals:
#             self.symbol = v
#             Z += expr
#         return p_tilde / Z


class Factor(Node):
    
    def __init__(self, name):
        self.name = name
        self.neighbors = []
        self.required_nbr_type = Variable
        self._func = Function(name)
        self.messages = set()
        
    def add_neighbors(self, nbrs):
        for nbr in nbrs:
            nbr.add_neighbors([self])
        super().add_neighbors(nbrs)
        
    def reset(self):
        self.messages = set()
        
    @property
    def func(self):
        return self._func(*[n.symbol for n in self.neighbors])
    

class Message:
    
    def __init__(self, expr, from_node, to_node):
        self.expr = expr
        self.from_node = from_node
        self.to_node = to_node
        
    def pprint(self):
        return (self.expr, self.from_node.name, self.to_node.name)


def clear_messages(nodes):
    return {n.clear_messages() for n in nodes}


def pprint_messages(nodes):
    msgs = set()
    for n in nodes:
        msgs = msgs | {m.pprint() for m in n.messages}
    return msgs

In [340]:
g = set()

var_vals = [v for v in range(10)]

# Variables
x_1 = Variable('x_1', var_vals)
x_2 = Variable('x_2', var_vals)
x_3 = Variable('x_3', var_vals)
x_4 = Variable('x_4', var_vals)

# Factors
f_a = Factor('f_a')
f_b = Factor('f_b')
f_c = Factor('f_c')

# Define root
root = x_3

# Define factors
f_c.add_neighbors([x_2, x_4])
f_a.add_neighbors([x_1, x_2])
f_b.add_neighbors([x_2, x_3])

# Add factors to graph
g.add(f_a)
g.add(f_b)
g.add(f_c)

# Enumerate all nodes
nodes = set()
for f in g:
    nodes = nodes | set([f] + f.neighbors)

In [350]:
def reduce_mul(iterable):
    return reduce(operator.mul, iterable)


def compute_message(from_node, to_node):
    if from_node.is_child:
        if isinstance(from_node, Factor):
            return Message(from_node.func, from_node, to_node)
        else:
            return Message(1, from_node, to_node)
    expr = compute_product_incoming_msgs(from_node, to_node)
    if isinstance(from_node, Factor):
        expr *= compute_factor_summation(from_node, to_node)
    return Message(expr, from_node, to_node)

        
def compute_factor_summation(from_node, to_node):
    sum_over = [n for n in from_node.neighbors if n.name != to_node.name]
    expr = 0
    for vals in product(*[n.vals for n in sum_over]):
        subs = {var.symbol: val for var, val in zip(sum_over, vals)}
        expr += from_node.func.subs(subs)
    return expr


def compute_product_incoming_msgs(from_node, to_node):
#     return reduce_mul()
    expr = 1
    for msg in filter(lambda m: m.to_node == from_node, from_node.messages):
        expr *= msg.expr
    return expr

# Children to root

In [351]:
clear_messages(nodes)

queue = deque([n for n in nodes if n.is_child and n != root])
visited = deque()
forward = True

while queue and visited != nodes:
    node = queue.popleft() if forward else queue.pop()
    for nbr in filter(lambda n: n not in visited, node.neighbors):
        msg = compute_message(from_node=node, to_node=nbr)
        nbr.messages.add(msg)
        if nbr not in queue:
            queue.append(nbr)
    visited.append(node)
    if forward and not queue:
        forward = False
        queue = visited
        visited = deque()

In [352]:
# EXPECTED_MESSAGES = pprint_messages(nodes)

In [353]:
assert pprint_messages(nodes) == EXPECTED_MESSAGES

In [349]:
x_2.marginal(2.3)

(f_a(0, 2.3) + f_a(1, 2.3) + f_a(2, 2.3) + f_a(3, 2.3) + f_a(4, 2.3) + f_a(5, 2.3) + f_a(6, 2.3) + f_a(7, 2.3) + f_a(8, 2.3) + f_a(9, 2.3))**2*(f_b(2.3, 0) + f_b(2.3, 1) + f_b(2.3, 2) + f_b(2.3, 3) + f_b(2.3, 4) + f_b(2.3, 5) + f_b(2.3, 6) + f_b(2.3, 7) + f_b(2.3, 8) + f_b(2.3, 9))*(f_c(2.3, 0) + f_c(2.3, 1) + f_c(2.3, 2) + f_c(2.3, 3) + f_c(2.3, 4) + f_c(2.3, 5) + f_c(2.3, 6) + f_c(2.3, 7) + f_c(2.3, 8) + f_c(2.3, 9))**2

In [122]:
a, b = var('a b')
c = Function('c')

In [123]:
f = c(a, b)

In [125]:
f.subs({a: 7})

c(7, b)

In [120]:
f

c(a, b)

In [121]:
f.evalf(subs={a: 7, b:9, c:lambda x, y: x * y})

SympifyError: Sympify of expression 'could not parse '<function <lambda> at 0x7f31d6550b70>'' failed, because of exception being raised:
SyntaxError: invalid syntax (<string>, line 1)

In [103]:
a = 7

In [106]:
f

c(7, b)

# Root to children

In [92]:
# implement BFS