<a href="https://colab.research.google.com/github/jpchen/playground/blob/master/torchfx_ppl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Useful program transformations for PPLs 
*@neerajprad, @jpchen, @xiaoyan0*

This notebook contains example probabilistic inference workflows as well as a complete example written in both JAX as well as PyTorch for comparison purposes. The code snippets are from the Leapfrog integrator that is used within a very popular MCMC inference algorithm called the No-U-Turn Sampler (NUTS) [[1](#scrollTo=3URBZvuod6-3&line=1&uniqifier=1)]. 

This code does not contain dependency on a particular PPL so that it is easier to compare differences between the two frameworks, particularly as it relates to features that will be important for any PyTorch-based PPL. Examples in Section 1 are small snippets in PyTorch for illustration purposes and code in section 2 further exemplifies these points by comparing a complete implementation of a Leapfrog integrator across JAX and PyTorch.

1. [Control Flow Examples](#scrollTo=0Pc-BvLTtcAM&uniqifier=1)
   - [Composition with grad](#scrollTo=mfmeUyh8voEM&uniqifier=1)
   - [Stochastic Control Flow](#scrollTo=TaVj14XQvDm-&uniqifier=1)
   - [Composition with Looping Primitives](#scrollTo=PsV7hBZzwpdl&uniqifier=1)
   - [Composition with JIT](#scrollTo=pjouiqlzutPX&uniqifier=1)
2. [JAX NUTS example: Leapfrog integrator](#scrollTo=rztOk3U_YUF8&uniqifier=1)
3. [PyTorch NUTS example: Leapfrog integrator](#scrollTo=hV_wTvKG9Xli&uniqifier=1)
4. [Helpful References](#scrollTo=3URBZvuod6-3&uniqifier=1)



---


## 1. Control Flow Examples

In [None]:
%reset -sf

#install nightly version for access to torch.vmap
!pip install --upgrade --pre torch==1.9.0.dev20210219 torchvision torchaudio -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html

import torch
import torch.distributions as dist
from torch.autograd import grad

print('pytorch version: ', torch.__version__)

Looking in links: https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html
Requirement already up-to-date: torch==1.9.0.dev20210219 in /usr/local/lib/python3.6/dist-packages (1.9.0.dev20210219)
Requirement already up-to-date: torchvision in /usr/local/lib/python3.6/dist-packages (0.9.0.dev20210219)
Requirement already up-to-date: torchaudio in /usr/local/lib/python3.6/dist-packages (0.8.0.dev20210219)
pytorch version:  1.9.0.dev20210219


### a) Composition with grad
At the core of the algorithm is a symplectic integrator that requires taking iterative gradients [[1](#scrollTo=PsV7hBZzwpdl&line=1&uniqifier=1)]. `vmap` or `jit` at the outermost loop would need to be able to handle this.

In [None]:
# This is one of the operations in the leapfrog integrator in [1]. The full integrator is 
# in the complete example in Sec. 2.

def grad_fn(node, params):
    # compute grad (of a stochastic value)
    grad = torch.autograd.grad(node, params)
    # update node
    new_node = node + grad[0]
    return new_node

params = [torch.tensor(0., requires_grad=True), torch.tensor(1., requires_grad=True)]
node = dist.Normal(*params).log_prob(torch.tensor(0.3))
grad_fn(node, params)

tensor(-0.6639, grad_fn=<AddBackward0>)

### b) Stochastic Control Flow
In most workflows we need a way to determine control flow pseudorandomly. JAX handles this through a local functional [RNG](https://jax.readthedocs.io/en/latest/jax.random.html). This is to say that both `jit` and `vmap` need to be able to handle the stochastic `accept > 0` condition. This is easier in JAX because there is no global RNG, and the JIT traces through the RNG splitting to be able to do this. Supporting this in PyTorch will be hard without a corresponding functional random number generator.

In [None]:
# (this is the Metropolis Hastings correction step in [1])
def accept_or_reject(v):
  # sample from prng. conditioned on the inputs in practice
  accept_prob = dist.Normal(0., 1.).log_prob(v)
  # accept `accept_prob` % of the time
  accept = dist.Bernoulli(logits=accept_prob).sample()
  if accept > 0:
    return torch.tensor(1.)
  return torch.tensor(0.)

### c) Compositon with Looping Primitives

In Bean Machine we have a set of nodes we are updating at each iteration. Since the updates are conditional on the dependency graph, we do this in a for-loop.

In [None]:
# looping primitives for node updates with conditionals
# outermost loop, composing all of the functions above

def update_nodes(nodes):
  # loop through, update some nodes given a condition
  for n in nodes:
    if accept_or_reject(n.val).item() > 0:
      # update node value (technically also the nodes in the markov blanket)
      # todo: do this in a non-mutating way
      proposed_value = dist.Normal(0., 1.).sample()
      n.val = proposed_value

### d) Composition with JIT
A functional grad that composes with JIT and actually inlines the backward operations (for optimization like JAX will be really useful. The issue here is that we need to set `requires_grad` to `True` in `_potential_grad` and the tracer complains about that. In the example below, a functional version of `grad` will be really useful so that the JIT does not complain about inserting a constant with `requires_grad` set to True.

In [None]:
def fn(x):
  # It will be nice to have a functional grad variant that does not require us
  # to set `requires_grad` to True.
  x.requires_grad_(True)
  y = x**3
  grad = torch.autograd.grad(y, x)
  x.requires_grad_(False)
  return x + grad


torch.jit.trace(fn, torch.tensor(2.))

RuntimeError: ignored

## JAX NUTS example: Code for leapfrog integrator

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
print('jax version: ', jax.__version__)

jax version:  0.2.9


In [None]:
from collections import namedtuple
import jax
from jax import grad, jit, partial, random, value_and_grad, lax
from jax.flatten_util import ravel_pytree
import jax.numpy as np
from jax import random
from jax.tree_util import register_pytree_node, tree_multimap

In [None]:
# (q, p) -> (position (param value), momentum)

IntegratorState = namedtuple("IntegratorState", ["q", "p", "potential_energy", "q_grad"])

# a tree-like JAX primitive that allows program transformations
# to work on Python containers (https://jax.readthedocs.io/en/latest/pytrees.html)
register_pytree_node(
    IntegratorState,
    lambda xs: (tuple(xs), None),
    lambda _, xs: IntegratorState(*xs)
)


def leapfrog(potential_fn, kinetic_fn):
    r"""
    Second order symplectic integrator that uses the leapfrog algorithm
    for position `q` and momentum `p`.

    :param potential_fn: Python callable that computes the potential energy
        given input parameters. The input parameters to `potential_fn` can be
        any python collection type.
    :param kinetic_fn: Python callable that returns the kinetic energy given
        inverse mass matrix and momentum.
    :return: a pair of (`init_fn`, `update_fn`).
    """
    def init_fn(q, p):
        """
        :param q: Position of the particle.
        :param p: Momentum of the particle.
        :return: initial state for the integrator.
        """
        potential_energy, q_grad = value_and_grad(potential_fn)(q)
        return IntegratorState(q, p, potential_energy, q_grad)

    def update_fn(step_size, inverse_mass_matrix, state):
        """
        :param float step_size: Size of a single step.
        :param inverse_mass_matrix: Inverse of mass matrix, which is used to
            calculate kinetic energy.
        :param state: Current state of the integrator.
        :return: new state for the integrator.
        """
        q, p, _, q_grad = state
        # maps a function over a pytree, returning a new pytree
        p = tree_multimap(lambda p, q_grad: p - 0.5 * step_size * q_grad, p, q_grad)  # p(n+1/2)
        p_grad = grad(kinetic_fn, argnums=1)(inverse_mass_matrix, p)
        q = tree_multimap(lambda q, p_grad: q + step_size * p_grad, q, p_grad)  # q(n+1)
        potential_energy, q_grad = value_and_grad(potential_fn)(q)
        p = tree_multimap(lambda p, q_grad: p - 0.5 * step_size * q_grad, p, q_grad)  # p(n+1)
        return IntegratorState(q, p, potential_energy, q_grad)

    return init_fn, update_fn


def kinetic_fn(inverse_mass_matrix, p):
    # flattens the pytree
    p, _ = ravel_pytree(p)

    if inverse_mass_matrix.ndim == 2:
        v = np.matmul(inverse_mass_matrix, p)
    elif inverse_mass_matrix.ndim == 1:
        v = np.multiply(inverse_mass_matrix, p)

    return 0.5 * np.dot(v, p)


Note that jax provides some great utilities that let us operate on [pytrees](https://jax.readthedocs.io/en/latest/jax.tree_util.html?highlight=pytree#module-jax.tree_util) which can be any python container type that support packing/unpacking implementations. e.g. `tree_multimap` above. this lets us write generic code without imposing any assumptions on the client code. e.g. `potential_fn` could be a list or a simple array and the integrator code remains the same.

In [None]:
D = 1000    

true_mean, true_std = np.ones(D), np.ones(D) * 2.

def potential_fn(q):
    """
    - log density for the normal distribution
    """
    return 0.5 * np.sum(((q['z'] - true_mean) / true_std) ** 2)    


# U-turn termination condition
# For demonstration purpose - this won't result in a correct MCMC proposal
def is_u_turning(q_i, q_f, p_f):
  return np.less(np.dot((q_f['z'] - q_i['z']), p_f['z']), 0)


# Run leapfrog until termination condition is met
def get_final_state(ke, pe, m_inv, step_size, q_i, p_i):
    lf_init, lf_update = leapfrog(pe, ke)
    lf_init_state = lf_init(q_i, p_i)
    q_f, p_f, _, _ = lax.while_loop(lambda x: is_u_turning(q_i, x[0], x[1]), 
                                    lambda x: lf_update(step_size, m_inv, x),
                                    lf_init_state)
    return (q_f, p_f)




### jit and grad composition

Note that we are jit compiling the integrator which includes grad computation.

In [None]:
q_i = {'z': np.zeros(D)}
p_i = lambda i: {'z': random.normal(random.PRNGKey(i), (D,))}
inv_mass_matrix = np.eye(D)
step_size = 0.001
  
fn = jit(get_final_state, static_argnums=(0, 1))
timefn = lambda i: fn(kinetic_fn, potential_fn, inv_mass_matrix, 
                      step_size, q_i, p_i(i))

# Run only once in a loop; otherwise the best number reported 
# does not include compilation time.
%timeit -n 1 -r 1 [timefn(0) for i in range(10)]

1 loop, best of 1: 20.4 ms per loop


### Lets add vmap to do this in parallel

Note that this requires:
 - composition of `vmap` with `jit` for `get_final_state`.
 - composition of `vmap` with control flow primitive `while` in `get_final_state`.
 - also composition of `vmap` and `jit` with `grad` in leapfrog.

In [None]:
# Draw K in parallel

K = 50
q_i = {'z': random.normal(random.PRNGKey(1), (K, D))}
p_i = {'z': random.normal(random.PRNGKey(2), (K, D))}
jax.vmap(lambda z: fn(kinetic_fn, potential_fn, inv_mass_matrix, 
                      step_size, *z))((q_i, p_i))

({'z': DeviceArray([[-0.6775011 ,  1.2218331 ,  0.34814987, ..., -0.1360297 ,
                 0.09302761, -0.8675127 ],
               [ 0.5821763 ,  0.11254374, -1.0764872 , ..., -0.10364341,
                -0.84956676,  1.1980356 ],
               [ 1.1932558 ,  0.69470936, -1.1416658 , ..., -1.3943284 ,
                 0.06114601, -0.891614  ],
               ...,
               [ 1.4670084 , -0.27074087, -0.5646535 , ..., -0.85916924,
                -0.7521342 , -1.5170535 ],
               [-1.9934872 ,  0.6114559 ,  0.26292798, ..., -0.65967107,
                -1.3601958 , -0.3978807 ],
               [-0.65449744, -0.30205557,  0.6561341 , ..., -0.41573793,
                 0.952913  ,  0.05202565]], dtype=float32)},
 {'z': DeviceArray([[-0.5898306 , -0.86168784, -0.6559259 , ..., -0.7063382 ,
                 0.5879463 , -0.7263582 ],
               [-0.8124092 ,  1.1634269 ,  0.67753744, ...,  0.0677776 ,
                -0.38750023,  0.1552683 ],
               [-0.53802

## PyTorch NUTS example: LeapFrog Integrator

In [None]:
def leapfrog(q, p, potential_fn, inverse_mass_matrix, step_size, num_steps=1, q_grads=None):
    r"""
    Second order symplectic integrator that uses the velocity leapfrog algorithm.

    :param dict q: dictionary of sample site names and their current values
        (type :class:`~torch.Tensor`).
    :param dict p: dictionary of sample site names and corresponding momenta
        (type :class:`~torch.Tensor`).
    :param callable potential_fn: function that returns potential energy given q
        for each sample site. The negative gradient of the function with respect
        to ``q`` determines the rate of change of the corresponding sites'
        momenta ``r``.
    :param torch.Tensor inverse_mass_matrix: a tensor :math:`M^{-1}` which is used
        to calculate kinetic energy: :math:`E_{kinetic} = \frac{1}{2}z^T M^{-1} q`.
        Here :math:`M` can be a 1D tensor (diagonal matrix) or a 2D tensor (dense matrix).
    :param float step_size: step size for each time step iteration.
    :param int num_steps: number of discrete time steps over which to integrate.
    :param torch.Tensor q_grads: optional gradients of potential energy at current ``q``.
    :return tuple (q_next, p_next, q_grads, potential_energy): next position and momenta,
        together with the potential energy and its gradient w.r.t. ``q_next``.
    """
    q_next = q.copy()
    p_next = p.copy()
    for _ in range(num_steps):
        q_next, p_next, q_grads, potential_energy = _single_step(q_next,
                                                                 p_next,
                                                                 potential_fn,
                                                                 inverse_mass_matrix,
                                                                 step_size,
                                                                 q_grads)
    return q_next, p_next, q_grads, potential_energy


def _single_step(q, p, potential_fn, inverse_mass_matrix, step_size, q_grads=None):
    r"""
    Single step leapfrog that modifies the `q`, `p` dicts in place.
    """

    q_grads = _potential_grad(potential_fn, q)[0] if q_grads is None else q_grads

    for site_name in p:
        p[site_name] = p[site_name] + 0.5 * step_size * (-q_grads[site_name])  # p(n+1/2)

    p_grads = _kinetic_grad(inverse_mass_matrix, p)
    for site_name in q:
        q[site_name] = q[site_name] + step_size * p_grads[site_name]  # q(n+1)

    q_grads, potential_energy = _potential_grad(potential_fn, q)
    for site_name in p:
        p[site_name] = p[site_name] + 0.5 * step_size * (-q_grads[site_name])  # p(n+1)

    return q, p, q_grads, potential_energy


def _potential_grad(potential_fn, q):
    q_keys, q_nodes = zip(*q.items())
    for node in q_nodes:
        node.requires_grad_(True)
    potential_energy = potential_fn(q)
    grads = torch.autograd.grad(potential_energy, q_nodes)
    for node in q_nodes:
        node.requires_grad_(False)
    return dict(zip(q_keys, grads)), potential_energy.detach()


def _kinetic_grad(inverse_mass_matrix, p):
    p_flat = torch.cat([p[site_name].reshape(-1) for site_name in sorted(p)])
    if inverse_mass_matrix.dim() == 1:
        grads_flat = inverse_mass_matrix * p_flat
    else:
        grads_flat = inverse_mass_matrix.matmul(p_flat)

    # unpacking
    grads = {}
    pos = 0
    for site_name in sorted(p):
        next_pos = pos + p[site_name].numel()
        grads[site_name] = grads_flat[pos:next_pos].reshape(p[site_name].shape)
        pos = next_pos
    assert pos == grads_flat.size(0)
    return grads

In [None]:
D = 1000

true_mean, true_std = 1, 2.

def potential_fn(params):
    return 0.5 * torch.sum(((params['z'] - true_mean) / true_std) ** 2)

# U-turn termination condition
# For demonstration purpose - this won't result in a correct MCMC proposal
def is_u_turning(q_i, q_f, p_f):
  return torch.dot((q_f['z'] - q_i['z']), p_f['z']) < 0.


# Run leapfrog until termination condition is met
def get_final_state(pe, m_inv, step_size, q_i, p_i):
    q, p = q_i, p_i
    q_grads = None
    while not is_u_turning(q_i, q, p):
      q, p, q_grads, _ = leapfrog(q, p, pe, m_inv, step_size, q_grads=q_grads)
    return (q, p)


In [None]:
q_i = {'z': torch.zeros(D)}
p_i = {'z': torch.randn(D)}
inv_mass_matrix = torch.eye(D)
step_size = 0.001
num_steps = 10000 


In [None]:
%timeit -n 1 -r 1 get_final_state(potential_fn, inv_mass_matrix, step_size, q_i, p_i)

1 loop, best of 1: 1.78 s per loop


### jit, grad, and vmap
Unlike in JAX, the PyTorch JIT cannot yet inline and optimize grad calls. 

In [None]:
torch.jit.trace(lambda q, p: get_final_state(potential_fn, inv_mass_matrix, step_size, q, p), (q_i, p_i))

And there remain unsupported ops for vmap batching.

In [None]:
K = 50
q_i = {'z': torch.randn(K, D)}
p_i = {'z': torch.randn(K, D)}
torch.vmap(lambda q, p: get_final_state(potential_fn, inv_mass_matrix, step_size, q, p))(q_i, p_i)

  after removing the cwd from sys.path.


RuntimeError: ignored

## Helpful References
1. [NUTS paper](https://arxiv.org/abs/1111.4246)
2. [NUTS JAX implementation](https://github.com/pyro-ppl/numpyro/blob/master/numpyro/infer/hmc.py)

  a. [Iterative NUTS numpyro (unrolling the recursive algorithm for JITting and vmap)](https://github.com/pyro-ppl/numpyro/wiki/Iterative-NUTS)

  b. [Iterative NUTS tfp](https://github.com/tensorflow/probability/blob/master/discussion/technical_note_on_unrolled_nuts.md)

3. [NUTS PyTorch implementation (fbinternal)](https://www.internalfb.com/intern/diffusion/FBS/browse/master/fbcode/beanmachine/beanmachine/ppl/inference/proposer/single_site_no_u_turn_sampler_proposer.py?commit=68b1672d648dba714d3f7c2ce13494b01925b103&lines=18)