Skip to content

Commit

Permalink
DOC: add references for haskell-style signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 13, 2022
1 parent bd659df commit 8a3bfe0
Show file tree
Hide file tree
Showing 10 changed files with 38 additions and 16 deletions.
4 changes: 3 additions & 1 deletion docs/autodidax.ipynb
Expand Up @@ -2442,7 +2442,9 @@
"### `linearize`\n",
"\n",
"In the case of `linearize`, we want to stage out the linear part of a `jvp`\n",
"computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,\n",
"computation. That is, in terms of\n",
"[Haskell-like type signatures](https://wiki.haskell.org/Type_signature),\n",
"if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,\n",
"then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, using `T a` to\n",
"mean \"the tangent type of `a`\" and using the \"lollipop\" `-o` rather than the\n",
"arrow `->` to indicate a _linear_ function. We define the semantics of\n",
Expand Down
4 changes: 3 additions & 1 deletion docs/autodidax.md
Expand Up @@ -1895,7 +1895,9 @@ computation.
### `linearize`

In the case of `linearize`, we want to stage out the linear part of a `jvp`
computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,
computation. That is, in terms of
[Haskell-like type signatures](https://wiki.haskell.org/Type_signature),
if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,
then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, using `T a` to
mean "the tangent type of `a`" and using the "lollipop" `-o` rather than the
arrow `->` to indicate a _linear_ function. We define the semantics of
Expand Down
4 changes: 3 additions & 1 deletion docs/autodidax.py
Expand Up @@ -1884,7 +1884,9 @@ def pprint_xla_call(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint:
# ### `linearize`
#
# In the case of `linearize`, we want to stage out the linear part of a `jvp`
# computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,
# computation. That is, in terms of
# [Haskell-like type signatures](https://wiki.haskell.org/Type_signature),
# if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,
# then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, using `T a` to
# mean "the tangent type of `a`" and using the "lollipop" `-o` rather than the
# arrow `->` to indicate a _linear_ function. We define the semantics of
Expand Down
5 changes: 4 additions & 1 deletion docs/jaxpr.rst
Expand Up @@ -352,7 +352,8 @@ constructed with the :py:func:`jax.lax.scan` function::

lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])

Here ``C`` is the type of the scan carry, ``A`` is the element type of the
This is written in terms of a `Haskell Type Signature`_:
``C`` is the type of the scan carry, ``A`` is the element type of the
input array(s), and ``B`` is the element type of the output array(s).

For the example consider the function ``func11`` below
Expand Down Expand Up @@ -474,3 +475,5 @@ parameter. The value of this parameter is a Jaxpr with 2 input variables.
The parameter ``in_axes`` specifies which of the input variables should be
mapped and which should be broadcast. In our example, the value of ``extra``
is broadcast and the value of ``arr`` is mapped.

.. _Haskell Type Signature: https://wiki.haskell.org/Type_signature
3 changes: 2 additions & 1 deletion docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
Expand Up @@ -1090,7 +1090,8 @@
"source": [
"### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules\n",
"\n",
"Here's a canonical basic example of using `jax.custom_jvp`:"
"Here's a canonical basic example of using `jax.custom_jvp`, where the comments use\n",
"[Haskell-like type signatures](https://wiki.haskell.org/Type_signature):"
]
},
{
Expand Down
3 changes: 2 additions & 1 deletion docs/notebooks/Custom_derivative_rules_for_Python_code.md
Expand Up @@ -570,7 +570,8 @@ A limitation to this approach is that the argument `f` can't close over any valu

### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules

Here's a canonical basic example of using `jax.custom_jvp`:
Here's a canonical basic example of using `jax.custom_jvp`, where the comments use
[Haskell-like type signatures](https://wiki.haskell.org/Type_signature):

```{code-cell} ipython3
:id: nVkhbIFAOGZk
Expand Down
6 changes: 4 additions & 2 deletions docs/notebooks/autodiff_cookbook.ipynb
Expand Up @@ -676,7 +676,8 @@
"id": "m1VJgJYQGfCK"
},
"source": [
"In terms of Haskell-like type signatures, we could write\n",
"In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature),\n",
"we could write\n",
"\n",
"```haskell\n",
"jvp :: (a -> b) -> a -> T a -> (b, T b)\n",
Expand Down Expand Up @@ -769,7 +770,8 @@
"id": "oVOZexCEkvv3"
},
"source": [
"In terms of Haskell-like type signatures, we could write\n",
"In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature),\n",
"we could write\n",
"\n",
"```haskell\n",
"vjp :: (a -> b) -> a -> (b, CT b -> CT a)\n",
Expand Down
6 changes: 4 additions & 2 deletions docs/notebooks/autodiff_cookbook.md
Expand Up @@ -378,7 +378,8 @@ y, u = jvp(f, (W,), (v,))

+++ {"id": "m1VJgJYQGfCK"}

In terms of Haskell-like type signatures, we could write
In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature),
we could write

```haskell
jvp :: (a -> b) -> a -> T a -> (b, T b)
Expand Down Expand Up @@ -451,7 +452,8 @@ v = vjp_fun(u)

+++ {"id": "oVOZexCEkvv3"}

In terms of Haskell-like type signatures, we could write
In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature),
we could write

```haskell
vjp :: (a -> b) -> a -> (b, CT b -> CT a)
Expand Down
7 changes: 4 additions & 3 deletions jax/_src/custom_derivatives.py
Expand Up @@ -781,8 +781,8 @@ def custom_gradient(fun):
and the VJP (gradient) function. See
https://www.tensorflow.org/api_docs/python/tf/custom_gradient.
If the mathematical function to be differentiated has type signature ``a ->
b``, then the Python callable ``fun`` should have signature
If the mathematical function to be differentiated has Haskell-like signature
``a -> b``, then the Python callable ``fun`` should have the signature
``a -> (b, CT b --o CT a)`` where we use ``CT x`` to denote a cotangent type
for ``x`` and the ``--o`` arrow to denote a linear function. See the example
below. That is, ``fun`` should return a pair where the first element
Expand Down Expand Up @@ -1001,7 +1001,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
linear_args):
"""Call a linear function, with a custom implementation for its transpose.
The type signatures of ``fun`` and ``fun_transpose`` are:
The `Haskell-like type signatures`_ of ``fun`` and ``fun_transpose`` are:
.. code-block:: haskell
Expand Down Expand Up @@ -1081,6 +1081,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
Returns:
The call result, i.e. ``fun(residual_args, linear_args)``.
.. _Haskell-like type signatures: https://wiki.haskell.org/Type_signature
"""
operands_res, res_tree = tree_flatten(residual_args)
operands_lin, lin_tree = tree_flatten(linear_args)
Expand Down
12 changes: 9 additions & 3 deletions jax/_src/lax/control_flow.py
Expand Up @@ -153,7 +153,7 @@ def scanned_fun(loop_carry, _):
def fori_loop(lower, upper, body_fun, init_val):
"""Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`.
The type signature in brief is
The `Haskell-like type signature`_ in brief is
.. code-block:: haskell
Expand Down Expand Up @@ -191,6 +191,8 @@ def fori_loop(lower, upper, body_fun, init_val):
Returns:
Loop value from the final iteration, of type ``a``.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
if not callable(body_fun):
raise TypeError("lax.fori_loop: body_fun argument should be callable.")
Expand Down Expand Up @@ -235,7 +237,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric],
init_val: T) -> T:
"""Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True.
The type signature in brief is
The `Haskell-like type signature`_ in brief is
.. code-block:: haskell
Expand Down Expand Up @@ -275,6 +277,8 @@ def while_loop(cond_fun, body_fun, init_val):
Returns:
The output from the final iteration of body_fun, of type ``a``.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
if not (callable(body_fun) and callable(cond_fun)):
raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.")
Expand Down Expand Up @@ -1354,7 +1358,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
unroll: int = 1) -> Tuple[Carry, Y]:
"""Scan a function over leading array axes while carrying along state.
The type signature in brief is
The `Haskell-like type signature`_ in brief is
.. code-block:: haskell
Expand Down Expand Up @@ -1422,6 +1426,8 @@ def scan(f, init, xs, length=None):
A pair of type ``(c, [b])`` where the first element represents the final
loop carry value and the second element represents the stacked outputs of
the second output of ``f`` when scanned over the leading axis of the inputs.
.. _Haskell-like type signature: https://wiki.haskell.org/Type_signature
"""
if not callable(f):
raise TypeError("lax.scan: f argument should be a callable.")
Expand Down

0 comments on commit 8a3bfe0

Please sign in to comment.