#### 不含模型参数的自定义层

In [1]:
# 定义一个将输入减掉均值的层CenteredLayer
from mxnet import nd, gluon
from mxnet.gluon import nn

class CenteredLayer(nn.Block):
    def __init__(self, **kwargs):
        super(CenteredLayer, self).__init__(**kwargs)
    
    def forward(self, x):
        return x - x.mean()

In [2]:
layer = CenteredLayer()
layer(nd.array([1, 2, 3, 4, 5]))


[-2. -1.  0.  1.  2.]
<NDArray 5 @cpu(0)>

In [3]:
net = nn.Sequential()
with net.name_scope():
    net.add(nn.Dense(128))
    net.add(nn.Dense(10))
    net.add(CenteredLayer())

In [4]:
net.initialize()
y = net(nd.random.uniform(shape=(4, 8)))
y.mean()


[  3.25962896e-10]
<NDArray 1 @cpu(0)>

#### 含模型参数的自定义层

In [5]:
params = gluon.ParameterDict(prefix='block1_')
params.get('param2', shape=(2, 3))
params

block1_ (
  Parameter block1_param2 (shape=(2, 3), dtype=<class 'numpy.float32'>)
)

实现一个含权重和偏差函数的全连接层。使用ReLU作为激活函数。其中 in_units和units分别是输入单元个数和输出单元个数。

In [10]:
class MyDense(nn.Block):
    def __init__(self, units, in_units, **kwargs):
        super(MyDense, self).__init__(**kwargs)
        with self.name_scope():
            self.weight = self.params.get('weight', shape=(in_units, units))
            self.bias = self.params.get('bias', shape=(units, ))
        
    def forward(self, x):
        linear = nd.dot(x, self.weight.data()) + self.bias.data()
        return nd.relu(linear)

In [11]:
dense = MyDense(5, in_units=10, prefix='o_my_dense_')
dense.params

o_my_dense_ (
  Parameter o_my_dense_weight (shape=(10, 5), dtype=<class 'numpy.float32'>)
  Parameter o_my_dense_bias (shape=(5,), dtype=<class 'numpy.float32'>)
)

In [12]:
dense.initialize()
dense(nd.random.uniform(shape=(2, 10)))


[[ 0.          0.09092736  0.          0.17156085  0.        ]
 [ 0.          0.06395531  0.          0.09730551  0.        ]]
<NDArray 2x5 @cpu(0)>

In [13]:
net = nn.Sequential()
with net.name_scope():
    net.add(MyDense(32, in_units=64))
    net.add(MyDense(2, in_units=32))
net.initialize()
net(nd.random.uniform(shape=(2, 64)))


[[ 0.  0.]
 [ 0.  0.]]
<NDArray 2x2 @cpu(0)>