Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions keras/src/utils/backend_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import inspect
import os
import sys
import warnings

from keras.src import backend as backend_module
from keras.src.api_export import keras_export
Expand Down Expand Up @@ -124,14 +125,22 @@ def set_backend(backend):

Example:

```python
import keras

keras.config.set_backend("jax")

del keras
import keras
```
>>> import os
>>> os.environ["KERAS_BACKEND"] = "tensorflow"
>>>
>>> import keras
>>> from keras import ops
>>> type(ops.ones(()))
<class 'tensorflow.python.framework.ops.EagerTensor'>
>>>
>>> keras.config.set_backend("jax")
UserWarning: Using `keras.config.set_backend` is dangerous...
>>> del keras, ops
>>>
>>> import keras
>>> from keras import ops
>>> type(ops.ones(()))
<class 'jaxlib.xla_extension.ArrayImpl'>

⚠️ WARNING ⚠️: Using this function is dangerous and should be done
carefully. Changing the backend will **NOT** convert
Expand All @@ -143,7 +152,7 @@ def set_backend(backend):

This includes any function or class instance that uses any Keras
functionality. All such code needs to be re-executed after calling
`set_backend()` and re-importing the `keras` module.
`set_backend()` and re-importing all imported `keras` modules.
"""
os.environ["KERAS_BACKEND"] = backend
# Clear module cache.
Expand All @@ -164,3 +173,16 @@ def set_backend(backend):
module_name = module_name[module_name.find("'") + 1 :]
module_name = module_name[: module_name.find("'")]
globals()[key] = importlib.import_module(module_name)

warnings.warn(
"Using `keras.config.set_backend` is dangerous and should be done "
"carefully. Already-instantiated objects will not be converted. Thus, "
"any layers / tensors / etc. already created will no longer be usable "
"without errors. It is strongly recommended not to keep around any "
"Keras-originated objects instances created before calling "
"`set_backend()`. This includes any function or class instance that "
"uses any Keras functionality. All such code needs to be re-executed "
"after calling `set_backend()` and re-importing all imported `keras` "
"modules.",
stacklevel=2,
)