diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index e9836c73582f..38b1ab439eb7 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -2442,7 +2442,9 @@ "### `linearize`\n", "\n", "In the case of `linearize`, we want to stage out the linear part of a `jvp`\n", - "computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,\n", + "computation. That is, in terms of\n", + "[Haskell-like type signatures](https://wiki.haskell.org/Type_signature),\n", + "if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`,\n", "then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, using `T a` to\n", "mean \"the tangent type of `a`\" and using the \"lollipop\" `-o` rather than the\n", "arrow `->` to indicate a _linear_ function. We define the semantics of\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 05aa2adc66d8..fa142066c017 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -1895,7 +1895,9 @@ computation. ### `linearize` In the case of `linearize`, we want to stage out the linear part of a `jvp` -computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`, +computation. That is, in terms of +[Haskell-like type signatures](https://wiki.haskell.org/Type_signature), +if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`, then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, using `T a` to mean "the tangent type of `a`" and using the "lollipop" `-o` rather than the arrow `->` to indicate a _linear_ function. We define the semantics of diff --git a/docs/autodidax.py b/docs/autodidax.py index 643d331c4bc1..8df4801a6f94 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -1884,7 +1884,9 @@ def pprint_xla_call(names: DefaultDict[Var, str], eqn: JaxprEqn) -> PPrint: # ### `linearize` # # In the case of `linearize`, we want to stage out the linear part of a `jvp` -# computation. That is, if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`, +# computation. That is, in terms of +# [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), +# if we have `jvp : (a -> b) -> (a, T a) -> (b, T b)`, # then we write `linearize : (a -> b) -> a -> (b, T a -o T b)`, using `T a` to # mean "the tangent type of `a`" and using the "lollipop" `-o` rather than the # arrow `->` to indicate a _linear_ function. We define the semantics of diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst index ae704ed195c9..add946bb2208 100644 --- a/docs/jaxpr.rst +++ b/docs/jaxpr.rst @@ -352,7 +352,8 @@ constructed with the :py:func:`jax.lax.scan` function:: lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B]) -Here ``C`` is the type of the scan carry, ``A`` is the element type of the +This is written in terms of a `Haskell Type Signature`_: +``C`` is the type of the scan carry, ``A`` is the element type of the input array(s), and ``B`` is the element type of the output array(s). For the example consider the function ``func11`` below @@ -474,3 +475,5 @@ parameter. The value of this parameter is a Jaxpr with 2 input variables. The parameter ``in_axes`` specifies which of the input variables should be mapped and which should be broadcast. In our example, the value of ``extra`` is broadcast and the value of ``arr`` is mapped. + +.. _Haskell Type Signature: https://wiki.haskell.org/Type_signature \ No newline at end of file diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index bd04dfa6954b..6e8ab41fc8dc 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -1090,7 +1090,8 @@ "source": [ "### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules\n", "\n", - "Here's a canonical basic example of using `jax.custom_jvp`:" + "Here's a canonical basic example of using `jax.custom_jvp`, where the comments use\n", + "[Haskell-like type signatures](https://wiki.haskell.org/Type_signature):" ] }, { diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index 32de12cd5dd9..75132f69209e 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -570,7 +570,8 @@ A limitation to this approach is that the argument `f` can't close over any valu ### Use `jax.custom_jvp` to define forward-mode (and, indirectly, reverse-mode) rules -Here's a canonical basic example of using `jax.custom_jvp`: +Here's a canonical basic example of using `jax.custom_jvp`, where the comments use +[Haskell-like type signatures](https://wiki.haskell.org/Type_signature): ```{code-cell} ipython3 :id: nVkhbIFAOGZk diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index 526466f6ccef..ab4cf8e8dfd0 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -676,7 +676,8 @@ "id": "m1VJgJYQGfCK" }, "source": [ - "In terms of Haskell-like type signatures, we could write\n", + "In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature),\n", + "we could write\n", "\n", "```haskell\n", "jvp :: (a -> b) -> a -> T a -> (b, T b)\n", @@ -769,7 +770,8 @@ "id": "oVOZexCEkvv3" }, "source": [ - "In terms of Haskell-like type signatures, we could write\n", + "In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature),\n", + "we could write\n", "\n", "```haskell\n", "vjp :: (a -> b) -> a -> (b, CT b -> CT a)\n", diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index d9a9f784b2de..a601217b8d30 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -378,7 +378,8 @@ y, u = jvp(f, (W,), (v,)) +++ {"id": "m1VJgJYQGfCK"} -In terms of Haskell-like type signatures, we could write +In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), +we could write ```haskell jvp :: (a -> b) -> a -> T a -> (b, T b) @@ -451,7 +452,8 @@ v = vjp_fun(u) +++ {"id": "oVOZexCEkvv3"} -In terms of Haskell-like type signatures, we could write +In terms of [Haskell-like type signatures](https://wiki.haskell.org/Type_signature), +we could write ```haskell vjp :: (a -> b) -> a -> (b, CT b -> CT a) diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index b2b81f37e4a9..5f55a68ec938 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -781,8 +781,8 @@ def custom_gradient(fun): and the VJP (gradient) function. See https://www.tensorflow.org/api_docs/python/tf/custom_gradient. - If the mathematical function to be differentiated has type signature ``a -> - b``, then the Python callable ``fun`` should have signature + If the mathematical function to be differentiated has Haskell-like signature + ``a -> b``, then the Python callable ``fun`` should have the signature ``a -> (b, CT b --o CT a)`` where we use ``CT x`` to denote a cotangent type for ``x`` and the ``--o`` arrow to denote a linear function. See the example below. That is, ``fun`` should return a pair where the first element @@ -1001,7 +1001,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args, linear_args): """Call a linear function, with a custom implementation for its transpose. - The type signatures of ``fun`` and ``fun_transpose`` are: + The `Haskell-like type signatures`_ of ``fun`` and ``fun_transpose`` are: .. code-block:: haskell @@ -1081,6 +1081,7 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args, Returns: The call result, i.e. ``fun(residual_args, linear_args)``. + .. _Haskell-like type signatures: https://wiki.haskell.org/Type_signature """ operands_res, res_tree = tree_flatten(residual_args) operands_lin, lin_tree = tree_flatten(linear_args) diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 18bcbc7cf36d..2c90d745aeab 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -153,7 +153,7 @@ def scanned_fun(loop_carry, _): def fori_loop(lower, upper, body_fun, init_val): """Loop from ``lower`` to ``upper`` by reduction to :func:`jax.lax.while_loop`. - The type signature in brief is + The `Haskell-like type signature`_ in brief is .. code-block:: haskell @@ -191,6 +191,8 @@ def fori_loop(lower, upper, body_fun, init_val): Returns: Loop value from the final iteration, of type ``a``. + + .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature """ if not callable(body_fun): raise TypeError("lax.fori_loop: body_fun argument should be callable.") @@ -235,7 +237,7 @@ def while_loop(cond_fun: Callable[[T], BooleanNumeric], init_val: T) -> T: """Call ``body_fun`` repeatedly in a loop while ``cond_fun`` is True. - The type signature in brief is + The `Haskell-like type signature`_ in brief is .. code-block:: haskell @@ -275,6 +277,8 @@ def while_loop(cond_fun, body_fun, init_val): Returns: The output from the final iteration of body_fun, of type ``a``. + + .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature """ if not (callable(body_fun) and callable(cond_fun)): raise TypeError("lax.while_loop: body_fun and cond_fun arguments should be callable.") @@ -1354,7 +1358,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], unroll: int = 1) -> Tuple[Carry, Y]: """Scan a function over leading array axes while carrying along state. - The type signature in brief is + The `Haskell-like type signature`_ in brief is .. code-block:: haskell @@ -1422,6 +1426,8 @@ def scan(f, init, xs, length=None): A pair of type ``(c, [b])`` where the first element represents the final loop carry value and the second element represents the stacked outputs of the second output of ``f`` when scanned over the leading axis of the inputs. + + .. _Haskell-like type signature: https://wiki.haskell.org/Type_signature """ if not callable(f): raise TypeError("lax.scan: f argument should be a callable.")