In [1]:
import ast
import inspect
import numpy
import pystan
import sys

def parse_model(python_function):
    class PythonAstVisitor(ast.NodeVisitor):
        def __init__(self):
            self.indent = 0
            self.label = None
        def generic_visit(self, node):
            result = ''.join(['  ' for _ in range(self.indent)])
            kind = type(node).__name__
            if self.label is None:
                result += '//missing visit method for {}\n'.format(kind)
            else:
                result += '//{} = {}\n'.format(self.label, kind)
            self.indent += 1
            for name, value in ast.iter_fields(node):
                if isinstance(value, ast.AST):
                    self.label = name
                    result += self.visit(value)
                elif isinstance(value, list):
                    for i in range(len(value)):
                        self.label = '{}[{:d}]'.format(name, i)
                        result += self.visit(value[i])
                else:
                    result += ''.join(['  ' for _ in range(self.indent)])
                    result += '//{} = {}\n'.format(name, value)
            self.indent -= 1
            self.label = None
            return result
        def visit_AnnAssign(self, node):
            #eg. 'x: int(lower=0,upper=1)[10]' -> 'int<lower=0,upper=1> x[10];'
            indent = ''.join(['  ' for _ in range(self.indent)])
            assert node.value is None and node.simple == 1
            identifier = self.visit(node.target)
            if isinstance(node.annotation, ast.Subscript):
                type_ast = node.annotation.value
                dims = '[{:d}]'.format(node.annotation.slice.value.n)
            else:
                type_ast = node.annotation
                dims = ''
            if isinstance(type_ast, ast.Call):
                typ = self.visit(type_ast.func) + '<'
                for i in range(len(type_ast.keywords)):
                    if i > 0:
                        typ += ', '
                    typ += self.visit(type_ast.keywords[i])
                typ += '>'
            else:
                typ = self.visit(type_ast)
            return indent + typ + ' ' + identifier + dims + ';\n'
        def visit_Assign(self, node):
            #eg. 'theta =~ uniform(0, 1)' -> 'theta ~ uniform(0, 1);'
            indent = ''.join(['  ' for _ in range(self.indent)])
            assert len(node.targets) == 1
            lhs = self.visit(node.targets[0])
            if isinstance(node.value, ast.UnaryOp):
                assert isinstance(node.value.op, ast.Invert)
                op = '~'
                rhs = self.visit(node.value.operand)
            else:
                op = '='
                rhs = self.visit(node.targets[0])
            return '{}{} {} {};\n'.format(indent, lhs, op, rhs)
        def visit_Call(self, node):
            if node.func.id is 'range':
                lower = self.visit(node.args[0])
                upper = self.visit(node.args[1])
                result = '{}:({}-1)'.format(lower, upper)
            else:
                result = self.visit(node.func) + '('
                for i in range(len(node.args)):
                    if i > 0:
                        result += ', '
                    result += self.visit(node.args[i])
                result += ')'
            return result
        def visit_For(self, node):
            indent = ''.join(['  ' for _ in range(self.indent)])
            tgt = self.visit(node.target)
            src = self.visit(node.iter)
            result = '{}for ({} in {})\n'.format(indent, tgt, src)
            self.indent += 1
            assert len(node.body) == 1
            result += self.visit(node.body[0])
            self.indent -= 1
            return result
        def visit_FunctionDef(self, node):
            result = ''
            self.indent = 1
            for stmt in node.body:
                if isinstance(stmt, ast.Expr):
                    assert isinstance(stmt.value, ast.Str)
                    if len(result) > 0:
                        result = result + '}\n'
                    result = result + stmt.value.s + ' {\n'
                else:
                    result += self.visit(stmt)
            if len(result) > 0:
                result = result + '}'
            self.indent = 0
            return result
        def visit_Index(self, node):
            return self.visit(node.value)
        def visit_keyword(self, node):
            lhs = node.arg
            rhs = self.visit(node.value)
            return '{}={}'.format(lhs, rhs)
        def visit_Module(self, node):
            assert len(node.body) == 1
            return self.visit(node.body[0])
        def visit_Name(self, node):
            return node.id
        def visit_Num(self, node):
            return str(node.n)
        def visit_Subscript(self, node):
            base = self.visit(node.value)
            index = self.visit(node.slice)
            return '{}[{}]'.format(base, index)
    source = inspect.getsource(python_function)
    tree = ast.parse(source)
    visitor = PythonAstVisitor()
    return visitor.visit(tree)

In [3]:
def coin_model():
    'data'
    x: int(lower=0, upper=1)[10]

    'parameters'
    theta: real(lower=0, upper=1)
    
    'model'
    theta =~ uniform(0, 1)
    for i in range(1, 11):
        x[i] =~ bernoulli(theta)

coin_code = parse_model(coin_model)

print(coin_code)

data {
  int<lower=0, upper=1> x[10];
}
parameters {
  real<lower=0, upper=1> theta;
}
model {
  theta ~ uniform(0, 1);
  for (i in 1:(11-1))
    x[i] ~ bernoulli(theta);
}


In [4]:
data = {'x': [0, 1, 0, 0, 0, 0, 0, 0, 0, 1]}

fit = pystan.stan(model_code=coin_code, data=data, iter=1000)

samples = fit.extract()['theta']

INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_323e2ca3758f46e04e26613ad95488f2 NOW.


In [5]:
print("Posterior mean:", numpy.mean(samples))
print("Posterior stddev:", numpy.std(samples))

Posterior mean: 0.25057515479367953
Posterior stddev: 0.11763560383571034
