From 8bcf358fdef98771ae66bd61182ed803b01e46bd Mon Sep 17 00:00:00 2001 From: Roman Ring Date: Mon, 26 Sep 2022 17:14:09 +0100 Subject: [PATCH] Remove unused _remat_static_argnums import. --- jax/_src/api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 101a72411b98..5cefee55af49 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -73,7 +73,6 @@ process_count, host_id, host_ids, host_count, default_backend) from jax.ad_checkpoint import checkpoint_policies, checkpoint as new_checkpoint -from jax._src.ad_checkpoint import _remat_static_argnums from jax.core import ShapedArray, raise_to_shaped from jax.custom_batching import custom_vmap from jax.custom_derivatives import (closure_convert, custom_gradient, custom_jvp,