In [1]:
import tensorflow as tf

# lstm

lstm, 全称是 long short term memory neural network, 中文名是长短记忆时网络。 该网络是RNN（Recurent neural network ）的变种。RNN类型相比于全连接神经网络的优势在于，同级别的网络层之间没有连接，那么输出信息只和输入信息相关，和输入的先后顺序无关。同时全连接网络的神经元不能存储信息，那么就没有记忆功能。对于和时间相关的输入数据，全连接网络就不能有效处理。
举个例子，
如果判别“地球上最高峰是珠穆朗玛峰”这个句子是否表达正确。全连接神经网络会将这个句子划分成几个词（地球，上，最高峰，是，珠穆朗玛峰）， 分布输入到模型中，忽略了词之间的关系。

RNN能够有效处理这类数据。然而，RNN自身也存在一些问题。RNN在理论上可以保留所有历史时刻的信息，但是在实际使用中，信息的传递往往会因为间隔太长而逐渐衰减，传递一段时刻以后，信息的作用效果就大大降低了。因此，普通RNN对于信息的长期依赖问题没有很好的处理办法。

为了克服这个问题， Hochreiter 等人改进了RNN， 提出了特殊的RNN模型， LSTM, 可以学习到长期依赖信息。


lstm网络的结构如下：
![title](img/lstm1.png)
lstm 基本思想
LSTM有以下参数， 输入x_t, 细胞状态C_t, 输出h_t, 遗忘门f, 输入门i, 输出门o。

遗忘门f的作用是将细胞状态C_t-1进行评估，0为完全遗忘，1为完全记忆。使用sigmoid函数来处理，将C_t-1 和一个0-1值进行相乘。
$$f_t=\sigma(W_f[h_t-1, x_t]+b_f)$$

输入门是针对候选的C_t来进行处理，同样使用sigmid函数进行0-1的映射。
$$i_t=\sigma(W_i[h_t-1, x_t]+b_i)$$

输出门是针对C_t 到 h_t的映射，同样使用sigmoid函数

$$o_t=\sigma(W_o[h_t-1, x_t)+b_o)$$






$$P(A \mid B) = \frac{ P(B \mid A) P(A) }{ P(B) }$$

In [2]:
(train_data, train_target), (test_data, test_target) = tf.keras.datasets.imdb.load_data()

In [3]:
max_len = 64

In [4]:
train_target = train_target.astype("float32")

In [5]:
train_padding = tf.keras.preprocessing.sequence.pad_sequences(train_data, padding="post", maxlen=max_len)
test_padding = tf.keras.preprocessing.sequence.pad_sequences(test_data, padding="post", maxlen=max_len)

In [6]:
train_padding = train_padding.astype("float32")
test_padding = test_padding.astype("float32")

In [7]:
word_index = tf.keras.datasets.imdb.get_word_index()

In [8]:
VOCAB_SIZE = 100000
EMBED_SIZE = 10
RNN_SIZE = 10

In [9]:
dataset = tf.data.Dataset.from_tensor_slices((train_padding, train_target))
dataset = dataset.shuffle(100).batch(100)

In [15]:
class JLSTM(tf.keras.Model):
    
    def __init__(self):
        super(JLSTM, self).__init__()
        
        self.embed = tf.keras.layers.Embedding(VOCAB_SIZE, output_dim=EMBED_SIZE)
        self.rnn = tf.keras.layers.LSTM(RNN_SIZE, return_sequences=True, return_state=True)
        self.out = tf.keras.layers.Dense(1, activation="sigmoid")
        
    def call(self, input_x):
        
        x = self.embed(input_x)
        out, state, _ = self.rnn(x)
        out = out[:,-1,:]
        logits = self.out(out)
        
        return logits

In [16]:
model = JLSTM()

In [17]:
model(tf.constant([[1, 2], [3, 4]]))

<tf.Tensor: id=232, shape=(2, 1), dtype=float32, numpy=
array([[0.50309056],
       [0.50318044]], dtype=float32)>

In [18]:
optimizer = tf.keras.optimizers.Adam()
loss_func = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [19]:
@tf.function()
def train_step(input_x, input_y):
    
    with tf.GradientTape() as tape:
        
        logits = model(input_x)
        loss = loss_func(logits, input_y)
        
    variables = model.variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))
    

    
    return loss

In [20]:
EPOCH = 100

for e in range(EPOCH):
    
    for i, (train_x, train_y) in enumerate(dataset):
        loss = train_step(train_x, train_y)
        
        if i% 100 == 0:
            print("epoch {0} batch {1} loss {2}".format(e, i, loss))

epoch 0 batch 0 loss 0.743574321269989
epoch 0 batch 100 loss 0.545647144317627
epoch 0 batch 200 loss 0.5006479024887085
epoch 1 batch 0 loss 0.5313969850540161
epoch 1 batch 100 loss 0.5219604969024658
epoch 1 batch 200 loss 0.5134150981903076
epoch 2 batch 0 loss 0.5318225622177124
epoch 2 batch 100 loss 0.5237939357757568
epoch 2 batch 200 loss 0.5045776963233948
epoch 3 batch 0 loss 0.545950710773468
epoch 3 batch 100 loss 0.526868999004364
epoch 3 batch 200 loss 0.5267236828804016
epoch 4 batch 0 loss 0.5380159616470337
epoch 4 batch 100 loss 0.5265653729438782
epoch 4 batch 200 loss 0.5075460076332092
epoch 5 batch 0 loss 0.5264567732810974
epoch 5 batch 100 loss 0.5074487924575806
epoch 5 batch 200 loss 0.5149799585342407
epoch 6 batch 0 loss 0.5225409865379333
epoch 6 batch 100 loss 0.48458412289619446
epoch 6 batch 200 loss 0.5262649655342102
epoch 7 batch 0 loss 0.530045747756958
epoch 7 batch 100 loss 0.5262291431427002
epoch 7 batch 200 loss 0.5034438371658325
epoch 8 batc

epoch 65 batch 0 loss 0.5259976387023926
epoch 65 batch 100 loss 0.5335954427719116
epoch 65 batch 200 loss 0.5070034265518188
epoch 66 batch 0 loss 0.5411931276321411
epoch 66 batch 100 loss 0.5108022689819336
epoch 66 batch 200 loss 0.49940571188926697
epoch 67 batch 0 loss 0.5259976983070374
epoch 67 batch 100 loss 0.5297964811325073
epoch 67 batch 200 loss 0.49940571188926697
epoch 68 batch 0 loss 0.5411930680274963
epoch 68 batch 100 loss 0.5411930680274963
epoch 68 batch 200 loss 0.48800909519195557
epoch 69 batch 0 loss 0.5032045245170593
epoch 69 batch 100 loss 0.5297965407371521
epoch 69 batch 200 loss 0.49560683965682983
epoch 70 batch 0 loss 0.5563884973526001
epoch 70 batch 100 loss 0.5183999538421631
epoch 70 batch 200 loss 0.49560680985450745
epoch 71 batch 0 loss 0.5411930680274963
epoch 71 batch 100 loss 0.4918079376220703
epoch 71 batch 200 loss 0.5146010518074036
epoch 72 batch 0 loss 0.5297964811325073
epoch 72 batch 100 loss 0.49560678005218506
epoch 72 batch 200 lo