[toc]

# Tensorflow Layer

比较推荐的实践是继承 `tf.keras.layers.Layer`，并定义 `__init__`, `build`, `call` 三个函数。其中 `__init__` 一般做一些和形状无关的初始化，而 `build` 一般做一些和形状有关的初始化，而 `call` 指前向传播。

## 推荐的作法

In [1]:
import tensorflow as tf

print(tf.__version__)


class MyDenseLayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs):
        super(MyDenseLayer, self).__init__()
        self.num_outputs = num_outputs

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.kernel = self.add_weight(
            "kernel", shape=[input_dim, self.num_outputs])
        self.bias = self.add_weight("bias", shape=[1, self.num_outputs])

    def call(self, input):
        return tf.matmul(input, self.kernel) + self.bias

n_samples = 2
indim = 4
outdim = 3 
layer = MyDenseLayer(outdim)
layer.build(input_shape=(None, indim))
x = tf.random.normal([n_samples, indim])
layer(x)

2.1.0


<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[ 1.0828593, -0.3148812,  1.2370515],
       [ 2.5652494, -1.3047172, -1.8012459]], dtype=float32)>

In [2]:
print([var.name for var in layer.trainable_variables])

['my_dense_layer/kernel:0', 'my_dense_layer/bias:0']


## 不重写 build

其中，`build` 不是必须的，也可以不重写 `build` 函数，并在 `__init__` 中做那么和形状有关的初始化，但是 tensorlfow 不建议这样做，具体原因可以看 [ 1 ]

In [3]:
import tensorflow as tf
from tensorflow import keras

class MyDense(keras.layers.Layer):
    def __init__(self, indim, outdim):
        super(MyDense, self).__init__()
        self.kernel = self.add_weight("kernel", shape=(indim, outdim))
        self.bias = self.add_weight("bias", shape=[1, outdim])

    def call(self, inputs, training=None):
        out = tf.matmul(inputs, self.kernel) + self.bias
        out = tf.nn.relu(out)
        return out

n_samples = 2
indim = 4
outdim = 3
x = tf.random.normal([n_samples, indim])
mydense = MyDense(indim, outdim)
mydense(x)

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0.91596544, 0.9862181 , 0.23890126],
       [0.9338381 , 3.0983229 , 0.        ]], dtype=float32)>

# References

1. [Custom layers  |  TensorFlow Core](https://tensorflow.google.cn/tutorials/customization/custom_layers?hl=zh-cn)