# RNN Welding

In [1]:
# 导入
import sys

sys.path.append("E:/dataFiles/github/MFlow")

In [2]:
# 数据生成
from data.generater import waveData

seq_len = 96  # 序列长度
in_dim = 16  # 输入维度
state_dim = 12  # 状态维度

xs, ys = waveData(1000, in_dim, seq_len)

train_xs = xs[:700]
train_ys = ys[:700]
test_xs = xs[700:]
test_ys = ys[700:]

print(train_xs.shape, train_ys.shape)
print(train_xs[0], train_ys[0])

(700, 96, 16) (700, 2)
[[-0.47902943 -0.33888471 -0.15164262 ...  0.29760411  0.08401242
   0.6191582 ]
 [ 0.56697832 -1.08823451  0.24312622 ...  1.38201263  0.60539122
   0.08490446]
 [-0.74879086  0.20460376  0.34146502 ...  0.37672354 -0.38006992
   0.25312174]
 ...
 [ 0.79162252 -0.6409203  -0.01049418 ... -0.66999013  0.43906847
  -0.65729462]
 [-1.32582854  0.31737572  0.71984772 ... -1.16742353 -1.10996313
  -0.67468318]
 [-0.38289001 -1.70425701 -0.4322597  ...  0.00871695 -1.11056038
  -0.37275389]] [1. 0.]


In [3]:
# 训练
import numpy as np
from mflow import core, ops, opts, lays

# 超参数
lr = 0.005
epoch = 30
batch_size = 16

with core.NameScope("RNNWelding"):
    # 初始化变量
    inputs = [core.Variable(size=(in_dim, 1), trainable=False) for _ in range(seq_len)]
    u = core.Variable(size=(state_dim, in_dim), trainable=True)
    w = core.Variable(size=(state_dim, state_dim), trainable=True)
    b = core.Variable(size=(state_dim, 1), trainable=True)
    y = core.Variable(size=(2, 1), trainable=False)
    last_step = None
    hiddens = []
    # 网络构建
    for iv in inputs:
        h = ops.Add(ops.MatMal(u, iv), b)
        if last_step is not None:
            h = ops.Add(ops.MatMal(w, last_step), h)
        h = ops.ReLU(h)
        last_step = h
        hiddens.append(last_step)
    # 焊接点
    welding_point = ops.Welding()
    # 全连接网络
    fc_1 = lays.Linear(welding_point, state_dim, 40, "ReLU")
    fc_2 = lays.Linear(fc_1, 40, 10, "ReLU")
    pred = lays.Linear(fc_2, 10, 2, None)
    predicter = ops.Logistic(pred)
    loss = ops.loss.CrossEntropyWithSoftMax(pred, y)
    loss.eps = 1e-6  # 避免梯度消失
    adam = opts.Adam(core.DefaultGraph, loss, lr)
    # 开始训练
    for ep in range(1, epoch + 1):
        bs_idx = 0  # 批次计数
        # 这是一个epoch的过程
        for i, (feat, lab) in enumerate(zip(train_xs, train_ys)):
            # 取变长序列
            lens = len(feat)
            start = np.random.randint(lens // 3)
            end = np.random.randint(lens // 3 + 30, lens)
            feat = feat[start: end]
            # 变长输入进入向量节点
            for j in range(len(feat)):
                inputs[j].setValue(np.mat(feat[j]).T)
            # 网络焊接
            welding_point.weld(hiddens[j])
            y.setValue(np.mat(lab).T)
            adam.step()
            bs_idx += 1
            if bs_idx == batch_size:
                if (i + 1) % 64 == 0:
                    print("Epoch: {:d}, itet: {:d}, loss: {:.7f}.".format(
                        ep, i + 1, loss.value[0, 0]))
                adam.update()
                bs_idx = 0
        # 一个epoch完成后进行评估
        preds = []
        for feat in test_xs:
            lens = len(feat)
            start = np.random.randint(lens // 3)
            end = np.random.randint(lens // 3 + 30, lens)
            feat = feat[start: end]
            for j in range(len(feat)):
                inputs[j].setValue(np.mat(feat[j]).T)
            welding_point.weld(hiddens[j])
            predicter.forward()
            preds.append(predicter.value.A.ravel())  # 结果
        preds = np.array(preds).argmax(axis=1)
        trues = test_ys.argmax(axis=1)
        acc = (trues == preds).astype("uint8").sum() / len(test_xs)
        print("Epoch: {:d}, acc: {:.3f}.".format(ep, acc))

Epoch: 1, itet: 64, loss: 0.6957231.
Epoch: 1, itet: 128, loss: 0.6877459.
Epoch: 1, itet: 192, loss: 0.7096018.
Epoch: 1, itet: 256, loss: 0.7158942.
Epoch: 1, itet: 320, loss: 0.6733113.
Epoch: 1, itet: 384, loss: 0.6466772.
Epoch: 1, itet: 448, loss: 0.7469688.
Epoch: 1, itet: 512, loss: 0.6260917.
Epoch: 1, itet: 576, loss: 0.7846410.
Epoch: 1, itet: 640, loss: 0.8065098.
Epoch: 1, acc: 0.470.
Epoch: 2, itet: 64, loss: 0.1719376.
Epoch: 2, itet: 128, loss: 0.6756439.
Epoch: 2, itet: 192, loss: 0.9503417.
Epoch: 2, itet: 256, loss: 0.6805741.
Epoch: 2, itet: 320, loss: 0.4258916.
Epoch: 2, itet: 384, loss: 0.2345144.
Epoch: 2, itet: 448, loss: 0.0479421.
Epoch: 2, itet: 512, loss: 0.7575343.
Epoch: 2, itet: 576, loss: 0.0099397.
Epoch: 2, itet: 640, loss: 0.4588063.
Epoch: 2, acc: 0.777.
Epoch: 3, itet: 64, loss: 1.1665633.
Epoch: 3, itet: 128, loss: 0.0009733.
Epoch: 3, itet: 192, loss: 1.7433162.
Epoch: 3, itet: 256, loss: 0.0020654.
Epoch: 3, itet: 320, loss: 0.3440103.
Epoch: 3,