## [参考](https://www.tensorflow.org/guide/keras/custom_layers_and_models)

In [10]:
import tensorflow as tf
from tensorflow import keras

In [14]:
class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        w_init = tf.random_normal_initializer()
        # 入力数行, 出力数列の重み
        self.w = tf.Variable(
            initial_value=w_init(shape=(input_dim, units), dtype="float32"),
            trainable=True
        )
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(
            initial_value=b_init(shape=(units,), dtype="float32"), trainable=True
        )
    
    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

In [15]:
x = tf.ones((2, 2))
linear_layer = Linear(4, 2)
print(linear_layer.w)
print(linear_layer.b)
y = linear_layer(x)
print(y)

<tf.Variable 'Variable:0' shape=(2, 4) dtype=float32, numpy=
array([[-0.04163347,  0.02341541, -0.02248148, -0.06662707],
       [ 0.00872479,  0.01393286,  0.01271552, -0.02699128]],
      dtype=float32)>
<tf.Variable 'Variable:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>
tf.Tensor(
[[-0.03290868  0.03734827 -0.00976597 -0.09361836]
 [-0.03290868  0.03734827 -0.00976597 -0.09361836]], shape=(2, 4), dtype=float32)


In [16]:
assert linear_layer.weights == [linear_layer.w, linear_layer.b]

## add_weight()を使用すると簡単

In [18]:
class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super(Linear, self).__init__()
        self.w = self.add_weight(
            shape=(input_dim, units), initializer="random_normal", trainable=True
        )
        self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)
    
    def call(self, inputs):
        """
        ただの乗算
        """
        return tf.matmul(inputs, self.w) + self.b

In [20]:
x = tf.ones((2, 2))
linear_layer = Linear(4, 2)
print(linear_layer.w)
print(linear_layer.b)
y = linear_layer(x)
print(y)

<tf.Variable 'Variable:0' shape=(2, 4) dtype=float32, numpy=
array([[ 0.00407009, -0.00869892, -0.0434149 ,  0.04944688],
       [ 0.00424434,  0.05113566,  0.04510517, -0.01551912]],
      dtype=float32)>
<tf.Variable 'Variable:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>
tf.Tensor(
[[0.00831443 0.04243674 0.00169028 0.03392775]
 [0.00831443 0.04243674 0.00169028 0.03392775]], shape=(2, 4), dtype=float32)


## トレーニングできない重み

In [21]:
class ComputeSum(keras.layers.Layer):
    def __init__(self, input_dim):
        super(ComputeSum, self).__init__()
        self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)
    
    def call(self, inputs):
        self.total.assign_add(tf.reduce_sum(inputs, axis=0))
        return self.total

In [27]:
x = tf.ones((2, 2))
my_sum = ComputeSum(2)
y = my_sum(x)
print(y.numpy())
y = my_sum(x)
print(y.numpy())

[2. 2.]
[4. 4.]


In [30]:
print(f"weights: {len(my_sum.weights)}")
print(f"non-trainable weights: {len(my_sum.non_trainable_weights)}")

print(f"trainable_weights: {my_sum.trainable_weights}")

weights: 1
non-trainable weights: 1
trainable_weights: []


## 入力の形状が分かってから重み付け

In [38]:
class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super(Linear, self).__init__()
        self.units = units
    
    def build(self, input_shape):
        print("ビルド")
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )
    
    def call(self, inputs):
        print("コール")
        return tf.matmul(inputs, self.w) + self.b

In [39]:
linear_layer = Linear(32)

x = tf.ones((2, 2))
y = linear_layer(x)
y

ビルド
コール


<tf.Tensor: shape=(2, 32), dtype=float32, numpy=
array([[ 0.06242579, -0.05435263, -0.06270741, -0.00311189,  0.03304204,
         0.06391817, -0.14059798, -0.03104004,  0.06532896,  0.16151224,
         0.03892711, -0.05108436, -0.12631126, -0.02692964, -0.02476245,
         0.0265835 ,  0.00109719, -0.06123249,  0.0980624 , -0.07449006,
        -0.01964549,  0.08673774, -0.02240579, -0.0448319 , -0.05188194,
         0.01540685,  0.05266732,  0.24476007, -0.03392565, -0.1852138 ,
         0.05699131,  0.10827867],
       [ 0.06242579, -0.05435263, -0.06270741, -0.00311189,  0.03304204,
         0.06391817, -0.14059798, -0.03104004,  0.06532896,  0.16151224,
         0.03892711, -0.05108436, -0.12631126, -0.02692964, -0.02476245,
         0.0265835 ,  0.00109719, -0.06123249,  0.0980624 , -0.07449006,
        -0.01964549,  0.08673774, -0.02240579, -0.0448319 , -0.05188194,
         0.01540685,  0.05266732,  0.24476007, -0.03392565, -0.1852138 ,
         0.05699131,  0.10827867]], dtyp