In [19]:
import sympy
from sympy import Symbol, Matrix, Function, diff

In [2]:
VALID_SYMBOLS = {
    'x': Symbol('x'),
    'alpha': Symbol('alpha'),
    'beta': Symbol('beta'),
    'mu': Symbol('mu'),
    'gamma': Symbol('gamma'),
}


class ExpFamDensity:
    
    N_MOMENTS = 2
    
    def __init__(self):
        self._test()
        assert self.symbols.issubset(VALID_SYMBOLS.values())
        self.wrt = self.x
        self.params = {s for s in self.symbols if s != self.x}
    
    @property
    def expr(self):
        expr, = self.phi.T * self.u
        return expr + self.fg
        
    @property
    def u(self):
        return self._u[self.wrt]
        
    @property
    def phi(self):
        return self._phi[self.wrt]
        
    @property
    def fg(self):
        return self._fg[self.wrt]
    
    @property
    def moments(self):
        e = Function('Expectation')
        return [e(pow(self.x, k)) for k in range(1, self.N_MOMENTS + 1)]  # temporary! fill me in!
    
    def parameterize_with_respect_to(self, symbol):
        if symbol not in self.symbols:
            raise Exception(f'Must parameterize w.r.t. one of the following: {self.symbols}')
        self.wrt = symbol
        
    def copy(self):
        raise NotImplementedError
        
    def _test(self):
        exprs = []
        for s in self.symbols:
            self.parameterize_with_respect_to(s)
            exprs.append(self.expr)
        assert all([sympy.simplify(e - exprs[0]) == 0 for e in exprs])
    
    
class LogUnivariateGaussian(ExpFamDensity):
    
    x = VALID_SYMBOLS['x']
    mu = VALID_SYMBOLS['mu']
    gamma = VALID_SYMBOLS['gamma']
    symbols = {x, mu, gamma}
    
    _u = {
        x: Matrix([x, x**2]),
        mu: Matrix([mu, mu**2]),
        gamma: Matrix([gamma, sympy.log(gamma)])
    }
    _phi = {
        x: Matrix([gamma * mu, -.5 * gamma]),
        mu: Matrix([gamma * x, -.5 * gamma]),
        gamma: Matrix([(mu * x) - (.5 * x**2) - (.5 * mu**2), .5])
    }
    _fg = {
        x: .5 * (sympy.log(gamma) - gamma * mu**2 - sympy.log(2 * sympy.pi)),
        mu: .5 * (sympy.log(gamma) - gamma * x**2 - sympy.log(2 * sympy.pi)),
        gamma: -.5 * sympy.log(2 * sympy.pi)
    }
    
    def __init__(self, mu, gamma):
        super().__init__()
        self._mu = mu
        self._gamma = gamma
        
    def copy(self):
        return LogUnivariateGaussian(mu=self.mu, gamma=self.gamma)
    
    def __repr__(self):
        return f'LogUnivariateGaussian(mu={self.mu}, gamma={self.gamma})'
        
        
class LogGamma(ExpFamDensity):
    
    x = VALID_SYMBOLS['x']
    alpha = VALID_SYMBOLS['alpha']
    beta = VALID_SYMBOLS['beta']
    symbols = {x, alpha, beta}
    
    _u = {
        x: Matrix([x, sympy.log(x)]),
        alpha: Matrix([alpha, -sympy.log(sympy.gamma(alpha))]),
        beta: Matrix([beta, sympy.log(beta)])
    }
    _phi = {
        x: Matrix([-beta, alpha]),
        alpha: Matrix([sympy.log(x) + sympy.log(beta), 1]),
        beta: Matrix([-x, alpha])
    }
    _fg = {
        x: alpha * sympy.log(beta) - sympy.log(sympy.gamma(alpha)) + 1 / x,
        alpha: -beta * x + 1 / x,
        beta: alpha * sympy.log(x) - sympy.log(sympy.gamma(alpha)) + 1 / x
    }
    
    def __init__(self, alpha, beta):
        super().__init__()
        self._alpha = alpha
        self._beta = beta
        
    def copy(self):
        return LogGamma(alpha=self.alpha, beta=self.beta)
    
    def __repr__(self):
        return f'LogGamma(alpha={self.alpha}, beta={self.beta})'

In [3]:
class Node:
    
    def __init__(self, name: str, density: ExpFamDensity, observed: tuple = None, kind: str = 'node'):
        self.name = name
        self.density = density
        self.observed = observed
        self.parents = {}
        self.children = {}
        
    def add_child(self, param, child):
        self.children[param] = child
        
    def add_parent(self, param, parent):
        if param not in self.density.params:
            raise Exception(f'Parent must be set for one of the following: {self.density.params}')
        if not isinstance(parent, Node):
            raise Exception(f'Parent must be of type `Node`')
        self.parents[param] = parent
        parent.add_child(param, self)
        
    def copy(self):
        n = Node(name=self.name, density=self.density.copy(), kind='variational_node')
        for param, parent in self.parents.items():
            n.add_parent(param, parent)
        return n
    
    @property
    def parent_to_param(self):
        return {v: k for k, v in self.parents.items()}
    
    def __repr__(self):
        return f'{self.density} (name: {self.name})'
        

class Graph:
    
    def __init__(self, nodes: set):
        self.nodes = {node.name: node for node in nodes}
        self.Q = {node.name: node.copy() for node in nodes}

In [4]:
# Instantiate nodes
x = Node(name='observed', density=LogUnivariateGaussian(0, 1), observed=(1, 1.5, -.5, .25))
gamma = Node(name='gamma', density=LogGamma(0, 1))
mu = Node(name='mu', density=LogUnivariateGaussian(0, 1))

# Add parents, children
x.add_parent(param=x.density.mu, parent=mu)
x.add_parent(param=x.density.gamma, parent=gamma)

# Build graph
g = Graph(nodes={x, gamma, mu})

In [5]:
# Update mu

In [20]:
for parent, child_node in g.nodes['mu'].children.items():
    child_node.density.parameterize_with_respect_to(parent)
    co_parents = {s for s in child_node.density.symbols if s not in {child_node.density.x, parent}}
    assert child_node.density.phi.free_symbols - {child_node.density.x} == co_parents
    child_node.density.parameterize_with_respect_to(child_node.density.gamma)
    moments = {}
    for expr in child_node.density.phi.values():
        moments[expr] = []
        for s in expr.free_symbols:  # it's something specific to the child node, its params or something
            for arg in expr.args:
                d = arg.copy()
                count = 0
                while d != 0:
                    d = diff(d, s)
                    count += 1
                if count > 1:
                    moments[expr].append((s, count - 1))
    for cp in co_parents:
        cp_node = child_node.parents[cp]

In [21]:
moments

{-0.5*mu**2 + mu*x - 0.5*x**2: [(x, 2), (x, 1), (mu, 2), (mu, 1)],
 0.500000000000000: []}

In [28]:
cp_node.density.moments

[Expectation(x), Expectation(x**2)]

In [8]:
child_node.density.phi

Matrix([
[   gamma*x],
[-0.5*gamma]])