Skip to content

Commit

Permalink
Switched from the deprecated defun_with_attributes() to tf.function()
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 511782775
  • Loading branch information
superbobry authored and Copybara-Service committed Feb 23, 2023
1 parent ee36d84 commit 26a685e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
2 changes: 1 addition & 1 deletion requirements-tf.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
tensorflow==2.11.0
tensorflow==2.12.0rc0
tensorflow-probability==0.12.2
5 changes: 2 additions & 3 deletions sonnet/src/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
# pylint: disable=g-direct-tensorflow-import
# Required for specializing `UnrolledLSTM` per device.
from tensorflow.python import context as context_lib
from tensorflow.python.eager import function as function_lib
# pylint: enable=g-direct-tensorflow-import


Expand Down Expand Up @@ -1029,9 +1028,9 @@ def wrapper(*args, **kwargs):
unique_api_name = "{}_{}".format(api_name, uuid.uuid4())
functions = {}
for device, specialization in specializations.items():
functions[device] = function_lib.defun_with_attributes(
functions[device] = tf.function(
specialization,
attributes={
experimental_attributes={
"api_implements": unique_api_name,
"api_preferred_device": device
})
Expand Down

0 comments on commit 26a685e

Please sign in to comment.