Skip to content

Commit

Permalink
Use num_tokens consistently
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 538941855
  • Loading branch information
jax authors committed Jun 9, 2023
1 parent 492fd4b commit 00f2a8c
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,8 @@ def lower_jaxpr_to_fun(
corresponding output that should alias them.
api_name: The name of the higher level primitive which should show up in the
name stack.
Returns the name of the function.
Returns:
MLIR func op
"""
def aval_to_types(aval):
if replace_tokens_with_dummy and aval is core.abstract_token:
Expand All @@ -847,10 +848,10 @@ def aval_to_types(aval):
# MLIR function.
output_token_types = []
token_types = [token_type() for _ in effects]
token_avals = [core.AbstractToken] * len(effects)
token_avals = [core.AbstractToken] * num_tokens
input_avals = dim_var_avals + token_avals + jaxpr.in_avals
input_types = [*dim_var_types, *token_types, *input_types]
output_avals = [core.AbstractToken] * (len(output_token_types) + len(token_types)) + jaxpr.out_avals
output_avals = [core.AbstractToken] * (len(output_token_types) + num_tokens) + jaxpr.out_avals
output_types = [*output_token_types, *token_types, *output_types]
if input_output_aliases is not None:
token_input_output_aliases = [None] * (num_dim_vars + num_tokens)
Expand Down

0 comments on commit 00f2a8c

Please sign in to comment.