In [14]:
import sympy
from sympy import Symbol, Matrix

In [112]:
class Distribution:
    
    @property
    def expr(self):
        expr, = self.phi.T * self.u
        return expr + self.fg
    
    @property
    def symbol_names(self):
        return [s.name for s in self.symbols]
    
    def set_with_respect_to(self, name):
        if name not in self.symbol_names:
            raise Exception(f'{name} not a valid parameter of this distribution')
        self.wrt = name
        
    def test(self):
        exprs = []
        for name in self.symbol_names:
            self.set_with_respect_to(name)
            exprs.append(self.expr)
        assert all([sympy.simplify(e - exprs[0]) == 0 for e in exprs])
    
    
class LogUnivariateGaussian(Distribution):
    
    x = Symbol('x')
    mu = Symbol('mu')
    tau = Symbol('tau')
    symbols = [x, mu, tau]
    
    _u = {
        'x': Matrix([x, x**2]),
        'mu': Matrix([mu, mu**2]),
        'tau': Matrix([tau, sympy.log(tau)])
    }
    _phi = {
        'x': Matrix([tau * mu, -.5 * tau]),
        'mu': Matrix([tau * x, -.5 * tau]),
        'tau': Matrix([(mu * x) - (.5 * x**2) - (.5 * mu**2), .5])
    }
    _fg = {
        'x': .5 * (sympy.log(tau) - tau * mu**2 - sympy.log(2 * sympy.pi)),
        'mu': .5 * (sympy.log(tau) - tau * x**2 - sympy.log(2 * sympy.pi)),
        'tau': -.5 * sympy.log(2 * sympy.pi)
    }
    
    def __init__(self, mu, tau, wrt='x'):
        self.test()
        self.wrt = wrt
        
    @property
    def u(self):
        return self._u[self.wrt]
        
    @property
    def phi(self):
        return self._phi[self.wrt]
        
    @property
    def fg(self):
        return self._fg[self.wrt]

In [113]:
lug = LogUnivariateGaussian(0, 1)