In [17]:
import tensorflow as tf

In [5]:
def scaled_elu(z, scale=1.0, alpha=1.0):
    # z >= 0 ? scale * z : scale * alpha * tf.nn.elu(z)
    flag = tf.greater_equal(z, 0.0)
    return scale * tf.where(flag, z, alpha * tf.nn.elu(z))


print(scaled_elu(tf.constant(-3.)))
print(scaled_elu(tf.constant([-3, -2.5])))

tf.Tensor(-0.95021296, shape=(), dtype=float32)
tf.Tensor([-0.95021296 -0.917915  ], shape=(2,), dtype=float32)


#### Convert python function to TF graph

In [9]:
scaled_elu_tf = tf.function(scaled_elu)
print(scaled_elu_tf(tf.constant(-3.)))
print(scaled_elu_tf(tf.constant([-3, -2.5])))

tf.Tensor(-0.95021296, shape=(), dtype=float32)
tf.Tensor([-0.95021296 -0.917915  ], shape=(2,), dtype=float32)


2021-08-08 14:11:12.238627: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
2021-08-08 14:11:12.250213: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


In [11]:
print(scaled_elu_tf.python_function is scaled_elu)

True


In [20]:
% timeit scaled_elu(tf.random.normal((1000, 1000)))
% timeit scaled_elu_tf(tf.random.normal((1000, 1000)))

8.31 ms ± 132 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


2021-08-08 14:18:32.486285: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


7.71 ms ± 340 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
# 1 + 1/2 + 1/2^2 + ... + 1/2^n

@tf.function
def converge_to_2(n):
    total = tf.constant(0.)
    t = tf.constant(1.)
    for _ in range(n):
        total += t
        t /= 2.0
    return total


print(converge_to_2(20))

tf.Tensor(1.9999981, shape=(), dtype=float32)


2021-08-08 15:05:48.345061: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


In [29]:
var = tf.Variable(0.)


@tf.function
def add_21():
    # TF 中的变量不能使用运算符，并且变量要在方法前初始化。
    return var.assign_add(21)


print(add_21())

tf.Tensor(21.0, shape=(), dtype=float32)


2021-08-08 15:11:45.990138: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.


In [33]:
@tf.function
def cube(z):
    return tf.pow(z, 3)


print(cube(tf.constant([1, 2, 3])))


tf.Tensor([ 1  8 27], shape=(3,), dtype=int32)


2021-08-08 15:14:21.246627: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:112] Plugin optimizer for device_type GPU is enabled.
