In [6]:
from collections import deque
from functools import reduce
from itertools import product
from math import log
import operator
import pickle

from sympy import Function, Symbol, log, Max

# Load correct messages

In [7]:
EXPECTED_MESSAGES = pickle.load(open('expected_messages.pkl', 'rb'))

In [8]:
class Node:
        
    def add_neighbors(self, nbrs):
        if not isinstance(nbrs, list):
            nbrs = [nbrs]
        assert all([isinstance(n, self.required_nbr_type) for n in nbrs])
        self.neighbors.extend(nbrs)
  
    def clear_messages(self):
        self.messages = set()
        self.max_messages = set()
    
    @property
    def is_child(self):
        return len(self.neighbors) == 1


class Variable(Node):
    
    def __init__(self, name, vals):
        self.name = name
        self.neighbors = []
        self.required_nbr_type = Factor
        self.symbol = Symbol(name)
        self.vals = vals
        self.messages = set()
        self.max_messages = set()
        
    def repr_messages(self):
        return {repr(m) for m in self.messages}
        
    def marginal(self, val):
        if len(self.messages) != len(self.neighbors):
            raise Expcetion(f'Variable {self.name} has not yet received all incoming messages')
        if val not in self.vals:
            raise Exception(f'{val} not in this variable\'s support: {self.vals}')
        Z = 0  # Compute normalization constant
        for v in self.vals:
            expr = 1
            for msg in self.messages:
                expr *= msg.expr.subs({self.symbol: v})
            Z += expr
            if v == val:  # Compute numerator
                p_tilde = expr
        return p_tilde / Z


class Factor(Node):
    
    def __init__(self, name):
        self.name = name
        self.neighbors = []
        self.required_nbr_type = Variable
        self.func = Function(name)
        self.messages = set()
        self.max_messages = set()
        
    def add_neighbors(self, nbrs):
        for nbr in nbrs:
            nbr.add_neighbors([self])
        super().add_neighbors(nbrs)
        
    @property
    def expr(self):
        return self.func(*[n.symbol for n in self.neighbors])
    

class Message:
    
    def __init__(self, expr, from_node, to_node):
        self.expr = expr
        self.from_node = from_node
        self.to_node = to_node
    
    def __repr__(self):
        return f'({self.expr}, {self.from_node.name}, {self.to_node.name})'
    
    
class MaxMessage(Message):
    pass


def clear_messages(nodes):
    return {n.clear_messages() for n in nodes}


def repr_messages(nodes):
    msgs = set()
    for n in nodes:
        msgs = msgs | {repr(m) for m in n.messages}
    return msgs

In [9]:
g = set()

var_vals = [v for v in range(10)]

# Variables
x_1 = Variable('x_1', var_vals)
x_2 = Variable('x_2', var_vals)
x_3 = Variable('x_3', var_vals)
x_4 = Variable('x_4', var_vals)

# Factors
f_a = Factor('f_a')
f_b = Factor('f_b')
f_c = Factor('f_c')

# Define root
root = x_3

# Define factors
f_c.add_neighbors([x_2, x_4])
f_a.add_neighbors([x_1, x_2])
f_b.add_neighbors([x_2, x_3])

# Add factors to graph
g.add(f_a)
g.add(f_b)
g.add(f_c)

# Enumerate all nodes
nodes = set()
for f in g:
    nodes = nodes | set([f] + f.neighbors)

In [52]:
def compute_message(from_node, to_node):
    if from_node.is_child:
        expr = from_node.expr if isinstance(from_node, Factor) else 1
        return Message(expr, from_node, to_node)
    expr = reduce_incoming_msgs(from_node)
    if isinstance(from_node, Factor):
        expr *= aggregate_over_nbrs(from_node.expr, from_node, to_node)
    return Message(expr, from_node, to_node)


def compute_max_message(from_node, to_node):
    if from_node.is_child:
        expr = from_node.expr if isinstance(from_node, Factor) else 1
        expr = log(expr)
        return Message(expr, from_node, to_node)
    expr = reduce_incoming_msgs(from_node, reduce_func=lambda x, y: x + y)
    if isinstance(from_node, Factor):
        expr += log(from_node.expr)
        expr = aggregate_over_nbrs(expr, from_node, to_node, agg_func=sympy_max)
    return MaxMessage(expr, from_node, to_node)


def aggregate_over_nbrs(expr, from_node, to_node, agg_func=sum):
    sum_over = [n for n in from_node.neighbors if n.name != to_node.name]
    exprs = []
    for vals in product(*[n.vals for n in sum_over]):
        subs = {var.symbol: val for var, val in zip(sum_over, vals)}
        exprs.append(from_node.expr.subs(subs))
    return agg_func(exprs)


def reduce_incoming_msgs(from_node, reduce_func=operator.mul):
    return reduce(reduce_func, (msg.expr for msg in from_node.messages))


def sympy_max(iterable):
    return Max(*iterable)

# Pass messages to from children to root, then root to children

In [53]:
clear_messages(nodes)

queue = deque([n for n in nodes if n.is_child and n != root])
visited = deque()
forward = True

while queue and visited != nodes:
    node = queue.popleft() if forward else queue.pop()
    for nbr in filter(lambda n: n not in visited, node.neighbors):
        msg = compute_message(from_node=node, to_node=nbr)
        nbr.messages.add(msg)
        if forward:
            max_msg = compute_max_message(from_node=node, to_node=nbr)
            nbr.max_messages.add(max_msg)
        else:  # Back-track for max-sum
            pass
        if nbr not in queue:
            queue.append(nbr)
    visited.append(node)
    if forward and not queue:
        forward = False
        queue = visited
        visited = deque()

In [54]:
assert EXPECTED_MESSAGES == repr_messages(nodes)

In [55]:
# list(zip(f_a.expr.subs(x_1.symbol, 0).args, f_a.expr.args))

In [56]:
x_3.max_messages

{(Max(f_b(0, x_3), f_b(1, x_3), f_b(2, x_3), f_b(3, x_3), f_b(4, x_3), f_b(5, x_3), f_b(6, x_3), f_b(7, x_3), f_b(8, x_3), f_b(9, x_3)), f_b, x_3)}

In [45]:
vals = {v: sum(m.expr for m in x_3.max_messages).subs(x_3.symbol, v) for v in x_3.vals}

In [59]:
vals[0].args

# there's a bunch of digging and possible duplicate computation to do, but you can get these states out

(f_b(0, 0),
 f_b(1, 0),
 f_b(2, 0),
 f_b(3, 0),
 f_b(4, 0),
 f_b(5, 0),
 f_b(6, 0),
 f_b(7, 0),
 f_b(8, 0),
 f_b(9, 0))

In [28]:
vals

{0: Max(f_b(0, 0), f_b(1, 0), f_b(2, 0), f_b(3, 0), f_b(4, 0), f_b(5, 0), f_b(6, 0), f_b(7, 0), f_b(8, 0), f_b(9, 0)),
 1: Max(f_b(0, 1), f_b(1, 1), f_b(2, 1), f_b(3, 1), f_b(4, 1), f_b(5, 1), f_b(6, 1), f_b(7, 1), f_b(8, 1), f_b(9, 1)),
 2: Max(f_b(0, 2), f_b(1, 2), f_b(2, 2), f_b(3, 2), f_b(4, 2), f_b(5, 2), f_b(6, 2), f_b(7, 2), f_b(8, 2), f_b(9, 2)),
 3: Max(f_b(0, 3), f_b(1, 3), f_b(2, 3), f_b(3, 3), f_b(4, 3), f_b(5, 3), f_b(6, 3), f_b(7, 3), f_b(8, 3), f_b(9, 3)),
 4: Max(f_b(0, 4), f_b(1, 4), f_b(2, 4), f_b(3, 4), f_b(4, 4), f_b(5, 4), f_b(6, 4), f_b(7, 4), f_b(8, 4), f_b(9, 4)),
 5: Max(f_b(0, 5), f_b(1, 5), f_b(2, 5), f_b(3, 5), f_b(4, 5), f_b(5, 5), f_b(6, 5), f_b(7, 5), f_b(8, 5), f_b(9, 5)),
 6: Max(f_b(0, 6), f_b(1, 6), f_b(2, 6), f_b(3, 6), f_b(4, 6), f_b(5, 6), f_b(6, 6), f_b(7, 6), f_b(8, 6), f_b(9, 6)),
 7: Max(f_b(0, 7), f_b(1, 7), f_b(2, 7), f_b(3, 7), f_b(4, 7), f_b(5, 7), f_b(6, 7), f_b(7, 7), f_b(8, 7), f_b(9, 7)),
 8: Max(f_b(0, 8), f_b(1, 8), f_b(2, 8), f_b(3, 

# Compute marginal

In [624]:
x_2.marginal(2)

(f_a(0, 2) + f_a(1, 2) + f_a(2, 2) + f_a(3, 2) + f_a(4, 2) + f_a(5, 2) + f_a(6, 2) + f_a(7, 2) + f_a(8, 2) + f_a(9, 2))**2*(f_b(2, 0) + f_b(2, 1) + f_b(2, 2) + f_b(2, 3) + f_b(2, 4) + f_b(2, 5) + f_b(2, 6) + f_b(2, 7) + f_b(2, 8) + f_b(2, 9))*(f_c(2, 0) + f_c(2, 1) + f_c(2, 2) + f_c(2, 3) + f_c(2, 4) + f_c(2, 5) + f_c(2, 6) + f_c(2, 7) + f_c(2, 8) + f_c(2, 9))**2/((f_a(0, 0) + f_a(1, 0) + f_a(2, 0) + f_a(3, 0) + f_a(4, 0) + f_a(5, 0) + f_a(6, 0) + f_a(7, 0) + f_a(8, 0) + f_a(9, 0))**2*(f_b(0, 0) + f_b(0, 1) + f_b(0, 2) + f_b(0, 3) + f_b(0, 4) + f_b(0, 5) + f_b(0, 6) + f_b(0, 7) + f_b(0, 8) + f_b(0, 9))*(f_c(0, 0) + f_c(0, 1) + f_c(0, 2) + f_c(0, 3) + f_c(0, 4) + f_c(0, 5) + f_c(0, 6) + f_c(0, 7) + f_c(0, 8) + f_c(0, 9))**2 + (f_a(0, 1) + f_a(1, 1) + f_a(2, 1) + f_a(3, 1) + f_a(4, 1) + f_a(5, 1) + f_a(6, 1) + f_a(7, 1) + f_a(8, 1) + f_a(9, 1))**2*(f_b(1, 0) + f_b(1, 1) + f_b(1, 2) + f_b(1, 3) + f_b(1, 4) + f_b(1, 5) + f_b(1, 6) + f_b(1, 7) + f_b(1, 8) + f_b(1, 9))*(f_c(1, 0) + f_c(1, 1)

## Substitute in factor funcs

In [625]:
x_2.marginal(2).replace(f_a.func, lambda x, y: x ** y)

81225*(f_b(2, 0) + f_b(2, 1) + f_b(2, 2) + f_b(2, 3) + f_b(2, 4) + f_b(2, 5) + f_b(2, 6) + f_b(2, 7) + f_b(2, 8) + f_b(2, 9))*(f_c(2, 0) + f_c(2, 1) + f_c(2, 2) + f_c(2, 3) + f_c(2, 4) + f_c(2, 5) + f_c(2, 6) + f_c(2, 7) + f_c(2, 8) + f_c(2, 9))**2/(100*(f_b(0, 0) + f_b(0, 1) + f_b(0, 2) + f_b(0, 3) + f_b(0, 4) + f_b(0, 5) + f_b(0, 6) + f_b(0, 7) + f_b(0, 8) + f_b(0, 9))*(f_c(0, 0) + f_c(0, 1) + f_c(0, 2) + f_c(0, 3) + f_c(0, 4) + f_c(0, 5) + f_c(0, 6) + f_c(0, 7) + f_c(0, 8) + f_c(0, 9))**2 + 2025*(f_b(1, 0) + f_b(1, 1) + f_b(1, 2) + f_b(1, 3) + f_b(1, 4) + f_b(1, 5) + f_b(1, 6) + f_b(1, 7) + f_b(1, 8) + f_b(1, 9))*(f_c(1, 0) + f_c(1, 1) + f_c(1, 2) + f_c(1, 3) + f_c(1, 4) + f_c(1, 5) + f_c(1, 6) + f_c(1, 7) + f_c(1, 8) + f_c(1, 9))**2 + 81225*(f_b(2, 0) + f_b(2, 1) + f_b(2, 2) + f_b(2, 3) + f_b(2, 4) + f_b(2, 5) + f_b(2, 6) + f_b(2, 7) + f_b(2, 8) + f_b(2, 9))*(f_c(2, 0) + f_c(2, 1) + f_c(2, 2) + f_c(2, 3) + f_c(2, 4) + f_c(2, 5) + f_c(2, 6) + f_c(2, 7) + f_c(2, 8) + f_c(2, 9))**2 + 