# PyMC4 — Hiding `yield` from the Model Specification API

Please refer to the `README.org` for context and discussion. In this notebook I will outline my proposal.

In [1]:
from datetime import datetime
print("Last run:", datetime.now())

Last run: 2019-07-15 11:51:01.074677


In [2]:
import __future__
import ast
import functools
import inspect
import re
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd

## AST Helper Functions

Based on http://code.activestate.com/recipes/578353-code-to-source-and-back/.

The main thing to take away from this cell block is that `uncompile` takes a Python object and returns its source code (along with a bunch of other things, but we're less interested in those), and `recompile` takes either:

1. the output of `uncompile`, or
2. a modified AST, and `compile`s (the Python built-in) it down to bytecode.
    
The output of `uncompile` can then be `exec`'ed or `eval`'ed.

In [3]:
PyCF_MASK = sum(v for k, v in vars(__future__).items() if k.startswith("CO_FUTURE"))


def uncompile(c):
    """uncompile(codeobj) -> [source, filename, mode, flags, firstlineno, privateprefix]."""
    if c.co_flags & inspect.CO_NESTED or c.co_freevars:
        raise NotImplementedError("Nested functions not supported")
    if c.co_name == "<lambda>":
        raise NotImplementedError("Lambda functions not supported")
    if c.co_filename == "<string>":
        raise NotImplementedError("Code without source file not supported")

    filename = inspect.getfile(c)

    try:
        lines, firstlineno = inspect.getsourcelines(c)
    except IOError:
        raise RuntimeError("Source code not available")

    source = "".join(lines)

    # __X is mangled to _ClassName__X in methods. Find this prefix:
    privateprefix = None
    for name in c.co_names:
        m = re.match("^(_[A-Za-z][A-Za-z0-9_]*)__.*$", name)
        if m:
            privateprefix = m.group(1)
            break

    return [source, filename, "exec", c.co_flags & PyCF_MASK, firstlineno, privateprefix]


def recompile(source, filename, mode, flags=0, firstlineno=1, privateprefix=None):
    """Recompile output of uncompile back to a code object. Source may also be preparsed AST."""
    if isinstance(source, ast.AST):
        a = source
    else:
        a = parse_snippet(source, filename, mode, flags, firstlineno)

    node = a.body[0]

    if not isinstance(node, ast.FunctionDef):
        raise RuntimeError("Expecting function AST node")

    c0 = compile(a, filename, mode, flags, True)

    return c0


def parse_snippet(source, filename, mode, flags, firstlineno, privateprefix_ignored=None):
    """Like ast.parse, but accepts indented code snippet with a line number offset."""
    args = filename, mode, flags | ast.PyCF_ONLY_AST, True
    prefix = "\n"
    try:
        a = compile(prefix + source, *args)
    except IndentationError:
        # Already indented? Wrap with dummy compound statement
        prefix = "with 0:\n"
        a = compile(prefix + source, *args)
        # Peel wrapper
        a.body = a.body[0].body
    ast.increment_lineno(a, firstlineno - 2)
    return a

## PyMC4 Backend

Now, let's talk about what the backends need to look like.

First, a helper class to traverse and transform the AST of the user-defined model specification function. Half the magic is in this class: please read the docstring.

In [4]:
class FunctionToGenerator(ast.NodeTransformer):
    """
    This subclass traverses the AST of the user-written, decorated,
    model specification and transforms it into a generator for the
    model. Subclassing in this way is the idiomatic way to transform
    an AST.

    Specifically:
    
    1. Add `yield` or `yield from` keywords to distribution assignments
       E.g. `x = tfd.Normal(0, 1)` -> `x = yield tfd.Normal(0, 1)`
    2. Rename the model specification function to `_model_generator`.
    3. Remove the @Model decorator. Otherwise, we risk running into
       an infinite recursion.
    """
    def visit_Assign(self, node):
        # If the assigned value is not a function call, return the original node
        # I.e., do nothing to assignments like `N = 10`, `mu = x.mean`,
        # or even `x = yield tfd.Normal(0, 1)`
        if not isinstance(node.value, ast.Call):
            print("Skipping {}".format(node.value))
            return node
                
        # The function that is being called. E.g. `tfd.Normal`
        function = node.value.func
        
        if isinstance(function, ast.Name):
            # `function` is a user-defined function in the global namespace (e.g. "Horseshoe()")
            function_name = function.id
            assigned_value = globals()[function_name]
        elif isinstance(function, ast.Attribute):
            # `function` is a module function (e.g. `tfd.Normal()`)
            # Unroll _all_ levels of ast.Attributes... (E.g. `tfd.distributions.Normal()`)
            names = []
            while isinstance(function, ast.Attribute):
                names.append(function.attr)
                function = function.value
            else:
                module_name = function.id
                names.reverse()
            assigned_value = globals()[module_name]
            for name in names:
                assigned_value = getattr(assigned_value, name)
        else:
            # If it is not something we expect, it is better to fail fast than continue silently.
            msg = "Model contains assignment of {}. Expected assignments of ast.Names or ast.Attributes".format(type(function))
            raise RuntimeError(msg)
            
        if isinstance(assigned_value, tfd.distribution._DistributionMeta):
            # Add a `yield`
            print("Adding `yield` to {}".format(assigned_value))
            new_node = node
            new_node.value = ast.Yield(value=new_node.value)
        elif isinstance(assigned_value, Model):
            # Add a `yield from`, and call the model_generator, not the model itself
            print("Adding `yield from` to {}".format(assigned_value))
            new_node = node
            new_node.value = ast.YieldFrom(value=new_node.value)
            new_node.value.value.func = ast.Attribute(
                value=ast.Name(id=new_node.value.value.func.id, ctx=ast.Load()),
                attr="model_generator",
                ctx=ast.Load())
        else:
            # Must be some other function, like `tf.zeros()` or something
            print("Skipping {}".format(assigned_value))
            return node

        # Fix location (line numbers and column offsets) of replaced node.
        # Recursively visit child nodes.
        ast.copy_location(new_node, node)
        ast.fix_missing_locations(new_node)
        self.generic_visit(node)
        return new_node
    
    # TODO: edge case of augmented and type-annotated assignments.
    # visit_AugAssign = visit_Assign
    # visit_AnnAssign = visit_Assign
    
    def visit_FunctionDef(self, node):
        new_node = node
        new_node.name = "_model_generator"
        new_node.decorator_list = []

        # Fix location (line numbers and column offsets) of replaced node.
        # Also recursively visit child nodes.
        ast.copy_location(new_node, node)
        ast.fix_missing_locations(new_node)
        self.generic_visit(node)
        return new_node

And now for the `pm.Model` decorator. Instead of a function, our decorator [will be a class](https://realpython.com/primer-on-python-decorators/#classes-as-decorators). This allows us to have a stateful decorator, where we can store model-related things (e.g. the AST and the generator) and even define user-facing functions such as `sample` or `observe`. The other half of the magic is in this class: please read the comments and docstrings.

In [5]:
class Model:
    """ Model decorator. """

    def __init__(self, func):
        self.func = func

        # Introspect wrapped function, instead of the decorator class.
        functools.update_wrapper(self, func)

        # Uncompile wrapped function
        uncompiled = uncompile(func.__code__)

        # Parse AST and modify it
        tree = parse_snippet(*uncompiled)
        tree = FunctionToGenerator().visit(tree)
        uncompiled[0] = tree
        
        # Convert modified AST to readable Python source code.
        # FIXME: needs to be in a file!!
        # source = astor.to_source(tree)
        # uncompiled[1] = source

        # Recompile wrapped function
        self.recompiled = recompile(*uncompiled)
        
        # Execute recompiled code (defines `_model_generator`) in the
        # locals() namespace and assign it to an attribute. Refer to
        # http://lucumr.pocoo.org/2011/2/1/exec-in-python/
        exec(self.recompiled, None, locals())
        self.model_generator = locals()["_model_generator"]

    """
    The following three functions aren't necessary for the rest of the notebook.
    I just want to point out that this would be natural places to define these functions.
    Refer to the "User-Facing API" section (below) for why.
    """
    
    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)
    
    def sample(self, *args, **kwargs):
        raise NotImplementedError("George isn't sure how sampling works.")

    def observe(self, *args, **kwargs):
        raise NotImplementedError("George isn't sure how observing works, either.")

## User-Facing Model Specification API

And now all users need to see is this!

### One-function models

In [6]:
@Model
def linear_regression(x):
    scale = tfd.HalfCauchy(0, 1, name="scale")
    coefs = tfd.Normal(tf.zeros(x.shape[1]), 1, name="coefs")
    predictions = tfd.Normal(tf.linalg.matvec(x, coefs), scale, name="predictions")
    return predictions

Adding `yield` to <class 'tensorflow_probability.python.distributions.half_cauchy.HalfCauchy'>
Adding `yield` to <class 'tensorflow_probability.python.distributions.normal.Normal'>
Adding `yield` to <class 'tensorflow_probability.python.distributions.normal.Normal'>


In [7]:
@Model
def gaussian_process():
    # Works with "nested" distributions too... E.g. `tfd.foo.Distribution()`
    gaussian_process = tfd.gaussian_process.GaussianProcess(kernel=tf.identity(3))
    return gaussian_process

Adding `yield` to <class 'tensorflow_probability.python.distributions.gaussian_process.GaussianProcess'>


In [8]:
@Model
def linear_regression_with_yields(x):
    # Even works with users who _prefer_ to write `yield`...
    # (Notice how it prints nothing!)
    scale = yield tfd.HalfCauchy(0, 1, name="scale")
    coefs = yield tfd.Normal(tf.zeros(x.shape[1]), 1, name="coefs")
    predictions = yield tfd.Normal(tf.linalg.matvec(x, coefs), scale, name="predictions")
    return predictions

Skipping <_ast.Yield object at 0x126bbcbe0>
Skipping <_ast.Yield object at 0x126bbcb00>
Skipping <_ast.Yield object at 0x126bdf198>


### Multi-function ("nested") models

In [9]:
# From https://gist.github.com/ferrine/59a63c738e03911eacba515b5be904ad
@Model
def Horseshoe(mu=0., tau=1., s=1., name="horseshoe_scope"):
    with tf.name_scope(name):
        scale = tfd.HalfCauchy(0, s, name="scale")
        noise = tfd.Normal(0, tau, name="noise")
        return scale * noise + mu

@Model
def regularized_linear_regression(x):
    scale = tfd.HalfCauchy(0, 1, name="scale")
    coefs = Horseshoe(tf.zeros(x.shape[1]), name="coefs")
    predictions = tfd.Normal(tf.linalg.matvec(x, coefs), scale, name="predictions")
    return predictions

Adding `yield` to <class 'tensorflow_probability.python.distributions.half_cauchy.HalfCauchy'>
Adding `yield` to <class 'tensorflow_probability.python.distributions.normal.Normal'>
Adding `yield` to <class 'tensorflow_probability.python.distributions.half_cauchy.HalfCauchy'>
Adding `yield from` to <__main__.Model object at 0x126bdf358>
Adding `yield` to <class 'tensorflow_probability.python.distributions.normal.Normal'>


### What else can we do in the `Model` decorator?

1. If we define `__call__`, then users can run `predictions = linear_regression(tf.zeros([3, 10]))`. I am unsure what we would want this to return. Note that this will **not** be as straightfoward as

    ```python
    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)
    ```
   since (currently) `self.func` is the user-defined function that crashes (just as in @ferrine's example). More bluntly, users will be writing a function that, without the `@Model` decorator, crashes. On the other hand, if we _don't_ implement `__call__`, users will write a function and get back a `Model` object that _cannot be called_, as you would expect a function to be. Tricky situation; food for thought; feedback needed!

2. If we define `sample`, then users can sample from their model via `linear_regression.sample()`.

3. If we define `observe`, then users can provide observations to their model via `linear_regression.observe()` (as suggested by @ferrine and @rpgoldman).

In [10]:
# All three statements will raise NotImplementedErrors.

'''
predictions = linear_regression(tf.zeros([3, 10]))
linear_regression.sample()
linear_regression.observe()
''';

## PyMC4 Core Engine

We can get the generator in exactly the same way that @ferrine's notebook requires:

In [11]:
linear_regression.model_generator(tf.zeros([3, 10]))

<generator object _model_generator at 0x10baa8990>

In [12]:
regularized_linear_regression.model_generator(tf.zeros([3, 10]))

<generator object _model_generator at 0x126ba7830>

Success!!

In fact, to demonstrate that it's actually the generator we need (and that there aren't subtle bugs along the way), we can interact with the generator in exactly the same way as in @ferrine's notebook.

I've omitted the "One level deeper" section in the notebook: that is, recursively interacting with the generator. I haven't tested it out, but I expect that it would also work.

In [13]:
# Taken from https://gist.github.com/ferrine/59a63c738e03911eacba515b5be904ad

def interact(gen, state):
    control_flow = gen()
    return_value = None
    while True:
        try:
            dist = control_flow.send(return_value)
            if dist.name in state["dists"]:
                control_flow.throw(RuntimeError(
                    "We found duplicate names in your cool model: {}, "
                    "so far we have other variables in the model, {}".format(
                        preds_dist.name, set(state["dists"].keys()), 
                    )
                ))
            if dist.name in state["samples"]:
                return_value = state["samples"][dist.name]
            else:
                return_value = dist.sample()
                state["samples"][dist.name] = return_value
            state["dists"][dist.name] = dist
        except StopIteration as e:
            if e.args:
                return_value = e.args[0]
            else:
                return_value = None
            break
    return return_value, state

### One-function models

In [14]:
preds, state = interact(lambda: linear_regression.model_generator(tf.zeros([3, 10])),
                        state=dict(dists=dict(), samples=dict()))

In [15]:
print(preds, state, sep="\n\n")

tf.Tensor([174.78996   20.97928   47.357864], shape=(3,), dtype=float32)

{'dists': {'scale': <tfp.distributions.HalfCauchy 'scale' batch_shape=[] event_shape=[] dtype=float32>, 'coefs': <tfp.distributions.Normal 'coefs' batch_shape=[10] event_shape=[] dtype=float32>, 'predictions': <tfp.distributions.Normal 'predictions' batch_shape=[3] event_shape=[] dtype=float32>}, 'samples': {'scale': <tf.Tensor: id=37, shape=(), dtype=float32, numpy=108.2138>, 'coefs': <tf.Tensor: id=65, shape=(10,), dtype=float32, numpy=
array([ 1.3074126 ,  2.6881516 ,  0.06404222, -0.16281821, -0.58692205,
        1.5679766 , -0.5559973 ,  0.8432684 , -0.21408895,  1.4693985 ],
      dtype=float32)>, 'predictions': <tf.Tensor: id=93, shape=(3,), dtype=float32, numpy=array([174.78996 ,  20.97928 ,  47.357864], dtype=float32)>}}


### Multi-function ("nested") models

In [16]:
preds, state = interact(lambda: regularized_linear_regression.model_generator(tf.random.normal([3, 10])),
                        state=dict(dists=dict(), samples={"predictions/": tf.zeros(3)}))

In [17]:
print(preds, state, sep="\n\n")

tf.Tensor([ 60.178547    1.9002857 -78.334015 ], shape=(3,), dtype=float32)

{'dists': {'scale': <tfp.distributions.HalfCauchy 'scale' batch_shape=[] event_shape=[] dtype=float32>, 'coefs_scale': <tfp.distributions.HalfCauchy 'coefs_scale' batch_shape=[] event_shape=[] dtype=float32>, 'coefs_noise': <tfp.distributions.Normal 'coefs_noise' batch_shape=[] event_shape=[] dtype=float32>, 'predictions': <tfp.distributions.Normal 'predictions' batch_shape=[3] event_shape=[] dtype=float32>}, 'samples': {'predictions/': <tf.Tensor: id=101, shape=(3,), dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>, 'scale': <tf.Tensor: id=136, shape=(), dtype=float32, numpy=0.8348854>, 'coefs_scale': <tf.Tensor: id=169, shape=(), dtype=float32, numpy=15.626522>, 'coefs_noise': <tf.Tensor: id=195, shape=(), dtype=float32, numpy=1.3660036>, 'predictions': <tf.Tensor: id=225, shape=(3,), dtype=float32, numpy=array([ 60.178547 ,   1.9002857, -78.334015 ], dtype=float32)>}}


## Discussion and Next Steps

Please refer to the `README.org`.

## Environment

In [18]:
!python --version
!cat requirements.txt

Python 3.6.7 :: Anaconda, Inc.
jupyter==1.0.0
tensorflow==2.0.0-beta1
tfp-nightly
