In [1]:
import sympy
from sympy import Symbol, Matrix

In [189]:
class ExpFamDensity:
    
    def __init__(self):
        self._test()
        self.wrt = self.x
        self.params = {s for s in self.symbols if s != self.x}
    
    @property
    def expr(self):
        expr, = self.phi.T * self.u
        return expr + self.fg
    
    @property
    def symbol_names(self):
        return {s.name for s in self.symbols}
        
    @property
    def u(self):
        return self._u[self.wrt]
        
    @property
    def phi(self):
        return self._phi[self.wrt]
        
    @property
    def fg(self):
        return self._fg[self.wrt]
    
    def parameterize_with_respect_to(self, symbol):
        if symbol not in self.symbols:
            raise Exception(f'Must parameterize w.r.t. one of the following an object of type `sympy.Symbol`s: {symbol_names}')
        self.wrt = symbol
        
    def copy(self):
        raise NotImplementedError
        
    def _test(self):
        exprs = []
        for s in self.symbols:
            self.parameterize_with_respect_to(s)
            exprs.append(self.expr)
        assert all([sympy.simplify(e - exprs[0]) == 0 for e in exprs])
    
    
class LogUnivariateGaussian(ExpFamDensity):
    
    x = Symbol('x')
    mu = Symbol('mu')
    gamma = Symbol('gamma')
    symbols = [x, mu, gamma]
    
    _u = {
        x: Matrix([x, x**2]),
        mu: Matrix([mu, mu**2]),
        gamma: Matrix([gamma, sympy.log(gamma)])
    }
    _phi = {
        x: Matrix([gamma * mu, -.5 * gamma]),
        mu: Matrix([gamma * x, -.5 * gamma]),
        gamma: Matrix([(mu * x) - (.5 * x**2) - (.5 * mu**2), .5])
    }
    _fg = {
        x: .5 * (sympy.log(gamma) - gamma * mu**2 - sympy.log(2 * sympy.pi)),
        mu: .5 * (sympy.log(gamma) - gamma * x**2 - sympy.log(2 * sympy.pi)),
        gamma: -.5 * sympy.log(2 * sympy.pi)
    }
    
    def __init__(self, mu, gamma):
        super().__init__()
        self.mu = mu
        self.gamma = gamma
        
    def copy(self):
        return LogUnivariateGaussian(mu=self.mu, gamma=self.gamma)
    
    def __repr__(self):
        return f'LogUnivariateGaussian(mu={self.mu}, gamma={self.gamma})'
        
        
class LogGamma(ExpFamDensity):
    
    x = Symbol('x')
    alpha = Symbol('alpha')
    beta = Symbol('beta')
    symbols = [x, alpha, beta]
    
    _u = {
        x: Matrix([x, sympy.log(x)]),
        alpha: Matrix([alpha, -sympy.log(sympy.gamma(alpha))]),
        beta: Matrix([beta, sympy.log(beta)])
    }
    _phi = {
        x: Matrix([-beta, alpha]),
        alpha: Matrix([sympy.log(x) + sympy.log(beta), 1]),
        beta: Matrix([-x, alpha])
    }
    _fg = {
        x: alpha * sympy.log(beta) - sympy.log(sympy.gamma(alpha)) + 1 / x,
        alpha: -beta * x + 1 / x,
        beta: alpha * sympy.log(x) - sympy.log(sympy.gamma(alpha)) + 1 / x
    }
    
    def __init__(self, alpha, beta):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        
    def copy(self):
        return LogGamma(alpha=self.alpha, beta=self.beta)
    
    def __repr__(self):
        return f'LogGamma(alpha={self.alpha}, beta={self.beta})'

In [190]:
class Node:
    
    def __init__(self, density: ExpFamDensity, observed: tuple = None, kind: str = 'node'):
        self.density = density
        self.observed = observed
        self.parents = []
        self.children = []
        
    def add_parents(self, *parents):
        self.parents.extend(parents)
        
    def add_children(self, *children):
        self.children.extend(children)
        
    def copy(self):
        n = Node(density=self.density.copy(), kind='variational_node')
        n.add_parents(*self.parents)
        n.add_children(*self.children)
        return n
    
    def __repr__(self):
        return f"{self.name}: {self.density}"
        

class Graph:
    
    def __init__(self, nodes: dict):
        self.nodes = {name: node for name, node in nodes.items()}
        self.Q = {name: node.copy() for name, node in nodes.items()}

In [191]:
x.density.symbols

[x, mu, gamma]

In [192]:
# Instantiate nodes
x = Node(density=LogUnivariateGaussian(0, 1), observed=(1, 1.5, -.5, .25))
gamma = Node(density=LogGamma(0, 1))
mu = Node(density=LogUnivariateGaussian(0, 1))

# Add parents, children
x.add_parents(mu, gamma)
gamma.add_children(x)
mu.add_children(x)

# Build graph
g = Graph(nodes={'x': x, 'gamma': gamma, 'mu': mu})

In [193]:
# Update mu

In [194]:
g.nodes['mu'].children[0].density

LogUnivariateGaussian(mu=0, gamma=1)

In [195]:
g.nodes['mu'].density.params

{gamma, mu}

In [196]:
g.nodes['mu'].children[0].density.x

x

In [96]:
g.nodes['mu'].children[0].density.parameterize_with_respect_to('mu')

In [109]:
{s.name for s in g.nodes['mu'].children[0].density.phi.free_symbols}

{'gamma', 'mu'}