##### Copyright 2020 Google LLC.
Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# jax2tex
Sam Schoenholz

In debugging JAX code, it is sometimes desireable to understand what is going on inside a transformed function. Right now, as far as I know, the only ways to do this are: to look at the JAXPR of the transformed function, to look at the XLA HLOs, or to step through in the debugger. However, I have often found none of these solutions to be particularly satisfactory. For mathematical functions all of these op-by-op representations throw away a lot of useful structure that we, as humans, rely on to understand calculations (such as the grouping and structure of terms in equations). After all, we have spent thousands of hours training our visual faculties to understand math, why throw that away?

I have recently found it useful to look at representations of functions that are closer to mathematical expressions that we are familiar with. It has been especially nice to decide on semantic groupings of variables that make it easier to understand the flow of calculations. To that end, I've been playing around with a small (three function) library `jax2tex` that converts from jax traceable functions to latex. This notebook contains some examples showing how the library works.

There are many different choices that one could make when transcribing JAX functions to latex. Here are a few decisions that I made that others might disagree with.

1. I have chosen a style that is very explicit in terms of indexing. This is to make it clear what the numpy code is doing.
2. There is some ambiguity and clutter due to numpy's flexibility re: singleton dimensions. Therefore, I suppress singleton dimensions. (E.g. dot products of dimensions of size-one are elided). 
3. To further remove clutter, I remove ops that affect the computation without affecting the math (e.g. type conversion).
4. To generate effective latex, you will likely need to hand-annotate the function rather than trying to do the annotation automatically. After all, people have carefully decided how to group terms and the grouping is usually very problem specific.
5. Notationally: the prefixes $d$ and $\delta$ refer to tangent vectors in the forward pass and cotangent vectors in the backward pass respectively.

There are still several TODOS that I think would be nice to add.

1. Most ops are still not implemented. Let me know if you would like to use jax2tex but are blocked by a specific op.
2. It would be nice to do some very simple algebraic simplification to further reduce clutter.
3. If people start using this seriously, it should be more rigorously tested.

With this having been said, let's continue to the demo:

In [None]:
!sudo apt-get install subversion
!svn export https://github.com/google-research/google-research/trunk/jax2tex
!pip install jax2tex/

In [None]:
import jax.numpy as jnp
import jax

import jax2tex as j2t

Let us first define a simple one-hidden layer linear MLP:

In [None]:
def f(x, y):
  return x * (x - y) / (x + y) + y

We can then look at both the jaxpr for the function. For now we'll use scalar dummy inputs of `x = 1` and `y = 1`.

In [None]:
from jax import make_jaxpr

print(make_jaxpr(f)(1., 1.))

{ lambda  ; a b.
  let c = sub a b
      d = mul a c
      e = add a b
      f = div d e
      g = add f b
  in (g,) }


We can also look at the latex:

In [None]:
print(j2t.jax2tex(f, 1, 1))

f &= {x\left(x - y\right) \over x + y} + y


$$f = {x\left(x - y\right) \over x + y} + y$$

It is frequently useful to define intermediate variables,

In [None]:
def f(x, y):
  z = j2t.tex_var(x + y, 'z')
  return x * (x - y) / z + y

print(j2t.jax2tex(f, 1, 1))

z &= x + y\\
f &= {x\left(x - y\right) \over z} + y


$$z = x + y\\
f = {x\left(x - y\right) \over z} + y$$

Now let's use this to get insight into what is happening inside a gradient calculation.

In [None]:
from jax import grad

@j2t.bind_names
def f(x, y):
  return x * (x - y) / (x + y) + y

print(j2t.jax2tex(grad(f), 1., 1.))

\delta x &= -1.0{\left(x + y\right)}^{-2}x\left(x - y\right) + {1.0 \over x + y}x + {1.0 \over x + y}\left(x - y\right)


$$
\delta x = -1.0{\left(x + y\right)}^{-2}x\left(x - y\right) + {1.0 \over x + y}x + {1.0 \over x + y}\left(x - y\right)$$

Here we the product rule with the three terms corresponding to the derivative of $(x + y)^{-1}$, $x- y$, and $x$ respectively.

Now we will consider a slightly more realistic example: a simple one-hidden layer MLP with quadratic activation functions.

In [None]:
@j2t.bind_names
def mlp(U, V, x):
  z1 = j2t.tex_var(x @ U, 'z^1')
  y1 = j2t.tex_var(jax.nn.relu(z1), 'y^1')
  return j2t.tex_var(y1 @ V, 'z^2')

Now we will evaluate the function using dummy inputs and weights. Here we take the input to have shape `[3,]` with readin weights of shape `[3, 2]` and readout weights to a scalar output of shape `[2,]`.

In [None]:
print(j2t.jax2tex(mlp, jnp.ones((3, 2)), jnp.ones((2,)), jnp.ones((3,))))

z^1_{i} &= \sum_{j}x_{j}U_{ji}\\
y^1_{i} &= \text{relu}(z^1_{i})\\
z^2 &= \sum_{i}y^1_{i}V_{i}


$$
z^1_{i} = \sum_{j}x_{j}U_{ji}\\
y^1_{i} = \text{relu}(z^1_{i})\\
z^2 = \sum_{i}y^1_{i}V_{i}
$$

Here we see equations defining an MLP that look a lot like what one might find in a textbook. Notice that as discussed above we are very explicit about indices and summation. Next we can look at the gradient,

In [None]:
print(j2t.jax2tex(grad(mlp, argnums=(0, 1)), jnp.ones((3, 2)), jnp.ones((2,)), jnp.ones((3,))))

z^1_{i} &= \sum_{j}x_{j}U_{ji}\\
\delta y^1_{i} &= 1.0V_{i}\\
\delta z^1_{i} &= \mathbbm 1_{z^1_{i}>0.0}\delta y^1_{i} + \left(1 - \mathbbm 1_{z^1_{i}>0.0}\right)0\\
\delta U_{ij} &= \delta z^1_{j}x_{i}\\
y^1_{i} &= \text{relu}(z^1_{i})\\
\delta V_{i} &= 1.0y^1_{i}


$$
z^1_{i} = \sum_{j}x_{j}U_{ji}\\
\delta y^1_{i} = 1.0V_{i}\\
\delta z^1_{i} = 1_{z^1_{i}>0.0}\delta y^1_{i} + \left(1 - 1_{z^1_{i}>0.0}\right)0\\
\delta U_{ij} = \delta z^1_{j}x_{i}\\
y^1_{i} = \text{relu}(z^1_{i})\\
\delta V_{i} = 1.0y^1_{i}
$$

Here we can see the forward pass (computing $z^1$ and $y^1$) as well as the backward pass (computing $\delta y^1$, $\delta z^1$, $\delta U$, and $\delta V$) written out in a relatively comprehensible form. 

A final feature that I have found useful is to define variables that are functions of some argument. Consider the following pair of functions:

In [None]:
def g(x):
  return j2t.tex_var(x ** 2, 'z')

def f(x, y):
  return g(x) * g(y)

print(j2t.jax2tex(f, 1., 1.))

z &= {x}^{2}\\
z &= {y}^{2}\\
f &= zz


$$
z = x^{2}\\
z = y^{2}\\
f = zz
$$

We see that the variable $z$ is ambigous and the dependence on $y$ and $x$ is suppressed. We can deal with this case properly by annotating that $z$ depends on $x$: 

In [None]:
def g(x):
  return j2t.tex_var(x ** 2, 'z', depends_on=x)

def f(x, y):
  return g(x) * g(y)

print(j2t.jax2tex(f, 1., 1.))

z(x) &= {x}^{2}\\
z(y) &= {y}^{2}\\
f(x,y) &= z(x)z(y)


$$
z(x) = x^{2}\\
z(y) = y^{2}\\
f = z(x)z(y)
$$

Now we see that the two instances of $z$ depend explicitly on their arguments.