diff --git a/keras/export/export_lib.py b/keras/export/export_lib.py index e0ca00e7d85..a539dad97aa 100644 --- a/keras/export/export_lib.py +++ b/keras/export/export_lib.py @@ -320,7 +320,8 @@ def stateless_fn(variables, *args, **kwargs): def stateful_fn(*args, **kwargs): return jax2tf_stateless_fn( - self._tf_trackable.variables, *args, **kwargs + # Change the trackable `ListWrapper` to a plain `list` + list(self._tf_trackable.variables), *args, **kwargs ) # Note: we truncate the number of parameters to what is