# Concrete functions

## Setup

In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

In [None]:
import traceback
import textwrap

try:
    !pip install -q tf-nightly
except Exception:
    pass

In [None]:
import tensorflow as tf

## Create a tf.function

In [None]:
@tf.function
def square(x):
    return x*x

In [None]:
square(2).numpy()

In [None]:
def pow(x,y):
    return x ** y

pow = tf.function(pow)

In [None]:
pow(3,4).numpy()

### Attach a tf.function method to a tf.Module

In [None]:
class Pow(tf.Module):
    def __init__(self, exponent):
        self.exponent = tf.Variable(exponent, dtype = tf.float32, name='Pow/exponent')

    @tf.function
    def __call__(self, x):
        return x ** self.exponent

In [None]:
pow = Pow(3)

In [None]:
pow.variables

In [None]:
pow(tf.constant(2.0)).numpy()

In [None]:
pow.exponent.assign(4)
pow(tf.constant(2.0)).numpy()

In [None]:
tf.saved_model.save(pow, 'pow')

In [None]:
reloaded_pow = tf.saved_model.load('pow')

In [None]:
reloaded_pow(tf.constant(3.0)).numpy()

### Assign a tf.function as an attribute

In [None]:
mod = tf.Module()
mod.increment_by = tf.Variable(2.0)

@tf.function
def increment(x):
    return x+mod.increment_by

mod.inc = increment
mod.inc(tf.constant(1.0)).numpy()

In [None]:
mod.cube = Pow(3)
mod.cube(tf.constant(2.0)).numpy()

In [None]:
mod.variables

In [None]:
tf.saved_model.save(mod, 'mod')
reloaded_mod = tf.saved_model.load('mod')

In [None]:
reloaded_mod.inc(4.0).numpy()

In [None]:
reloaded_mod.cube(4.0).numpy()

### Interoperability with tf.keras

In [None]:
linear = tf.keras.Sequential([tf.keras.layers.Dense(units=1, input_shape=[1])])
linear.compile(optimizer='adam', loss='mean_squared_error')
linear.fit(x=[-1, 0, 1, 2, 3, 4], y=[-3, -1, 1, 3, 5, 7], epochs=50, verbose=0)

In [None]:
linear(tf.constant([[1],[2]]))

In [None]:
linear.variables

In [None]:
module = tf.Module()
module.linear = linear

In [None]:
module.variables

In [None]:
tf.saved_model.save(module,'module')

In [None]:
reloaded = tf.saved_model.load('module')

In [None]:
reloaded.linear([[1.0]])

## Tracing

In [None]:
@tf.function
def mul(a, b):
    print('Tracing:\n    {a}\n    {b}\n'.format(a=a, b=b))
    return a*b

### Dtypes and shapes

In [None]:
# Trace with ints
mul(tf.constant(2), tf.constant(3)).numpy()

In [None]:
# Trace with floats
mul(tf.constant(2.0), tf.constant(3.0)).numpy()

In [None]:
# Call with ints again => no trace
mul(tf.constant(10), tf.constant(10))

In [None]:
# Trace with vectors
mul(tf.constant([1.0,3.0]), tf.constant(3.0)).numpy()

In [None]:
# Trace with different-sized vectors
mul(tf.constant([1.0,2.0,3.0, 4.0]), tf.constant(3.0))

### Immutable python objects

In [None]:
@tf.function
def mul(a, b):
    print('Tracing:\n    {a}\n    {b}\n'.format(a=a, b=b))
    return a*b

In [None]:
# Trace for a=3.0
mul(3.0, tf.constant(3.0)).numpy()

In [None]:
# Don't trace for a=3.0 the second time:
mul(3.0, tf.constant(3.0)).numpy()

In [None]:
@tf.function
def power(a,b):
    print('Tracing "power": a={}'.format(a))
    return a**b

In [None]:
p = tf.constant(2)
for n in range(12):
    power(n,p)

In [None]:
p = tf.constant(2)
for n in range(12):
    power(n,p)

In [None]:
p = tf.constant(2)
for n in tf.range(12):
    power(n,p)

In [None]:
@tf.function(input_signature=(
    tf.TensorSpec(shape=[], dtype=tf.float32),
    tf.TensorSpec(shape=[], dtype=tf.float32),)
)
def power_with_sig(a,b):
    print('Tracing "power_with_sig"')
    return a**b

In [None]:
power_with_sig(3.0, 3.0).numpy()

In [None]:
try:
    power_with_sig(tf.constant([1.0,2.0,3.0]),tf.constant(3.0))
    assert False
except ValueError:
    traceback.print_exc(limit=1)

### Example: Dropout

In [None]:
class Dropout(tf.Module):
    def __init__(self, rate, name=None):
        super(Dropout, self).__init__(name)
        self.rate = tf.Variable(rate, dtype = tf.float32, trainable=False)

    @tf.function
    def __call__(self, x, training=True):
        print(textwrap.dedent("""
                              Tracing "Dropout":
                                  training = {}
                                  x = {}
                                  name = {:s}
                              """.format(training, x, self.name)))
        if training:
            print('    - Train branch\n')
            mask = tf.random.uniform(x.shape) > self.rate
            return x * tf.cast(mask, tf.float32)/self.rate
        else:
            print('    - Test branch\n')
            return x

In [None]:
dropout = Dropout(0.5)

In [None]:
dropout(tf.range(10, dtype=tf.float32), training=True).numpy()

In [None]:
dropout(tf.range(10, dtype=tf.float32), training=True).numpy()

In [None]:
dropout(tf.range(10, dtype=tf.float32), training=False).numpy()

In [None]:
dropout(tf.range(10, dtype=tf.float32), training=False).numpy()

In [None]:
dropout(tf.range(10, dtype=tf.float32), training=tf.constant(False)).numpy()

In [None]:
 dropout(tf.range(10, dtype=tf.float32), training=tf.constant(True)).numpy()

In [None]:
dropout(tf.range(10, dtype=tf.float32), training=tf.constant(False)).numpy()

### Other python objects

In [None]:
dropout_a = Dropout(0.5, name='dropout_a')

In [None]:
print(dropout_a(tf.range(10, dtype=tf.float32), True).numpy())
print(dropout_a(tf.range(10, dtype=tf.float32), True).numpy())

In [None]:
dropout_b = Dropout(0.5, name='dropout_b')

In [None]:
print(dropout_b(tf.range(10, dtype=tf.float32), True).numpy())
print(dropout_b(tf.range(10, dtype=tf.float32), True).numpy())

In [None]:
@tf.function
def run(callable, x):
    print('Tracing "run":\n    callable = {}\n    x = {}\n'.format(callable, x))
    return callable(x)

In [None]:
def plus_1(x):
    return x+1

print(run(plus_1, tf.constant(2.0)).numpy())
print(run(plus_1, tf.constant(5.0)).numpy())

In [None]:
print(run(dropout, tf.range(10.0)).numpy())
print(run(dropout, tf.range(10.0)).numpy())

### Weak references

In [None]:
@tf.function
def plus_var(x):
    print('Tracing "plus_var":\n    x = {}\n    var = {}\n\n'.format(x, var.name))
    return x + var

In [None]:
var = tf.Variable(1, name="IntVar")
plus_var(tf.constant([1,2])).numpy()

In [None]:
var = tf.Variable(2.0, name="FloatVar")
plus_var(tf.constant([2.0, 10.0])).numpy()

In [None]:
try:
    plus_var(tf.constant([1,2])).numpy()
    assert False
except tf.errors.FailedPreconditionError:
    traceback.print_exc(limit=1)

## Accessing concrete function

#### Using input_signature

In [None]:
@tf.function(input_signature=(
    tf.TensorSpec(shape=[None], dtype=tf.float32),
    tf.TensorSpec(shape=[None], dtype=tf.float32),)
)
def power_with_sig(a,b):
    print('Tracing "power_with_sig"\n')
    return a**b

In [None]:
p = power_with_sig.get_concrete_function()
type(p)

In [None]:
p(tf.constant([2.0,3.0,4.0]), tf.constant([5.0,4.0,3.0])).numpy()

### Using get_concrete_function

In [None]:
@tf.function
def power(a,b):
    print('Tracing "power"\n')
    return a**b

In [None]:
float_power = power.get_concrete_function(
  a = tf.TensorSpec(shape=[], dtype=tf.float32),
  b = tf.TensorSpec(shape=[], dtype=tf.float32))

In [None]:
float_power(tf.constant(3.0),tf.constant(3.0))

In [None]:
row = tf.range(10)
col = tf.constant([[1],[2],[3]])

concrete_power = power.get_concrete_function(a = row, b = col)
concrete_power(row, col).numpy()

## Using a concrete function

In [None]:
float_power(tf.constant(2.0), tf.constant(3.0)).numpy()

In [None]:
try:
    float_power(2.0,3.0)
    assert False
except ValueError:
    traceback.print_exc(limit=1)

In [None]:
try:
    float_power(tf.constant(1),tf.constant(3))
    assert False
except tf.errors.InvalidArgumentError:
    traceback.print_exc(limit=1)

In [None]:
float_power(tf.constant([1.,2.,3.,4.,5.]),tf.constant(3.)).numpy()

In [None]:
try:
    float_power(tf.constant([1.,2.,3.]),tf.constant([4., 5.])).numpy()
    assert False
except tf.errors.InvalidArgumentError:  
    traceback.print_exc(limit=1)

In [None]:
print(float_power.structured_input_signature)
print(float_power.structured_outputs)

## Python Objects in signatures

In [None]:
cube = power.get_concrete_function(
    a = tf.TensorSpec([], dtype=tf.float32),
    b = 3.0)

In [None]:
print(cube.structured_input_signature)

In [None]:
cube(tf.constant(10.0)).numpy()

In [None]:
class Greeter(object):
    def __init__(self, greeting):
        self.greeting = greeting

    def greet(self, who):
        return " ".join([self.greeting, who])

p = Greeter("Hello")
m = p.greet
print(m)

In [None]:
print(m("TensorFlow!"))

In [None]:
class MyModel(tf.Module):
    def __init__(self, ins, outs):
        initializer = tf.initializers.GlorotNormal()
        self.W = tf.Variable(initializer([ins, outs]))
        self.B = tf.Variable(tf.zeros([outs], dtype = tf.float32))

    @tf.function
    def run(self, x):
        print('Tracing "MyModule":\n    x={}\n'.format(x))
        return tf.matmul(x, self.W)+self.B

In [None]:
mod = MyModel(ins=5, outs=3)

In [None]:
mod.run([[1.0,1.0,1.0, 1.0, 1.0]]).numpy()

In [None]:
concrete_run = mod.run.get_concrete_function(x = tf.TensorSpec([None, None]))

In [None]:
concrete_run(tf.constant([[1.0,1.0,1.0, 1.0, 1.0],
                          [2.0,2.0,2.0, 2.0, 2.0]])).numpy()

In [None]:
print(concrete_run.structured_input_signature)
print(concrete_run.structured_outputs)

## Accessing concrete functions from a SavedModel

In [None]:
dropout = Dropout(0.5)

_ = dropout(tf.range(10, dtype=tf.float32), tf.constant(True))
_ = dropout(tf.random.normal([2, 3]), tf.constant(True))

In [None]:
export_dir = 'dropout'
tf.saved_model.save(dropout, export_dir)

### Direct access

In [None]:
reloaded_dropout = tf.saved_model.load(export_dir)

In [None]:
print(reloaded_dropout(tf.range(10, dtype=tf.float32), tf.constant(False)).numpy())
print(reloaded_dropout(tf.random.normal([2,3]), tf.constant(True)).numpy())

In [None]:
try:
    reloaded_dropout(tf.range(12, dtype=tf.float32), tf.constant(True))
    assert False
except ValueError:
    traceback.print_exc(limit=1)

In [None]:
cf = reloaded_dropout.__call__.get_concrete_function(
    x = tf.TensorSpec([10]), 
    training = tf.TensorSpec([], tf.bool))

In [None]:
result = cf(tf.range(10, dtype=tf.float32), tf.constant(True)).numpy()
print(result)

### Named signatures: Exporting for C++

#### Simple example

In [None]:
dropout = Dropout(0.5)

In [None]:
cf = dropout.__call__.get_concrete_function(tf.zeros((2,3), dtype=tf.float32), tf.constant(False))

import time
export_dir = "./saved/" # +str(time.time())

tf.saved_model.save(dropout, export_dir, signatures = cf)

In [None]:
reloaded = tf.saved_model.load(export_dir)

print(reloaded.signatures)

In [None]:
cf = reloaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
result = cf(x=tf.random.normal([2,3]), training=tf.constant(True))

print(result)

In [None]:
cf.structured_outputs

#### Example: Setting the output names

In [None]:
@tf.function
def named_result(x, training=True):
    return {'dropout': dropout(x, training)}

dropout.named_result = named_result

cf = dropout.named_result.get_concrete_function(tf.zeros((2,3), dtype=tf.float32),
                                                tf.constant(False))

#### Example: Setting the signature names

In [None]:
export_dir = "./saved/"  # +str(time.time())
tf.saved_model.save(dropout, export_dir, signatures = {'simple':cf})

In [None]:
reloaded = tf.saved_model.load(export_dir)
cf = reloaded.signatures['simple']
result = cf(x=tf.random.normal([2,3]), training=tf.constant(True))

print({key:value.numpy() for key,value in result.items()})

In [None]:
vector = dropout.__call__.get_concrete_function(tf.TensorSpec((2,3), dtype=tf.float32), tf.constant(False))
matrix = dropout.__call__.get_concrete_function(tf.TensorSpec((2,3), dtype=tf.float32), tf.constant(False))
cube = dropout.__call__.get_concrete_function(tf.TensorSpec((2,3), dtype=tf.float32), tf.constant(False))

export_dir = "./saved/" # +str(time.time())

tf.saved_model.save(dropout, export_dir, 
                    signatures = {
                        "vector": vector,
                        "matrix": matrix,
                        "cube": cube
                    })

In [None]:
reloaded = tf.saved_model.load(export_dir)
print('{}'.format(reloaded.signatures).replace("{","{\n    ").replace(">, ", ">,\n    "))