### LSTM 实现 

(https://zhuanlan.zhihu.com/p/81549798)[https://zhuanlan.zhihu.com/p/81549798]

(https://zhuanlan.zhihu.com/p/32085405)[https://zhuanlan.zhihu.com/p/32085405]


https://zhuanlan.zhihu.com/p/54868269




In [182]:
import tensorflow as tf
from tensorflow.python.keras import initializers
from tensorflow.python.keras import backend as K
from tensorflow.python.ops import array_ops
from tensorflow.python.keras import activations
from tensorflow.python.util.tf_export import keras_export
import numpy as np
from tensorflow import keras
import os
import time
import pickle
print('tf version: ', tf.__version__)
print('GPU : ', tf.test.is_gpu_available())
print('GPU list', tf.config.list_physical_devices('GPU'))

tf version:  2.1.0
GPU :  False
GPU list []


In [207]:
class LSTM_CELL(tf.keras.layers.Layer):
    def __init__(self, units=256, **kwargs):
        # lstm 维度
        self.units = units
        super(LSTM_CELL, self).__init__(**kwargs)

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.w = self.add_weight(shape=(input_dim, self.units * 4), name='kernel',
            initializer=initializers.get('glorot_uniform'))
        print("LSTM w.shape: {}".format(self.w.shape))
        

        # u 保存hadden 的权重
        self.u = self.add_weight(shape=(self.units, self.units * 4),
                                                name='recurrent_kernel',
                                                initializer=initializers.get('orthogonal'))
        print("LSTM u.shape: {}".format(self.u.shape))
        self.bias = self.add_weight(
            shape=(self.units * 4), name='bias',
            initializer=initializers.get('zeros'))
        print("LSTM b.shape: {}".format(self.bias.shape))
#         self.recurrent_activation = activations.get('hard_sigmoid')
#         self.activation = activations.get('tanh')
        
        self.sigmoid = activations.get('hard_sigmoid')
        self.tanh = activations.get('tanh')

    def call(self, inputs, states):
        """
        input shape 是三维 ，同时会计算batch_size 个样本数据
        """
#         print("---- call ---  states ", states)
        last_h = states[0]   # h(t-1)
        last_c = states[1]   # c(c-1)
        # i information 输入门
        # f forget 遗忘门
        # o output 输出门
        # c. cell
        # 四个权重保持在一个变量里面
        w_i, w_f, w_c, w_o = tf.split(self.w, num_or_size_splits=4, axis=1)
#         print('w_i.shape', w_i.shape)
        b_i, b_f, b_c, b_o = tf.split(self.bias, num_or_size_splits=4, axis=0)
#         print('b_i.shape', b_i.shape)
        # w x
        x_i = K.dot(inputs, w_i)
#         print("inputs shape {} * w_i shape {} = x_i shape {} ".format(inputs.shape, w_i.shape, x_i.shape))
        x_f = K.dot(inputs, w_f)
        x_c = K.dot(inputs, w_c)
        x_o = K.dot(inputs, w_o)
        # w x + b
        x_i = K.bias_add(x_i, b_i)
        x_f = K.bias_add(x_f, b_f)
        x_c = K.bias_add(x_c, b_c)
        x_o = K.bias_add(x_o, b_o)

        u_i, u_f, u_c, u_o = tf.split(self.u, num_or_size_splits=4, axis=1)
        # w x + u * h + x
        i = self.sigmoid(x_i + K.dot(last_h, u_i))
        f = self.sigmoid(x_f + K.dot(last_h, u_f))
#         c = f * last_c + self.tanh(x_c + K.dot(last_h, u_c))
        c = f * last_c + i * self.tanh(x_c + K.dot(last_h, u_c))

        o = self.sigmoid(x_o + K.dot(last_h, u_o))

        # 计算 h
        h = o * self.tanh(c)
        
        return h, (h, c)

class Rnn(tf.keras.layers.Layer):
    def __init__(self, units=128):
        super(Rnn, self).__init__()
        self.cell = LSTM_CELL(units)
        self.init_state = None
    def build(self, input_shape):
        print('Rnn shape: ', input_shape)
        shape = input_shape.as_list()
        n_batch = shape[0]
        init_h = tf.zeros(shape=[n_batch, self.cell.units])
        init_c = init_h
        self.init_state = (init_h, init_c)

    def call(self, inputs):
        """
        前向传播， 依次遍历每个时间序列
        第一个维度是 样本数量
        第二个维度是 时间序列
        """
        # time step
        ts = inputs.shape.as_list()[1]
#       print(inputs.shape.as_list())
        h, c = self.init_state
        for i in range(ts):
            h, (h, c) = self.cell(inputs[:, i],(h, c))
        return h


In [208]:
a = tf.random.normal(shape=(2, 3, 4))
print(a.shape)
rnn = Rnn(5)
h = rnn(a)
print(h.shape)
# print(a[0])
# [4, 28, 28]
# LSTM w.shape: (28, 1024)
# LSTM u.shape: (256, 1024)
# LSTM b.shape: (1024,)
# w_i.shape (28, 256)
# b_i.shape (256,)

(2, 3, 4)
Rnn shape:  (2, 3, 4)
LSTM w.shape: (4, 20)
LSTM u.shape: (5, 20)
LSTM b.shape: (20,)
(2, 5)


In [209]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()

        # office lstm
        #self.rnn = tf.keras.layers.RNN(tf.keras.layers.LSTMCell(256))
        # my lstm
        self.rnn = Rnn(256)
        # 三层网络， 128 Dense + 10 softmax
        self.d1 = tf.keras.layers.Dense(128, activation="relu")
        self.d2 = tf.keras.layers.Dense(10, activation="softmax")
    def call(self, x):
        """
        前向传播， 预测， 输入x， 输出y。 
        """
        x = self.rnn(x)
        # [batch_size, d1.output_size], [4, 128]
        x = self.d1(x)
        # [batch_size, d2.output_size], [4, 10]
        # 最后输出分类
        x = self.d2(x)
#       print('------x.shape', x.shape)
        return x

In [210]:
@tf.function
def train_step(model, loss, opti, images, labels, train_loss, train_acc):
    with tf.GradientTape() as tape:
        # pred [batch_size, n_class] (4, 10)
        pred = model(images)
        loss_val = loss(labels, pred)
    train_loss.update_state(loss_val)
    train_acc.update_state(labels, pred)
    grad = tape.gradient(loss_val, model.trainable_variables)
    opti.apply_gradients(zip(grad, model.trainable_variables))

In [211]:
# 定义优化器
opti = tf.keras.optimizers.Adam()
# 定义损失函数
loss = tf.keras.losses.SparseCategoricalCrossentropy()
# 用于记录损失值
train_loss = tf.keras.metrics.Mean()
# 记录正确率
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
# 加载数据
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), _ = fashion_mnist.load_data()
train_images = train_images / 255.0
num_used = 5000
train_images = train_images[:num_used]
train_labels = train_labels[:num_used]
train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(10000).batch(4)
# 定义模型
model = MyModel()
epochs = 30



In [212]:
train_images[0].shape

(28, 28)

In [None]:
list_time_cost = list()
list_acc = list()
for epoch in range(epochs):
    # train
    train_loss.reset_states()
    train_acc.reset_states()
    # images [batch_size, height, width] (4, 28, 28)
    # labels [batch_size]
    start = time.time()
    for images, labels in train_ds:
        train_step(model, loss, opti, images, labels, train_loss, train_acc)
    ends = time.time()
    cost = ends - start
    list_time_cost.append(cost)
    list_acc.append(train_acc.result().numpy())
    print("Time: {:.2f} s, Epoch: {:2d}, loss: {:.5f}, acc: {:.5f}".format(cost, epoch, train_loss.result(), train_acc.result()))
# with open("./output/my_lstm_acc.pkl", "wb") as fw:
#     pickle.dump(list_acc, fw)
# with open("./output/my_lstm_time_cost.pkl", "wb") as fw:
#     pickle.dump(list_time_cost, fw)



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.

Rnn shape:  (4, 28, 28)
LSTM w.shape: (28, 1024)
LSTM u.shape: (256, 1024)
LSTM b.shape: (1024,)
Time: 27.05 s, Epoch:  0, loss: 1.07026, acc: 0.59680
Time: 21.57 s, Epoch:  1, loss: 0.72783, acc: 0.72920
Time: 21.61 s, Epoch:  2, loss: 0.58067, acc: 0.78720
Time: 21.49 s, Epoch:  3, loss: 0.52058, acc: 0.80900
Time: 21.50 s, Epoch:  4, loss: 0.47311, acc: 0.82900
Time: 21.60 s, Epoch:  5, loss: 0.43479, acc: 0.83560
Time: 21.55 s, Epoch:  6, loss: 0.41062, acc: 0.84780
Time: 21.50 s, Epoch:  7, loss: 0.38830, acc: 0.85960
Time: 21.66 s, Epoch:  8, loss: 0.36042, acc: 0.86460


```
Time: 23.789902687072754 s, Epoch: 0, loss: 1.1165646314620972, acc: 0.5685999989509583
Time: 19.317150831222534 s, Epoch: 1, loss: 0.8056125044822693, acc: 0.6935999989509583
Time: 19.2503023147583 s, Epoch: 2, loss: 0.6855613589286804, acc: 0.7441999912261963
Time: 19.254036903381348 s, Epoch: 3, loss: 0.5773577690124512, acc: 0.7815999984741211
Time: 19.334920406341553 s, Epoch: 4, loss: 0.5494149923324585, acc: 0.7972000241279602
Time: 19.337989330291748 s, Epoch: 5, loss: 0.5215439200401306, acc: 0.8105999827384949
Time: 19.4918429851532 s, Epoch: 6, loss: 0.47755488753318787, acc: 0.829200029373169
Time: 19.27661943435669 s, Epoch: 7, loss: 0.4736003279685974, acc: 0.8271999955177307
Time: 19.279857635498047 s, Epoch: 8, loss: 0.4204665422439575, acc: 0.843999981880188
Time: 19.294177293777466 s, Epoch: 9, loss: 0.4283628761768341, acc: 0.8410000205039978
Time: 19.24098014831543 s, Epoch: 10, loss: 0.39684203267097473, acc: 0.8497999906539917
Time: 19.27296805381775 s, Epoch: 11, loss: 0.36575642228126526, acc: 0.8640000224113464
Time: 19.261507272720337 s, Epoch: 12, loss: 0.4031131863594055, acc: 0.8593999743461609
Time: 19.254382848739624 s, Epoch: 13, loss: 0.33121854066848755, acc: 0.8722000122070312
Time: 19.271594047546387 s, Epoch: 14, loss: 0.314357727766037, acc: 0.8820000290870667
Time: 19.265925645828247 s, Epoch: 15, loss: 0.3158433437347412, acc: 0.8841999769210815
Time: 19.297107934951782 s, Epoch: 16, loss: 0.29900258779525757, acc: 0.8871999979019165
Time: 19.178181886672974 s, Epoch: 17, loss: 0.2762836217880249, acc: 0.8960000276565552
Time: 19.29625701904297 s, Epoch: 18, loss: 0.24783462285995483, acc: 0.9053999781608582
Time: 19.205928802490234 s, Epoch: 19, loss: 0.2404624968767166, acc: 0.9110000133514404
Time: 19.252617597579956 s, Epoch: 20, loss: 0.24850600957870483, acc: 0.9056000113487244
Time: 19.25585174560547 s, Epoch: 21, loss: 0.2243300825357437, acc: 0.9151999950408936
Time: 19.26186513900757 s, Epoch: 22, loss: 0.20446330308914185, acc: 0.9211999773979187
Time: 19.31869077682495 s, Epoch: 23, loss: 0.19534704089164734, acc: 0.9272000193595886
Time: 19.297197103500366 s, Epoch: 24, loss: 0.17534077167510986, acc: 0.9354000091552734
Time: 19.262805223464966 s, Epoch: 25, loss: 0.17609521746635437, acc: 0.9330000281333923
Time: 19.22975778579712 s, Epoch: 26, loss: 0.17506620287895203, acc: 0.9381999969482422
Time: 19.275073051452637 s, Epoch: 27, loss: 0.15615135431289673, acc: 0.9409999847412109
Time: 19.438047647476196 s, Epoch: 28, loss: 0.14412060379981995, acc: 0.9458000063896179
Time: 19.246238946914673 s, Epoch: 29, loss: 0.1463460624217987, acc: 0.9458000063896179
```

```
#c = f * last_c + self.tanh(x_c + K.dot(last_h, u_c))
c = f * last_c + i * self.tanh(x_c + K.dot(last_h, u_c))
```
 修改完上面一处错误后， acc 提高。        
        
```
Time: 26.709998607635498 s, Epoch: 0, loss: 1.152647614479065, acc: 0.5522000193595886
Time: 21.37618327140808 s, Epoch: 1, loss: 0.7082579135894775, acc: 0.7386000156402588
Time: 21.438363790512085 s, Epoch: 2, loss: 0.5829753875732422, acc: 0.7871999740600586
Time: 21.412479400634766 s, Epoch: 3, loss: 0.5179308652877808, acc: 0.8044000267982483
Time: 21.25425386428833 s, Epoch: 4, loss: 0.47378724813461304, acc: 0.8223999738693237
Time: 21.292400360107422 s, Epoch: 5, loss: 0.43394234776496887, acc: 0.8407999873161316
Time: 21.256401777267456 s, Epoch: 6, loss: 0.4198310077190399, acc: 0.8442000150680542
Time: 21.24660563468933 s, Epoch: 7, loss: 0.3805636763572693, acc: 0.8565999865531921
Time: 21.293148517608643 s, Epoch: 8, loss: 0.3567814826965332, acc: 0.871399998664856
Time: 21.26962113380432 s, Epoch: 9, loss: 0.3456974923610687, acc: 0.873199999332428
Time: 21.267067432403564 s, Epoch: 10, loss: 0.3175477981567383, acc: 0.8859999775886536
Time: 21.255532026290894 s, Epoch: 11, loss: 0.3029397130012512, acc: 0.8889999985694885
Time: 21.34101128578186 s, Epoch: 12, loss: 0.29142022132873535, acc: 0.8913999795913696
Time: 21.273035526275635 s, Epoch: 13, loss: 0.26742011308670044, acc: 0.8949999809265137
Time: 21.25704288482666 s, Epoch: 14, loss: 0.2575889527797699, acc: 0.9061999917030334
Time: 21.324198246002197 s, Epoch: 15, loss: 0.2270210236310959, acc: 0.9147999882698059
Time: 21.39011859893799 s, Epoch: 16, loss: 0.21490620076656342, acc: 0.9187999963760376
Time: 21.248369216918945 s, Epoch: 17, loss: 0.21009555459022522, acc: 0.920799970626831
Time: 21.258670806884766 s, Epoch: 18, loss: 0.19424863159656525, acc: 0.9246000051498413
Time: 21.290088176727295 s, Epoch: 19, loss: 0.18225084245204926, acc: 0.928600013256073
Time: 21.236432552337646 s, Epoch: 20, loss: 0.1635647863149643, acc: 0.9383999705314636
Time: 21.3145968914032 s, Epoch: 21, loss: 0.15962575376033783, acc: 0.9408000111579895
Time: 21.270111322402954 s, Epoch: 22, loss: 0.14077134430408478, acc: 0.9455999732017517
Time: 21.28567624092102 s, Epoch: 23, loss: 0.1404210329055786, acc: 0.9477999806404114
Time: 21.267570972442627 s, Epoch: 24, loss: 0.1189630851149559, acc: 0.9563999772071838
Time: 21.305675745010376 s, Epoch: 25, loss: 0.1059480682015419, acc: 0.9598000049591064
Time: 21.32499861717224 s, Epoch: 26, loss: 0.11065905541181564, acc: 0.9577999711036682
Time: 21.27551817893982 s, Epoch: 27, loss: 0.10627779364585876, acc: 0.9613999724388123
Time: 21.334386825561523 s, Epoch: 28, loss: 0.08813607692718506, acc: 0.9664000272750854
Time: 21.29513669013977 s, Epoch: 29, loss: 0.07467116415500641, acc: 0.9742000102996826
```