Skip to content

Commit

Permalink
Pass pmap axis specs optionally to make_model_info.
Browse files Browse the repository at this point in the history
This helps outputting correct shapes from jax.lax.all_gather.

PiperOrigin-RevId: 485625883
  • Loading branch information
Haiku Contributor authored and Copybara-Service committed Nov 2, 2022
1 parent a191d4e commit d0ba451
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions haiku/_src/jaxpr_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
import logging
import os
import sys
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Set, Sequence, Optional
from typing import Any, Callable, Dict, List, Mapping, NamedTuple, Set, Sequence, Optional, Tuple

from haiku._src import summarise
import jax
Expand Down Expand Up @@ -128,6 +128,7 @@ def make_model_info(
name: Optional[str] = None,
include_module_info: bool = True,
compute_flops: Optional[ComputeFlopsFn] = None,
axis_env: Optional[Sequence[Tuple[Any, int]]] = None,
) -> Callable[..., Module]:
"""Creates a function that computes flop, param and state information.
Expand All @@ -142,6 +143,7 @@ def make_model_info(
information for haiku modules. Can be slow for very large computations.
compute_flops: Optional, a function that returns an estimate of the number
of flops required to execute an equation.
axis_env: Sizes of pmapped axes. See docs of jax.make_jaxpr for details.
Returns:
A wrapped version of `f` that when applied to example arguments returns a
Expand All @@ -153,7 +155,7 @@ def make_model_info(
"""
if not name:
name = f.__name__
make_jaxpr = jax.make_jaxpr(f)
make_jaxpr = jax.make_jaxpr(f, axis_env=axis_env)
if include_module_info:
# Wrap f in a lambda so eval_summary doesn't try to un-transform it.
# TODO(tomhennigan): remove lambda trick
Expand Down

0 comments on commit d0ba451

Please sign in to comment.