To get the peak performance and to make your model deployable, use `tf.function` to make graphs out of your programs. There are some tips for using `TF.function`:
* Don't rely on Python object attributes like mutation, list appending, etc.
* The `TF.function` works better with Tensorflow Ops, rather than Numpy Ops or Python primitives.

A `tf.function` defined is as a core operator in Tensorflow. A `tf.py_function` is a wrapped function by Tensorflow to make a Python primitive function to become a Tensorflow operator.

In [0]:
!pip install -q tf-nightly

In [2]:
import tensorflow as tf
import numpy as np
import traceback
import contextlib

print("Tensorflow Version: {}".format(tf.__version__))
print("Eager Mode: {}".format(tf.executing_eagerly()))
print("GPU {} available".format("is" if tf.config.experimental.list_physical_devices("GPU") else "not"))

Tensorflow Version: 2.2.0-dev20200119
Eager Mode: True
GPU is available


Before we start, let's define a helper function to demonstrate the kinds of error you encounter.

In [0]:
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print("Caught excepted exception: {}".format(error_class))
  except Exception as e:
    raise e
  else:
    raise Exception("Expected {} to be raised but no error was raised.".format(error_class))

# Basics

A `tf.function` defined is as a core operator in Tensorflow. You can execute it eagerly or use it in the graph. It also has gradients and so on.

In [0]:
@tf.function
def add(a, b):
  return a + b

In [5]:
add(tf.ones(shape=[2, 2]), tf.ones(shape=[2, 2]))

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>

A `tf.function` also has gradients.

In [6]:
val = tf.Variable(2.0)

with tf.GradientTape() as tape:
  y = add(val, 10.0)  # y = val + 10
tape.gradient(y, val)

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

You can define a nested `tf.functions`.

In [0]:
@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

In [8]:
dense_layer(tf.ones(shape=[2, 3]), tf.ones(shape=[3,2]), tf.ones(shape=[2]))

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[4., 4.],
       [4., 4.]], dtype=float32)>

You can also pass the Numpy objects or Python primitive objects.

In [9]:
add(np.ones(shape=[3, 3]), np.ones(shape=[3, 3]))

<tf.Tensor: shape=(3, 3), dtype=float64, numpy=
array([[2., 2., 2.],
       [2., 2., 2.],
       [2., 2., 2.]])>

In [10]:
add(2.0, 3.0).numpy()



5.0

# Tracing and Polymorphism

Python dynamic typing allows you to call a function with a variety of data types. The Python runtime would do different responses in each scenario. However, Tensorflow graphs require a static data type and data shape. The `TF.function` bridges the gap between both of them. **That is, a `TF.function` accepts a variety of data types but to be executed as a Tensorflow operator.**

In [0]:
@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

In [12]:
print(double(tf.constant(1)), sep='\n\n')
print(double(tf.constant(1.1)), sep='\n\n')
print(double(tf.constant("s")), sep='\n\n')

Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)
Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'ss', shape=(), dtype=string)


In [13]:
print("Obtaining a concrete function.")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))

print("Executing traced function.")
print(double_strings(tf.constant("s")))
print(double_strings(tf.constant("x")))

print("Using a concrete trace with incompatiable data types will cause an error.")
with assert_raises(tf.errors.InvalidArgumentError):
  print(double_strings(tf.constant(1.1)))

Obtaining a concrete function.
Tracing with Tensor("a:0", dtype=string)
Executing traced function.
tf.Tensor(b'ss', shape=(), dtype=string)
tf.Tensor(b'xx', shape=(), dtype=string)
Using a concrete trace with incompatiable data types will cause an error.
Caught excepted exception: <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>


You can also define the input spec right in the decorator `tf.function`. The spec can both limit the data type and the data shape.

In [14]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32), ))
def next_collect(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, x * 3 + 1)

print(next_collect(tf.constant([1, 2])))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)


In [15]:
print("This causes an error due to the shape incompatible.")
with assert_raises(ValueError):
  print(next_collect([[1, 2], [3, 4]]))

This causes an error due to the shape incompatible.
Caught excepted exception: <class 'ValueError'>


# Python or Tensor Args

It's normal to use Python arguments to control hyperparameters and the graph. However, it's possible that a Python argument is not being to control graph construction. Otherwise, the Autographs would be generated automatically but they are identical. This may cause low performance.

In [0]:
def train_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_step: {}".format(num_steps))
  for _ in tf.range(num_steps):
    train_step()

The following is an example unrolling the graph automatically with low performance.

In [17]:
train(num_steps=10)
train(num_steps=20)

Tracing with num_step: 10
Tracing with num_step: 20


in this example, you can simply cast the argument to the Tensor. 

In [18]:
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))

Tracing with num_step: Tensor("num_steps:0", shape=(), dtype=int32)


# Side Effects

In general, Python side effects (e.g. printing or mutating objects) only available during tracing. In Tensorflow, the best way to debug your traces is to use Tensorflow Ops like `tf.Variable.assign`, `tf.print` or `tf.summary`.

In [19]:
@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)  # the best way to do a trace in tf.function

f(1)
f(1)
f(2)
f(2)

Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2
Executed with 2


If you want to execute Python code during each invocation of a `tf.fucntion`, `tf.py_function` is an exit hatch. However, the `tf.py_function` is not portable, nor does it work in distributed setups.

In [20]:
external_list = []

def python_pri(x):
  print("python primitive effects")
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(python_pri, inp=[x], Tout=[])

f(1)
f(2)
f(3)
assert len(external_list) == 3
assert external_list[-1].numpy() == 3

python primitive effects
python primitive effects
python primitive effects


# Beware of Python State

The below code causes graph exploding, especially when the graph is big enough. The state from Python is kept by the runtime. However, such operations may encounter unexpected errors due to tracing issues.

In [0]:
external_var = tf.Variable(0)

@tf.function
def consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("The value of the external variable: ", external_var)

In [22]:
iterator = iter([1, 2, 3, 4])
consume_next(iterator)  # 1
consume_next(iterator)  # 3
consume_next(iterator)  # 6

The value of the external variable:  1
The value of the external variable:  2
The value of the external variable:  3


If you want to iterate the Python data, the safest way is to wrap it in a `tf.data.Dataset` and use the wrapped function in a for-loop.

In [0]:
def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

In [0]:
@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x)
  return loss

The `tf.data.Dataset` API could help not to unroll the whole graph at a time.

In [25]:
small_data = [(1, 1)] * 2
big_data = [(1, 1)] * 10

# unroll the graph
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

# unroll the graph not to be big
measure_graph_size(train, tf.data.Dataset.from_generator(
  lambda: small_data, (tf.int32, tf.int32)
))
measure_graph_size(train, tf.data.Dataset.from_generator(
  lambda: big_data, (tf.int32, tf.int32)
))

train([(1, 1), (1, 1)]) contains 8 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph


While loading Python/Numpy data to `tf.data.Dataset`, two common APIs are available, `tf.data.Dataset.from_generator` and `tf.data.Dataset.from_tensors`. The former keeps the data in Python and fetch it via `tf.py_function` which might have performance implications. The latter will bundle a copy of the data as a batch in a node of a graph, which might have memory implications.

# Automatic Control Dependencies

No need to add control dependencies because the `tf.function` is also designed to solve the execution order issues in the compute graph.

In [26]:
a = tf.Variable(1.0)
b = tf.Variable(2.0)

@tf.function
def f(x, y):
  a.assign(y * b)  # == 6.0
  b.assign_add(x * a)  # == 12 + 2
  return a + b  # == 20.0

f(2.0, 3.0)

<tf.Tensor: shape=(), dtype=float32, numpy=20.0>

# Variables

Eager mode and graph mode have different operations on variables. In the eager mode, the variables in the graph will be created in each call. On the contrary, the graph mode will be intended to use the same variables in the graph.

In [27]:
@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Caught excepted exception: <class 'ValueError'>


The non-ambiguous variable is fine on both modes. 

In [28]:
v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

print(v.numpy())
f(1)
print(v.numpy())
print(f(2))

1.0
2.0
tf.Tensor(4.0, shape=(), dtype=float32)


You can create a variable as long as you can guarantee the variables were created only for the first time.

In [29]:
class C:
  pass

obj = C
obj.v = None

@tf.function
def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

print(g(1.0))
print(g(2.0))

tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)


In [30]:
state = []

@tf.function
def fn(x):
  if not state:
    state.append(tf.Variable(2.0 * x))
    state.append(tf.Variable(state[0] * 3.0))
  return state[0] * x * state[1]

print(fn(tf.constant(1.0)))
print(fn(tf.constant(3.0)))

tf.Tensor(12.0, shape=(), dtype=float32)
tf.Tensor(36.0, shape=(), dtype=float32)


# Using AutoGraph

The autograph library will rewrite conditionals and loops which depend on Tensors to run dynamically in the graph.

In [31]:
@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))

[0.311216474 0.581785679 0.524046659 0.247471333 0.145412087]
[0.301543355 0.52396214 0.480817169 0.24254021 0.144395784]
[0.292724341 0.4807522 0.446897835 0.237893641 0.143400535]
[0.284640342 0.446845829 0.419345707 0.233505219 0.142425612]
[0.277194351 0.419302851 0.396379083 0.229351848 0.141470328]
[0.270306408 0.396342963 0.376846522 0.225413218 0.140534043]
[0.263909906 0.376815528 0.359965712 0.221671417 0.139616132]
[0.257948935 0.359938741 0.34518382 0.218110546 0.138716]
[0.252376109 0.345160067 0.332097411 0.214716449 0.137833074]
[0.247150928 0.332076281 0.320404112 0.21147649 0.13696681]
[0.242238626 0.320385158 0.30987227 0.208379313 0.136116698]
[0.237609103 0.309855133 0.300320894 0.205414727 0.135282248]
[0.233236179 0.300305277 0.291606247 0.202573508 0.134462968]
[0.229096949 0.291591942 0.28361252 0.199847311 0.133658409]
[0.225171268 0.283599377 0.276245236 0.197228581 0.132868141]
[0.221441343 0.276233107 0.269426405 0.194710419 0.132091746]
[0.21789141 0.269415

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.20239802, 0.24159752, 0.23701175, 0.18139602, 0.12770325],
      dtype=float32)>

You can inspect the function graph via the `tf.autocode.to_graph(func)` API.

In [32]:
def _f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

print(tf.autograph.to_code(_f))

def tf___f(x):
    do_return = False
    retval_ = ag__.UndefinedReturnValue()
    with ag__.FunctionScope('_f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:

        def get_state():
            return (x,)

        def set_state(loop_vars):
            nonlocal x
            (x,) = loop_vars

        def loop_body():
            nonlocal x
            ag__.converted_call(tf.print, (x,), None, fscope)
            x = ag__.converted_call(tf.tanh, (x,), None, fscope)

        def loop_test():
            return (ag__.converted_call(tf.reduce_sum, (x,), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        do_return = True
        retval_ = fscope.mark_return_value(x)
    (do_return,)
    return ag__.retval(retval_)



## AutoGraph: Conditionals

AutoGraph will convert `if` statements into the equivalent to `tf.cond` calls if the condition is a Tensor.

The following is the help function that checks if the resulting graph uses `tf.cond`.

In [0]:
def test_tf_cond(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == "cond" for node in g.as_graph_def().node):
    print("{}({}) uses tf.cond.".format(
      f.__name__, ', '.join(map(str, args))
    ))
  else:
    print("{}({}) executes normally.".format(
      f.__name__, ', '.join(map(str, args))
    ))
  
  print("Result: {}".format(f(*args).numpy()))

In [0]:
@tf.function
def dropout(x, training=True):
  if training:
    x = tf.nn.dropout(x, rate=0.5)
  return x

Convert to `tf.cond` as long as the conditional is a Tensor.

In [35]:
test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), True)

dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), True) executes normally.
Result: [0. 2. 0. 2. 0. 2. 2. 2. 2. 2.]


In [36]:
test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), tf.constant(True))

dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), tf.Tensor(True, shape=(), dtype=bool)) uses tf.cond.
Result: [0. 2. 0. 2. 0. 2. 0. 2. 2. 0.]


In [0]:
@tf.function
def f(x):
  if x > 0:
    x = x + 1
    print("Tracing `then` branch.")
  elif x == 0:
    x = x + 0
    print("Tracing `equal` branch.")
  else:
    x = x - 1
    print("Tracing `else` branch.")
  return x

In [38]:
f(-1).numpy()

Tracing `else` branch.


-2

In [39]:
f(1.0).numpy()

Tracing `then` branch.


2.0

The `tf.cond` traces all of the conditionals and chooses the correct branch at runtime.

In [40]:
f(tf.constant(1.0)).numpy()

Tracing `then` branch.
Tracing `equal` branch.
Tracing `else` branch.


2.0

If there is a branch returning Tensors, the `tf.cond` forces other branches to return Tensors.

In [0]:
@tf.function
def f():
  if tf.constant(True):
    x = tf.ones([3, 3])
  return x

In [42]:
with assert_raises(ValueError):
  f()

Caught excepted exception: <class 'ValueError'>


In [0]:
@tf.function
def f(x, y):
  if bool(x):
    y = y + 1.
    print("Tracing `then` branch.")
  else:
    y = y - 1.
    print("Tracing `else` branch.")
  return y

In [44]:
f(True, 0).numpy()

Tracing `then` branch.


1.0

In [45]:
f(False, 0).numpy()

Tracing `else` branch.


-1.0

In [46]:
with assert_raises(TypeError):
  f(tf.constant(True), tf.constant(0.0))

Caught excepted exception: <class 'TypeError'>


## AutoGraph: Loops

If a loop is converted, it will be unrolled with `tf.while_loop`. If there is a case of `for x in tf.data.Dataset`, it is transformed into `tf.data.Dataset.reduce`.

Let's build a help function.

In [0]:
def test_dynamically_unrolled(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'while' for node in g.as_graph_def().node):
    print("{}({}) uses tf.while_loop.".format(
        f.__name__, ', '.join(map(str, args))))
  elif any(node.name == 'ReduceDataset' for node in g.as_graph_def().node):
    print("{}({}) uses tf.data.Dataset.reduce.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) gets unrolled.".format(
        f.__name__, ', '.join(map(str, args))))

### For loops

The below `tf.function` demonstrates static unrolling.

In [48]:
@tf.function
def for_in_range():
  x = 0
  for i in range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_range)

for_in_range() gets unrolled.


In [49]:
@tf.function
def while_py_cond():
  x = 5
  while x > 0:
    x = x - 1
  return x

test_dynamically_unrolled(while_py_cond)

while_py_cond() gets unrolled.


If there is a `Tensor` or a `TF.data.Dataset` in the loop in the AutoGraph, it would be rewritten to `tf.while_loop` or `tf.data.Dataset.reduce`.

In [50]:
@tf.function
def for_in_tfrange():
  x = tf.constant(0, dtype=tf.int32)
  for i in tf.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfrange)

for_in_tfrange() uses tf.while_loop.


In [51]:
@tf.function
def for_in_tfdataset():
  x = tf.constant(0, dtype=tf.int64)
  for i in tf.data.Dataset.range(5):
    x += i
  return x

test_dynamically_unrolled(for_in_tfdataset)

for_in_tfdataset() uses tf.data.Dataset.reduce.


In [52]:
@tf.function
def while_tf_cond():
  x = tf.constant(5)
  while x > 5:
    x = x - 1
  return x

test_dynamically_unrolled(while_tf_cond)

while_tf_cond() uses tf.while_loop.


If you introduce a break or an early return clause that depends on a Tensor, the top-level condition or iterable should also be a Tensor.

In [53]:
@tf.function
def while_py_true_py_break(x):
  """Static unrolling."""
  while True:
    if x == 0:
      break
    x = x - 1
  return x

test_dynamically_unrolled(while_py_true_py_break, 5)

while_py_true_py_break(5) gets unrolled.


In [54]:
@tf.function
def buggy_while_py_true_py_break(x):
  while True:
    if tf.equal(x, 0):  # cause error
      break
    x = x - 1
  return

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_while_py_true_py_break, 5)

Caught excepted exception: <class 'TypeError'>


A correct way.

In [55]:
@tf.function
def while_tf_true_tf_break(x):
  while tf.constant(True):  # tf true
    if x == 0:
      break  # py break
    x = x - 1
  return x

test_dynamically_unrolled(while_tf_true_tf_break, 5)

while_tf_true_tf_break(5) uses tf.while_loop.


In [56]:
@tf.function
def buggy_py_for_tf_break():
  x = 0
  for i in range(5):  # py loop
    if tf.equal(i, 3):  # tf conditional
      break
    x += i
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_py_for_tf_break)

Caught excepted exception: <class 'TypeError'>


A correct way.

In [57]:
@tf.function
def tf_for_break():
  x = 0
  for i in tf.range(5):  # tf loop
    if i == 3:  # python conditional
      break
    x += i
  return x

test_dynamically_unrolled(tf_for_break)

tf_for_break() uses tf.while_loop.


In order to collect the results from a dynamically unrolled loop, you'll need to use `tf.TensorArray`.

In [0]:
batch_size = 2
seq_len = 3  # time series/points
feature_size = 4  # embedding / rnn states

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])

In [60]:
dynamic_rnn(rnn_step,
            tf.random.uniform(shape=[batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))

<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.25938368, 0.950837  , 0.2547896 , 0.54169357],
        [1.2350857 , 1.4474765 , 0.7446759 , 1.0271566 ],
        [1.8774089 , 2.1505756 , 1.3160257 , 1.7554648 ]],

       [[0.13711762, 0.5610384 , 0.5977385 , 0.7911644 ],
        [0.15460062, 0.64549196, 1.0810224 , 0.9349692 ],
        [0.19282866, 1.286826  , 1.3374677 , 1.5134813 ]]], dtype=float32)>

### Other Details

All tensors are required to be initialized before they are used in `tf.while_loop`.

In [74]:
@tf.function
def buggy_loop_var_uninitialized():
  for i in tf.range(3):
    x = i  # cause an error
  return x

with assert_raises(ValueError):
  test_dynamically_unrolled(buggy_loop_var_uninitialized)

Caught excepted exception: <class 'ValueError'>


In [81]:
x = tf.Variable(1)

@tf.function
def loop_var_initialized():
  for i in tf.range(3):
    x.assign(i)
  return x

loop_var_initialized()

<tf.Tensor: shape=(), dtype=int32, numpy=2>

The shapes or data types of all tensors in the loop must be consistent in all iterations.

In [83]:
@tf.function
def buggy_loop_type_changes():
  x = tf.constant(1.0, dtype=tf.float32)
  for i in tf.range(3):
    x = i  # cause an error
  return x

with assert_raises(TypeError):
  test_dynamically_unrolled(buggy_loop_type_changes)

Caught excepted exception: <class 'TypeError'>


In [84]:
@tf.function
def buggy_concat():
  x = tf.ones(shape=[0, 10])
  for i in tf.range(5):
    x = tf.concat([x, tf.ones(shape=[1, 10])], axis=0)  # cause an error
  return x

with assert_raises(ValueError):
  buggy_concat()

Caught excepted exception: <class 'ValueError'>


A correct way to do the concatenation.

In [90]:
@tf.function
def concat_with_padding():
  x = tf.zeros(shape=[5, 10], dtype=tf.float32)
  for i in tf.range(5):
    x = tf.concat([x[:i], tf.random.uniform(shape=[1,10]), x[i+1:]], axis=0)
    x.set_shape([5, 10])  # required to set_shape after a tf.concat operation
  return x

concat_with_padding()

<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
array([[0.6898327 , 0.13407004, 0.9046999 , 0.28530312, 0.9393835 ,
        0.7386488 , 0.1098659 , 0.46744478, 0.4397533 , 0.0653286 ],
       [0.96035933, 0.46595097, 0.08150244, 0.5987668 , 0.27322912,
        0.11421788, 0.9216167 , 0.583694  , 0.26626718, 0.46127534],
       [0.8354796 , 0.2922089 , 0.7615863 , 0.1654191 , 0.59055746,
        0.57912564, 0.09645224, 0.05547762, 0.8927189 , 0.24779224],
       [0.7678379 , 0.53827715, 0.8584336 , 0.4450512 , 0.7394295 ,
        0.61915326, 0.0353024 , 0.31575656, 0.8718381 , 0.95665157],
       [0.50390637, 0.6010597 , 0.6973202 , 0.33764052, 0.22260141,
        0.7390568 , 0.4174404 , 0.24913824, 0.47793007, 0.7276628 ]],
      dtype=float32)>