In [6]:
from collections import deque, defaultdict
from functools import reduce
from itertools import product, chain
from math import sin
import operator
import pickle

import sympy
from sympy import Function, Symbol, Max

# Load correct messages

In [7]:
EXPECTED_MESSAGES = pickle.load(open('expected_messages.pkl', 'rb'))

# Define required objects

In [41]:
class Node:
        
    def add_neighbors(self, nbrs):
        if not isinstance(nbrs, list):
            nbrs = [nbrs]
        assert all([isinstance(n, self.required_nbr_type) for n in nbrs])
        self.neighbors.extend(nbrs)
  
    def clear_messages(self):
        self.messages = set()
        self.max_messages = set()
    
    @property
    def is_child(self):
        return len(self.neighbors) == 1


class Variable(Node):
    
    def __init__(self, name, vals):
        self.name = name
        self.neighbors = []
        self.required_nbr_type = Factor
        self.symbol = Symbol(name)
        self.vals = vals
        self.messages = set()
        self.max_messages = set()
        
    def repr_messages(self):
        return {repr(m) for m in self.messages}
        
    def marginal(self, val):
        if len(self.messages) != len(self.neighbors):
            raise Exception(f'Variable {self.name} has not yet received all incoming messages')
        if val not in self.vals:
            raise Exception(f'{val} not in this variable\'s support: {self.vals}')
        Z = 0  # Compute normalization constant
        for v in self.vals:
            expr = 1
            for msg in self.messages:
                expr *= msg.expr.subs({self.symbol: v})
            Z += expr
            if v == val:  # Compute numerator
                p_tilde = expr
        return p_tilde / Z


class Factor(Node):
    
    def __init__(self, name):
        self.name = name
        self.neighbors = []
        self.required_nbr_type = Variable
        self.func = Function(name)
        self.messages = set()
        self.max_messages = set()
        
    def add_neighbors(self, nbrs):
        for nbr in nbrs:
            nbr.add_neighbors([self])
        super().add_neighbors(nbrs)
        
    @property
    def expr(self):
        return self.func(*[n.symbol for n in self.neighbors])
    

class Message:
    
    def __init__(self, expr, from_node, to_node):
        self.expr = expr
        self.from_node = from_node
        self.to_node = to_node
    
    def __repr__(self):
        return f'({self.expr}, {self.from_node.name}, {self.to_node.name})'
    
    
class MaxMessage(Message):
    pass


class Graph:
    
    def __init__(self):
        self.factors = {}
        
    def add(self, fac, lmbda_expr=None):
        self.factors[fac] = {
            'lmbda_expr': lmbda_expr,
        }
        
    def __call__(self, expr, subs=None):
        for fac, metadata in self.factors.items():
            if metadata.get('lmbda_expr'):
                expr = expr.replace(fac.func, metadata['lmbda_expr'])
        return expr.subs(subs or {})
    
    @property
    def factor_args(self):
        return {fac.func: fac.expr.args for fac in self.factors}
    
    @property
    def nodes(self):
        nodes = set()
        for f in self.factors:
            nodes = nodes | set([f] + f.neighbors)
        return nodes
    
    def joint(self, subs=None):
        expr = reduce(operator.mul, (fac.expr for fac in self.factors))
        return self.__call__(expr, subs)


def clear_messages(nodes):
    return {n.clear_messages() for n in nodes}


def repr_messages(nodes):
    msgs = set()
    for n in nodes:
        msgs = msgs | {repr(m) for m in n.messages}
    return msgs

# Instantiate graph, define variables, add factors

In [42]:
g = Graph()

var_vals = [v for v in range(10)]

# Variables
x_1 = Variable('x_1', var_vals)
x_2 = Variable('x_2', var_vals)
x_3 = Variable('x_3', var_vals)
x_4 = Variable('x_4', var_vals)

# Factors
f_a = Factor('f_a')
f_b = Factor('f_b')
f_c = Factor('f_c')

# Define root
root = x_3

# Define factors
f_c.add_neighbors([x_2, x_4])
f_a.add_neighbors([x_1, x_2])
f_b.add_neighbors([x_2, x_3])

# Add factors to graph
g.add(f_a, lambda x, y: sympy.exp(sympy.tanh(x + y / 10)))
g.add(f_b, lambda x, y: sympy.exp(sympy.sin(x - y / 10)))
g.add(f_c, lambda x, y: sympy.exp(sympy.log(sympy.log(y + 1) + 1)))

# Define message passing helpers

In [43]:
def compute_message(from_node, to_node):
    if from_node.is_child:
        expr = from_node.expr if isinstance(from_node, Factor) else 1
        return Message(expr, from_node, to_node)
    expr = reduce_incoming_msgs(from_node)
    if isinstance(from_node, Factor):
        expr *= aggregate_over_nbrs(from_node.expr, from_node, to_node)
    return Message(expr, from_node, to_node)


def compute_max_message(from_node, to_node):
    if from_node.is_child:
        expr = from_node.expr if isinstance(from_node, Factor) else 1
        expr = sympy.log(expr)
        return Message(expr, from_node, to_node)
    expr = reduce_incoming_msgs(from_node, reduce_func=lambda x, y: x + y)
    if isinstance(from_node, Factor):
        expr += sympy.log(from_node.expr)
        expr = aggregate_over_nbrs(expr, from_node, to_node, agg_func=sympy_max)
    return MaxMessage(expr, from_node, to_node)


def aggregate_over_nbrs(expr, from_node, to_node, agg_func=sum):
    sum_over = [n for n in from_node.neighbors if n.name != to_node.name]
    exprs = []
    for vals in product(*[n.vals for n in sum_over]):
        subs = {var.symbol: val for var, val in zip(sum_over, vals)}
        exprs.append(from_node.expr.subs(subs))
    return agg_func(exprs)


def reduce_incoming_msgs(from_node, reduce_func=operator.mul):
    return reduce(reduce_func, (msg.expr for msg in from_node.messages))


def sympy_max(iterable):
    return Max(*iterable)


def get_argmaxes(expr):
    def _get_argmaxes(expr):
        if isinstance(expr, Max):
            expr = tuple([expr])
        if isinstance(expr, tuple) and all(isinstance(e, Max) for e in expr):
            for e in expr:
                max_val = g(e)
                yield [term for term in e.args if g(term) == max_val][-1:]  # "randomly" choose the last one
        else:
            yield from _get_argmaxes(expr.args)
    argmaxes, = zip(*_get_argmaxes(expr))
    return argmaxes

# Message passing

- Pass messages to from children to root, then root to children, via the sum-product algorithm.
- Additionally, compute max-messages via the max-sum algorithm, and the maximizing configuration of variables via the back-tracking algorithm.

In [44]:
clear_messages(g.nodes)

queue = deque([n for n in g.nodes if n.is_child and n != root])
visited = deque()
forward = True
max_states = defaultdict(list)


def back_track(node):
    max_msg = sum(m.expr for m in node.max_messages)
    max_exprs = {v: max_msg.subs(node.symbol, v) for v in node.vals}
    maxes = {v: g(expr) for v, expr in max_exprs.items()}
    if node == root:
        max_val = max(maxes.values())
        max_states[node.symbol] += [v for v, mx in maxes.items() if mx == max_val]
    for v in max_states[node.symbol]:
        argmaxes = get_argmaxes(max_exprs[v])
        for fac in argmaxes:
            setting, = [(s, v) for s, v in zip(g.factor_args[fac.func], fac.args) if s != node.symbol]
            phi_n, v = setting
            max_states[phi_n].append(v)


# Pass messages forward and back; run max-sum in parallel
while queue and visited != g.nodes:
    node = queue.popleft() if forward else queue.pop()
    if not forward and isinstance(node, Variable) and node.max_messages:  # Back-track for max-sum
        back_track(node)
    for nbr in filter(lambda n: n not in visited, node.neighbors):
        msg = compute_message(from_node=node, to_node=nbr)
        nbr.messages.add(msg)
        if forward:
            max_msg = compute_max_message(from_node=node, to_node=nbr)
            nbr.max_messages.add(max_msg)
        if nbr not in queue:
            queue.append(nbr)
    visited.append(node)
    if forward and not queue:
        forward = False
        queue = visited
        visited = deque()

In [45]:
assert EXPECTED_MESSAGES == repr_messages(g.nodes)

# Print maximizing configurations

In [46]:
configs = [{k: v for k, v in zip(max_states.keys(), states)} for states in zip(*max_states.values())]

for c in configs:
    max_val = float(g.joint(subs=c))
    print(f'Config: {c} | Max val: {max_val}')

Config: {x_3: 4, x_2: 2, x_1: 9, x_4: 9} | Max val: 24.392582884438887


# Compute marginals

In [47]:
total = 0
for v in var_vals:
    m = float(g(x_2.marginal(v)))
    total += m
    print(f'{v}: {m:1.3}')

assert total == 1

0: 0.0442
1: 0.112
2: 0.175
3: 0.119
4: 0.0489
5: 0.0273
6: 0.0383
7: 0.0961
8: 0.183
9: 0.156
