From 08bf20009602bdcae8f7e4b760ae56f41204da62 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Fri, 29 Mar 2024 11:19:19 -0700 Subject: [PATCH] Fix for JAX export on GPU. --- keras/export/export_lib.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras/export/export_lib.py b/keras/export/export_lib.py index e0ca00e7d85..a539dad97aa 100644 --- a/keras/export/export_lib.py +++ b/keras/export/export_lib.py @@ -320,7 +320,8 @@ def stateless_fn(variables, *args, **kwargs): def stateful_fn(*args, **kwargs): return jax2tf_stateless_fn( - self._tf_trackable.variables, *args, **kwargs + # Change the trackable `ListWrapper` to a plain `list` + list(self._tf_trackable.variables), *args, **kwargs ) # Note: we truncate the number of parameters to what is