# PyMC4 — Hiding `yield` from the Model Specification API

Please refer to the `README.org` for context and discussion. In this notebook I will outline my proposal.

In [1]:
from datetime import datetime
print("Last run:", datetime.now())

Last run: 2019-07-10 14:13:35.934874


In [2]:
import __future__
import ast
import functools
import inspect
import re
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd

## AST Helper Functions

Based on http://code.activestate.com/recipes/578353-code-to-source-and-back/.

The main thing to take away from this cell block is that `uncompile` takes a Python object and returns its source code (along with a bunch of other things, but we're less interested in those), and `recompile` takes either:

1. the output of `uncompile`, or
2. a modified AST, and `compile`s (the Python built-in) it down to bytecode.
    
The output of `uncompile` can then be `exec`'ed or `eval`'ed.

In [3]:
PyCF_MASK = sum(v for k, v in vars(__future__).items() if k.startswith("CO_FUTURE"))


def uncompile(c):
    """uncompile(codeobj) -> [source, filename, mode, flags, firstlineno, privateprefix]."""
    if c.co_flags & inspect.CO_NESTED or c.co_freevars:
        raise NotImplementedError("Nested functions not supported")
    if c.co_name == "<lambda>":
        raise NotImplementedError("Lambda functions not supported")
    if c.co_filename == "<string>":
        raise NotImplementedError("Code without source file not supported")

    filename = inspect.getfile(c)

    try:
        lines, firstlineno = inspect.getsourcelines(c)
    except IOError:
        raise RuntimeError("Source code not available")

    source = "".join(lines)

    # __X is mangled to _ClassName__X in methods. Find this prefix:
    privateprefix = None
    for name in c.co_names:
        m = re.match("^(_[A-Za-z][A-Za-z0-9_]*)__.*$", name)
        if m:
            privateprefix = m.group(1)
            break

    return [source, filename, "exec", c.co_flags & PyCF_MASK, firstlineno, privateprefix]


def recompile(source, filename, mode, flags=0, firstlineno=1, privateprefix=None):
    """Recompile output of uncompile back to a code object. source may also be preparsed AST."""
    if isinstance(source, ast.AST):
        a = source
    else:
        a = parse_snippet(source, filename, mode, flags, firstlineno)

    node = a.body[0]

    if not isinstance(node, ast.FunctionDef):
        raise RuntimeError("Expecting function AST node")

    c0 = compile(a, filename, mode, flags, True)

    return c0


def parse_snippet(source, filename, mode, flags, firstlineno, privateprefix_ignored=None):
    """Like ast.parse, but accepts indented code snippet with a line number offset."""
    args = filename, mode, flags | ast.PyCF_ONLY_AST, True
    prefix = "\n"
    try:
        a = compile(prefix + source, *args)
    except IndentationError:
        # Already indented? Wrap with dummy compound statement
        prefix = "with 0:\n"
        a = compile(prefix + source, *args)
        # Peel wrapper
        a.body = a.body[0].body
    ast.increment_lineno(a, firstlineno - 2)
    return a

## PyMC4 Backend

Now, let's talk about what the backends need to look like.

First, a helper class to traverse and transform the AST of the user-defined model specification function. Half the magic is in this class: please read the docstring.

In [4]:
class FunctionToGenerator(ast.NodeTransformer):
    """
    This subclass traverses the AST of the user-written, decorated,
    model specification and transforms it into a generator for the
    model. Subclassing in this way is the idiomatic way to transform
    an AST.

    Specifically:
    
    1. Add `yield` keywords to all assignments
       E.g. `x = tfd.Normal(0, 1)` -> `x = yield tfd.Normal(0, 1)`
    2. Rename the model specification function to
       `_pm_compiled_model_generator`. This is done out an abundance
       of caution more than anything.
    3. Remove the @Model decorator. Otherwise, we risk running into
       an infinite recursion.
    """
    def visit_Assign(self, node):
        # TODO: AugAssign and AnnAssign nodes, for completeness.
        # https://greentreesnakes.readthedocs.io/en/latest/nodes.html#AugAssign
        # https://greentreesnakes.readthedocs.io/en/latest/nodes.html#AnnAssign
        new_node = node
        new_node.value = ast.Yield(value=new_node.value)

        # Tie up loose ends in the AST.
        # FIXME: I may be cargo-culting what I've read in docs and tutorials.
        ast.copy_location(new_node, node)
        ast.fix_missing_locations(new_node)
        self.generic_visit(node)
        return new_node
    
    def visit_FunctionDef(self, node):
        new_node = node
        new_node.name = "_pm_compiled_model_generator"
        new_node.decorator_list = []

        # FIXME: Some more cargo-culting.
        ast.copy_location(new_node, node)
        ast.fix_missing_locations(new_node)
        self.generic_visit(node)
        return new_node

And now for the `pm.Model` decorator. Instead of a function, our decorator [will be a class](https://realpython.com/primer-on-python-decorators/#classes-as-decorators). This allows us to have a stateful decorator, where we can store model-related things (e.g. the AST and the generator) and even define user-facing functions such as `sample` or `observe`. The other half of the magic is in this class: please read the comments and docstrings.

In [5]:
class Model:
    """ pm.Model decorator. """

    def __init__(self, func):
        self.func = func

        # Introspect wrapped function, instead of the decorator class.
        functools.update_wrapper(self, func)

        # Uncompile wrapped function
        uncompiled = uncompile(func.__code__)

        # Parse AST and modify it
        tree = parse_snippet(*uncompiled)
        tree = FunctionToGenerator().visit(tree)
        uncompiled[0] = tree

        # Recompile wrapped function
        self.recompiled = recompile(*uncompiled)    
        
        # Execute recompiled code (defines `_pm_compiled_model_generator`)
        # in the locals() namespace and assign it to an attribute.
        # Refer to http://lucumr.pocoo.org/2011/2/1/exec-in-python/
        # FIXME: Need to understand locals() and namespaces more.
        exec(self.recompiled, None, locals())
        self.model_generator = locals()["_pm_compiled_model_generator"]

    """
    The following three functions aren't necessary for the rest of the notebook.
    I just want to point out that this would be natural places to define these functions.
    Refer to the "User-Facing API" section (below) for why.
    """
        
    def __call__(self, *args, **kwargs):
        # Could be something like what we have already:
        # https://github.com/pymc-devs/pymc4/blob/master/pymc4/coroutine_model.py#L63
        raise NotImplementedError("Evaluate model, as in `coroutine_model.py`.")

    def sample(self, *args, **kwargs):
        raise NotImplementedError("George isn't sure how sampling works.")

    def observe(self, *args, **kwargs):
        raise NotImplementedError("George isn't sure how observing works, either.")

## User-Facing Model Specification API

And now all users need to see is this:

In [6]:
@Model
def linear_regression(x):
    scale = tfd.HalfCauchy(0, 1)
    coefs = tfd.Normal(tf.zeros(x.shape[1]), 1)
    predictions = tfd.Normal(tf.linalg.matvec(x, coefs), scale)
    return predictions

### What else can we do in the `Model` decorator?

1. If we define `__call__`, then users can run `predictions = linear_regression(tf.zeros([3, 10]))`. I am unsure what we would want this to return. Note that this will **not** be as straightfoward as

    ```python
    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)
    ```
   since (currently) `self.func` is the user-defined function that crashes (just as in @ferrine's example). More bluntly, users will be writing a function that, without the `@Model` decorator, crashes. On the other hand, if we _don't_ implement `__call__`, users will write a function and get back a `Model` object that _cannot be called_, as you would expect a function to be. Tricky situation; food for thought; feedback needed!

2. If we define `sample`, then users can sample from their model via `linear_regression.sample()`.

3. If we define `observe`, then users can provide observations to their model via `linear_regression.observe()` (as suggested by @ferrine and @rpgoldman).

In [7]:
# All three statements will raise NotImplementedErrors.

predictions = linear_regression(tf.zeros([3, 10]))
linear_regression.sample()
linear_regression.observe()

NotImplementedError: Evaluate model, as in `coroutine_model.py`.

## PyMC4 Core Engine

We can get the generator in exactly the same way that @ferrine's notebook requires:

In [8]:
linear_regression.model_generator(tf.zeros([3, 10]))

<generator object _pm_compiled_model_generator at 0x107a5c5c8>

Success!!

In fact, to demonstrate that it's actually the generator we need (and that there aren't subtle bugs along the way), we can interact with the generator in exactly the same way as in @ferrine's notebook.

I've omitted the "One level deeper" section in the notebook: that is, recursively interacting with the generator. I haven't tested it out, but I expect that it would also work.

In [9]:
# Taken from https://gist.github.com/ferrine/59a63c738e03911eacba515b5be904ad

def interact(gen, state):
    control_flow = gen()
    return_value = None
    while True:
        try:
            dist = control_flow.send(return_value)
            if dist.name in state["dists"]:
                control_flow.throw(RuntimeError(
                    "We found duplicate names in your cool model: {}, "
                    "so far we have other variables in the model, {}".format(
                        preds_dist.name, set(state["dists"].keys()), 
                    )
                ))
            if dist.name in state["samples"]:
                return_value = state["samples"][dist.name]
            else:
                return_value = dist.sample()
                state["samples"][dist.name] = return_value
            state["dists"][dist.name] = dist
        except StopIteration as e:
            if e.args:
                return_value = e.args[0]
            else:
                return_value = None
            break
    return return_value, state

In [10]:
preds, state = interact(lambda: linear_regression.model_generator(tf.zeros([3, 10])),
                        state=dict(dists=dict(), samples=dict()))

In [11]:
preds

<tf.Tensor 'Normal_1/sample/Reshape:0' shape=(3,) dtype=float32>

In [12]:
state

{'dists': {'HalfCauchy/': <tfp.distributions.HalfCauchy 'HalfCauchy/' batch_shape=[] event_shape=[] dtype=float32>,
  'Normal/': <tfp.distributions.Normal 'Normal/' batch_shape=[10] event_shape=[] dtype=float32>,
  'Normal_1/': <tfp.distributions.Normal 'Normal_1/' batch_shape=[3] event_shape=[] dtype=float32>},
 'samples': {'HalfCauchy/': <tf.Tensor 'HalfCauchy/sample/Reshape:0' shape=() dtype=float32>,
  'Normal/': <tf.Tensor 'Normal/sample/Reshape:0' shape=(10,) dtype=float32>,
  'Normal_1/': <tf.Tensor 'Normal_1/sample/Reshape:0' shape=(3,) dtype=float32>}}

## Discussion and Next Steps

Please refer to the `README.org`.

## Environment

In [13]:
!python --version
!cat requirements.txt

Python 3.6.7 :: Anaconda, Inc.
jupyter==1.0.0
tensorflow==1.14.0
tensorflow-probability==0.7.0
