In [2]:
import tensorflow as tf
x = tf.Variable(1.0,dtype=tf.float32)

In [3]:
@tf.function(input_signature=[tf.TensorSpec(shape = [], dtype = tf.float32)])
def add_print(a):
    x.assign_add(a)
    tf.print(x)
    return(x)

In [4]:
add_print(tf.constant(3.0))

4


<tf.Tensor: id=18, shape=(), dtype=float32, numpy=4.0>

下面利用tf.Module的子类化将其封装一下：

In [5]:
class DemoModule(tf.Module):
    def __init__(self, init_value = tf.constant(0.0), name = None):
        super(DemoModule, self).__init__(name=name)
        with self.name_scope:
            self.x = tf.Variable(init_value, dtype = tf.float32, trainable=True)
    @tf.function(input_signature=[tf.TensorSpec(shape = [], dtype = tf.float32)])
    def addprint(self, a):
        with self.name_scope:
            self.x.assign_add(a)
            tf.print(self.x)
            return(self.x)

In [6]:
# 执行
demo = DemoModule(init_value = tf.constant(1.0))
result = demo.addprint(tf.constant(5.0))

6


In [7]:
# 查看模块中的全部变量和全部 
print(demo.variables)
print(demo.trainable_variables)

(<tf.Variable 'demo_module/Variable:0' shape=() dtype=float32, numpy=6.0>,)
(<tf.Variable 'demo_module/Variable:0' shape=() dtype=float32, numpy=6.0>,)


In [8]:
# 查看模型中的全部子模块
demo.submodules

()

In [9]:
# 使用tf.saved_model 保存模型，并指定需要跨平台部署的方法
tf.saved_model.save(demo,"./data/",signatures = {"serving_default":demo.addprint})

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: ./data/assets


In [10]:
# 加载模型
demo2 = tf.saved_model.load("./data/")
demo2.addprint(tf.constant(5.0))

11


<tf.Tensor: id=192, shape=(), dtype=float32, numpy=11.0>

### 查看模型文件相关信息， 红框标出来的输出信息在模型部署和跨平台使用时有可能会用到
! saved_model_model_cli show --dir ./data/ --all

In [14]:
import numpy as np

class MyModel(tf.keras.Model):
    
    def __init__(self, num_classes=10):
        super(MyModel, self).__init__(name='my_model')
        self.num_classes = num_classes
        # 定义自己需要的层
        self.dense_1 = tf.keras.layers.Dense(32, activation='relu')
        self.dense_2 = tf.keras.layers.Dense(num_classes)
        
    @tf.function(input_signature=[tf.TensorSpec([None,32],tf.float32)])   
    def call(self, inputs):
        # 定义前向传播
        # 使用'__init__'定义的层
        x = self.dense_1(inputs)
        return self.dense_2(x)

In [15]:
data = np.random.random((1000,32))
labels = np.random.random((1000,10))

# Instantiate an optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function
loss_fn = tf.keras.losses.CategoricalCrossentropy()

# Prepare the training dataset
batch_size =64
train_dataset = tf.data.Dataset.from_tensor_slices((data, labels))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

In [17]:
model = MyModel(num_classes=10)
epochs = 3
for epoch in range(epochs):
    print('Start of epoch %d ' % (epoch,))
    
    # 遍历数据集的batch_size
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train)
            loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        
        # 每200 batchs 打印一次
        if step % 200 == 0:
            print('Training loss (for one batch) at step %s: %s' % (step, float(loss_value)))
            print('Seen so far: %s samples ' % ((step + 1) * 64))

Start of epoch 0 


To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Training loss (for one batch) at step 0: 33.57910919189453
Seen so far: 64 samples 
Start of epoch 1 
Training loss (for one batch) at step 0: 18.434566497802734
Seen so far: 64 samples 
Start of epoch 2 
Training loss (for one batch) at step 0: 16.910335540771484
Seen so far: 64 samples 


In [18]:
tf.saved_model.save(model, 'my_saved_model')

INFO:tensorflow:Assets written to: my_saved_model\assets
