In [292]:
from copy import copy

import sympy
from sympy import Symbol, Matrix, Function, diff

In [298]:
VALID_SYMBOLS = {
    'x': Symbol('x'),
    'alpha': Symbol('alpha'),
    'beta': Symbol('beta'),
    'mu': Symbol('mu'),
    'gamma': Symbol('gamma'),
}


class ExpFamDensity:
    
    N_MOMENTS = 2
    
    def __init__(self):
        self._test()
        assert self.symbols.issubset(VALID_SYMBOLS.values())
        self.wrt = self.x
        self.params = {s for s in self.symbols if s != self.x}
    
    @property
    def expr(self):
        expr, = self.phi.T * self.u
        return expr + self.fg
        
    @property
    def u(self):
        return self._u[self.wrt]
        
    @property
    def phi(self):
        return self._phi[self.wrt]
    
    @phi.setter
    def phi(self, value):
        self._phi[self.wrt] = value
        
    @property
    def fg(self):
        return self._fg[self.wrt]
    
    @property
    def moments(self):
        return [Symbol(f'{self.__class__.__name__}Moment_{k}') for k in range(1, self.N_MOMENTS + 1)]  # temporary! fill me in!
    
    def parameterize_with_respect_to(self, symbol):
        if symbol not in self.symbols:
            raise Exception(f'Must parameterize w.r.t. one of the following: {self.symbols}')
        self.wrt = symbol
        
    def copy(self):
        raise NotImplementedError
        
    def compose_subs_dict(self):
        return {p: getattr(self, f'_{p.name}') for p in self.params}
        
    def _test(self):
        exprs = []
        for s in self.symbols:
            self.parameterize_with_respect_to(s)
            exprs.append(self.expr)
        assert all([sympy.simplify(e - exprs[0]) == 0 for e in exprs])
    
    
class LogUnivariateGaussian(ExpFamDensity):
    
    x = VALID_SYMBOLS['x']
    mu = VALID_SYMBOLS['mu']
    gamma = VALID_SYMBOLS['gamma']
    symbols = {x, mu, gamma}
    
    _u = {
        x: Matrix([x, x**2]),
        mu: Matrix([mu, mu**2]),
        gamma: Matrix([gamma, sympy.log(gamma)])
    }
    _phi = {
        x: Matrix([gamma * mu, -.5 * gamma]),
        mu: Matrix([gamma * x, -.5 * gamma]),
        gamma: Matrix([(mu * x) - (.5 * x**2) - (.5 * mu**2), .5])
    }
    _fg = {
        x: .5 * (sympy.log(gamma) - gamma * mu**2 - sympy.log(2 * sympy.pi)),
        mu: .5 * (sympy.log(gamma) - gamma * x**2 - sympy.log(2 * sympy.pi)),
        gamma: -.5 * sympy.log(2 * sympy.pi)
    }
    
    def __init__(self, mu, gamma):
        super().__init__()
        self._mu = mu
        self._gamma = gamma
        
    def copy(self):
        return LogUnivariateGaussian(mu=self.mu, gamma=self.gamma)
        
    def __repr__(self):
        return f'LogUnivariateGaussian(mu={self.mu}, gamma={self.gamma})'
        
        
class LogGamma(ExpFamDensity):
    
    x = VALID_SYMBOLS['x']
    alpha = VALID_SYMBOLS['alpha']
    beta = VALID_SYMBOLS['beta']
    symbols = {x, alpha, beta}
    
    _u = {
        x: Matrix([x, sympy.log(x)]),
        alpha: Matrix([alpha, -sympy.log(sympy.gamma(alpha))]),
        beta: Matrix([beta, sympy.log(beta)])
    }
    _phi = {
        x: Matrix([-beta, alpha]),
        alpha: Matrix([sympy.log(x) + sympy.log(beta), 1]),
        beta: Matrix([-x, alpha])
    }
    _fg = {
        x: alpha * sympy.log(beta) - sympy.log(sympy.gamma(alpha)) + 1 / x,
        alpha: -beta * x + 1 / x,
        beta: alpha * sympy.log(x) - sympy.log(sympy.gamma(alpha)) + 1 / x
    }
    
    def __init__(self, alpha, beta):
        super().__init__()
        self._alpha = alpha
        self._beta = beta
        
    def copy(self):
        return LogGamma(alpha=self.alpha, beta=self.beta)
    
    def __repr__(self):
        return f'LogGamma(alpha={self.alpha}, beta={self.beta})'

In [299]:
class Node:
    
    def __init__(self, name: str, density: ExpFamDensity, observed: tuple = None, kind: str = 'node'):
        self.name = name
        self.density = density
        self.observed = observed
        self.kind = kind
        self.parents = {}
        self.children = {}
        
    def add_child(self, param, child):
        self.children[param] = child
        
    def add_parent(self, param, parent):
        if param not in self.density.params:
            raise Exception(f'Parent must be set for one of the following: {self.density.params}')
        if not isinstance(parent, Node):
            raise Exception(f'Parent must be of type `Node`')
        self.parents[param] = parent
        parent.add_child(param, self)
        
    def copy(self):
        return Node(name=self.name, density=self.density.copy(), observed=self.observed, kind='variational_node')
    
    @property
    def parent_to_param(self):
        return {v: k for k, v in self.parents.items()}
    
    def __repr__(self):
        return f'{self.density} (name: {self.name})'
        

class Graph:
    
    def __init__(self, nodes: set):
        self.nodes = {node.name: node for node in nodes}
#         self.Q = {node: node.copy() for node in nodes}
        self.parents = {node.name: node for node in nodes if not node.parents}
        
        
def compute_required_moments(expr):
    moments = set()
    for e in expr.values():
        for s in e.free_symbols:
            for arg in e.args:
                d = copy(arg)
                count = 0
                while d != 0:
                    d = diff(d, s)
                    count += 1
                if count > 1:
                    moments.add((s, count - 1))
    return moments


def compute_moment_substitutions(x, child_node, moments):
    subs = {}
    for sym, moment in moments:
        node = child_node if sym == x else child_node.parents[sym]
        if node.observed:
            subs[sym] = tuple(child_node.observed)
        else:
            subs[pow(sym, moment)] = node.density.moments[moment - 1]
    subs_ = []
    # If the child is observed, we must send a message w.r.t. each observed value to the parent
    if child_node.observed:
        vals = subs[x]
        for val in vals:
            subs = dict(subs)
            subs[x] = val
            subs_.append(subs)
    else:
        subs_ = [subs]
    return subs_

In [301]:
# Instantiate nodes
obs = Node(name='observed', density=LogUnivariateGaussian(0, 1), observed=(1, 2.5, -.5, .25))
gamma = Node(name='gamma', density=LogGamma(0, 1))
mu = Node(name='mu', density=LogUnivariateGaussian(0, 1))

# Add parents, children
obs.add_parent(param=obs.density.mu, parent=mu)
obs.add_parent(param=obs.density.gamma, parent=gamma)

# Build graph
g = Graph(nodes={obs, gamma, mu})

In [302]:
# Update mu

# NB: the observed variable in each random variable is denoted x

In [303]:
count = 0


while count < 10:
    for parent_name, parent_node in g.parents.items():
        
        # Initialize child-to-parent message
        parent_node_subs = parent_node.density.compose_subs_dict()
        parent_node_phi = parent_node.density.phi.subs(parent_node_subs)

        # Compute message counterparts
        for parent, child_node in g.nodes[parent_name].children.items():
            child_node.density.parameterize_with_respect_to(parent)
            co_parents = {s for s in child_node.density.symbols if s not in {x, parent}}
            x = child_node.density.x
            assert child_node.density.phi.free_symbols - {x} == co_parents

            # Compute the co-parent moments that we require, to be substituted into this child's natural parameter vector
            moments = compute_required_moments(child_node.density.phi)
            assert co_parents | {x} == {sym for sym, moment in moments}

            # Compute substitutions required for each individual child-to-parent message
            subs = compute_moment_substitutions(x, child_node, moments)

            # Compute sum message
            message = child_node.density.phi * 0
            for s in subs:
                message += child_node.density.phi.subs(s)
                
            # Set new message
            parent_node.density.phi = parent_node_phi + message

            count += 1
            
            print(f'count: {count} | parent_node: {parent_node} | child_node: {child_node}')
            print(f'new_phi: {parent_node_phi + message}\n')

count: 1 | parent_node: LogGamma(alpha=alpha, beta=beta) (name: gamma) | child_node: LogUnivariateGaussian(mu=mu, gamma=gamma) (name: observed)
new_phi: Matrix([[3.25*LogUnivariateGaussianMoment_1 - 2.0*LogUnivariateGaussianMoment_2 - 4.78125], [2.00000000000000]])

count: 2 | parent_node: LogUnivariateGaussian(mu=mu, gamma=gamma) (name: mu) | child_node: LogUnivariateGaussian(mu=mu, gamma=gamma) (name: observed)
new_phi: Matrix([[3.25*LogGammaMoment_1], [-2.0*LogGammaMoment_1 - 0.5]])

count: 3 | parent_node: LogGamma(alpha=alpha, beta=beta) (name: gamma) | child_node: LogUnivariateGaussian(mu=mu, gamma=gamma) (name: observed)
new_phi: Matrix([[6.5*LogUnivariateGaussianMoment_1 - 4.0*LogUnivariateGaussianMoment_2 - 8.5625], [4.00000000000000]])

count: 4 | parent_node: LogUnivariateGaussian(mu=mu, gamma=gamma) (name: mu) | child_node: LogUnivariateGaussian(mu=mu, gamma=gamma) (name: observed)
new_phi: Matrix([[6.5*LogGammaMoment_1], [-4.0*LogGammaMoment_1 - 0.5]])

count: 5 | parent_n