Skip to content

[BUG] bm.for_loop receive jit param but don't handle it #802

@Routhleck

Description

@Routhleck

For brainpy.math.for_loop, it receive jit param and the docstring said it is described as "Whether to just-in-time compile the function."

def for_loop(
body_fun: Callable,
operands: Any,
reverse: bool = False,
unroll: int = 1,
remat: bool = False,
jit: Optional[bool] = None,
progress_bar: bool = False,
unroll_kwargs: Optional[Dict] = None,
):

But never handle it in main struct of this func

if not isinstance(operands, (tuple, list)):
operands = (operands,)
return brainstate.transform.for_loop(
warp_to_no_state_input_output(body_fun),
*operands, reverse=reverse, unroll=unroll,
pbar=brainstate.transform.ProgressBar() if progress_bar else None,
)

Perhaps with jax.disable_jit() is a solution.


And also there are some unused params like remat and unroll_kwargs, need to be removed in future.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions