# gru 生成模型 

参照官方文档    
https://tensorflow.google.cn/tutorials/text/text_generation

## 1数据处理

In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf

import numpy as np
import os
import time

In [2]:
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
# 读取并为 py2 compat 解码
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
text = text[:111539]
# 文本长度是指文本中的字符个数
print ('Length of text: {} characters'.format(len(text)))
text[:30]

Length of text: 111539 characters


'First Citizen:\nBefore we proce'

In [3]:
# 文本中的非重复字符
vocab = sorted(set(text))
print ('{} unique characters'.format(len(vocab)))


61 unique characters


In [7]:
# 创建从非重复字符到索引的映射
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

text_as_int = np.array([char2idx[c] for c in text])
text_as_int.shape

(111539,)

In [8]:
text_as_int_target=text_as_int[1:]
print(text_as_int[:100])
print(text_as_int_target[:100])


[16 43 52 53 54  1 13 43 54 43 60 39 48  8  0 12 39 40 49 52 39  1 57 39
  1 50 52 49 37 39 39 38  1 35 48 59  1 40 55 52 54 42 39 52  5  1 42 39
 35 52  1 47 39  1 53 50 39 35 45  7  0  0 11 46 46  8  0 29 50 39 35 45
  5  1 53 50 39 35 45  7  0  0 16 43 52 53 54  1 13 43 54 43 60 39 48  8
  0 34 49 55]
[43 52 53 54  1 13 43 54 43 60 39 48  8  0 12 39 40 49 52 39  1 57 39  1
 50 52 49 37 39 39 38  1 35 48 59  1 40 55 52 54 42 39 52  5  1 42 39 35
 52  1 47 39  1 53 50 39 35 45  7  0  0 11 46 46  8  0 29 50 39 35 45  5
  1 53 50 39 35 45  7  0  0 16 43 52 53 54  1 13 43 54 43 60 39 48  8  0
 34 49 55  1]


In [9]:
# 设定每个输入句子长度的最大值
seq_length = 100

examples_per_epoch = len(text)//(seq_length)
print("examples_per_epoch",examples_per_epoch)
# 从原文的 长串整数编码切出 111500个 字符编码
train_x=text_as_int[:examples_per_epoch*100]
train_y=text_as_int_target[:examples_per_epoch*100]
# reshape成1115行  每行是一个seq
train_x=train_x.reshape([examples_per_epoch,-1])
train_y=train_y.reshape([examples_per_epoch,-1])

print(train_x.shape)
print(train_y.shape)

examples_per_epoch 1115
(1115, 100)
(1115, 100)


In [11]:
# 批大小
BATCH_SIZE = 64

batchs_per_epoch = examples_per_epoch//(BATCH_SIZE)
print("batchs_per_epoch",batchs_per_epoch)

train_x = train_x[:17*64]
train_y = train_y[:17*64]

print(train_x.shape)
print(train_y.shape)

batchs_per_epoch 17
(1088, 100)
(1088, 100)


In [12]:
print(''.join([idx2char[i] for i in train_x[2]]))
print(''.join([idx2char[i] for i in train_y[2]]))

 know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us
know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us 


## 2模型构建

In [13]:


# 词集的长度
vocab_size = len(vocab)

# 嵌入的维度
embedding_dim = 256

# RNN 的单元数量
rnn_units = 1024


###  model 可以不指定第二位  就是时间步的那一维     

1 train_x 的shape是 1088\*100 (按batch feed的话是64\*100) 行数是句子数 列是指每行有100个char  这里一个char是一个0~vocab_size的编号 代表它在词表里是第几个   
2 模型第一层走一个embedding把 input的每个句子 的最后一维  就是长度是100的那一维 给多拉一维 拉成 100\*256 (256词向量的维度)  意思是给这100个char 每一char都给一个特定的表示性的 char vec   
3 模型第二层 单向gru 要是return_sequences=True的话 每个时间步都会返回  这样的话 这层的输出也是3维  是batchsize seqlen out_h   rnn_units = 1024  故 out_h共 1024个   
4 模型第三层  一个全连接层 而且没激活函数  是让上一层输出的最后一维映射到 61就是词表维度   它的算loss往外放了 没放这

In [14]:
def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    model = tf.keras.Sequential([
    tf.keras.layers.Embedding(vocab_size, embedding_dim,
                              batch_input_shape=[batch_size, None]),
    tf.keras.layers.GRU(rnn_units,
                        return_sequences=True,
                        stateful=True,
                        recurrent_initializer='glorot_uniform'),
    tf.keras.layers.Dense(vocab_size)
  ])
    return model


In [15]:
model = build_model(
  vocab_size = len(vocab),
  embedding_dim=embedding_dim,
  rnn_units=rnn_units,
  batch_size=BATCH_SIZE)


In [16]:
example_batch_predictions = model(train_x[:64])
print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")


(64, 100, 61) # (batch_size, sequence_length, vocab_size)


In [17]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (64, None, 256)           15616     
_________________________________________________________________
gru (GRU)                    (64, None, 1024)          3938304   
_________________________________________________________________
dense (Dense)                (64, None, 61)            62525     
Total params: 4,016,445
Trainable params: 4,016,445
Non-trainable params: 0
_________________________________________________________________


In [18]:
def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

example_batch_loss  = loss(train_x[:64], example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("scalar_loss:      ", example_batch_loss.numpy().mean())


Prediction shape:  (64, 100, 61)  # (batch_size, sequence_length, vocab_size)
scalar_loss:       4.111576


In [19]:
model.compile(optimizer='adam', loss=loss)



In [20]:
# 检查点保存至的目录
checkpoint_dir = './training_checkpoints/gru_ssby_zg/'

# 检查点的文件名
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")

checkpoint_callback=tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_prefix,
    save_weights_only=True)



In [50]:
history = model.fit(train_x,train_y,batch_size=BATCH_SIZE, epochs=10, callbacks=[checkpoint_callback])


Train on 1088 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


## 3文本生成

In [17]:
tf.train.latest_checkpoint(checkpoint_dir)


'./training_checkpoints/gru_ssby_zg/ckpt_10'

In [18]:
model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

model.build(tf.TensorShape([1, None]))

model.summary()


Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding_1 (Embedding)      (1, None, 256)            15616     
_________________________________________________________________
gru_1 (GRU)                  (1, None, 1024)           3938304   
_________________________________________________________________
dense_1 (Dense)              (1, None, 61)             62525     
Total params: 4,016,445
Trainable params: 4,016,445
Non-trainable params: 0
_________________________________________________________________


In [19]:
idx2char

array(['\n', ' ', '!', '&', "'", ',', '-', '.', ':', ';', '?', 'A', 'B',
       'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O',
       'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y', 'a', 'b', 'c', 'd',
       'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q',
       'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z'], dtype='<U1')

In [20]:
test_char = 'w'
test_char_index = char2idx[test_char]
test_char_index

57

In [21]:
test_char_predict = model(np.array([test_char_index]).reshape([1,1]))
test_char_predict

<tf.Tensor: id=4700, shape=(1, 1, 61), dtype=float32, numpy=
array([[[ 1.7514497 ,  2.2750309 ,  0.702665  , -3.3979805 ,
          1.1239854 ,  1.777669  ,  0.05389442,  1.3949445 ,
          0.71221805,  0.49949953,  1.0053072 , -0.45564556,
         -1.9303716 , -0.64923096, -2.3778992 , -1.7258878 ,
         -2.2563925 , -2.5792494 , -1.7512918 ,  0.55804664,
         -3.4456086 , -2.87575   , -4.5027575 , -2.5136354 ,
         -1.5305692 , -1.0664184 , -2.692316  , -3.1265128 ,
         -4.1851764 , -0.43765306, -1.0113206 , -0.31129268,
         -1.9276886 , -1.722134  , -2.4528635 ,  2.6098864 ,
         -2.252105  , -0.2655328 , -2.025079  ,  2.963385  ,
         -1.0747224 , -1.5135165 ,  4.288119  ,  2.816905  ,
         -2.5537174 , -4.4427905 , -0.33213013, -1.0699183 ,
          2.525613  ,  1.8245411 , -2.01107   , -1.8195901 ,
          0.78075075,  1.0319731 , -0.5754767 , -1.8902191 ,
         -2.103181  , -0.4223451 , -0.72295344, -0.21589315,
         -3.6508598 ]]],

In [22]:
pre_char_index = tf.argmax(test_char_predict.numpy().flatten())
print(pre_char_index,idx2char[pre_char_index])

tf.Tensor(42, shape=(), dtype=int64) h


In [23]:
def generate_text(model, start_string):
  # 评估步骤（用学习过的模型生成文本）

  # 要生成的字符个数
    num_generate = 1000

  # 将起始字符串转换为数字（向量化）
    input_eval = [char2idx[s] for s in start_string]
#     print(input_eval)
    input_eval = tf.expand_dims(input_eval, 0)
    print("input_eval",input_eval)

  # 空字符串用于存储结果
    text_generated = []

  # 低温度会生成更可预测的文本
  # 较高温度会生成更令人惊讶的文本
  # 可以通过试验以找到最好的设定
    temperature = 1.0

  # 这里批大小为 1
    model.reset_states()
    for i in range(num_generate):
        predictions = model(input_eval)
#         print(predictions.shape)
      # 删除批次的维度
        predictions = tf.squeeze(predictions, 0)
#         print(predictions.shape)
        
      # 用分类分布预测模型返回的字符
        predictions = predictions / temperature
        predicted_id=tf.argmax(predictions.numpy().flatten()).numpy()
        print("max predicted_id",predicted_id)
        predicted_id = tf.random.categorical(predictions, num_samples=1)[-1,0].numpy()
        print("random predicted_id",predicted_id)

      # 把预测字符和前面的隐藏状态一起传递给模型作为下一个输入
        input_eval = tf.expand_dims([predicted_id], 0)
        print(input_eval.shape)

        text_generated.append(idx2char[predicted_id])

    return (start_string + ''.join(text_generated))


In [24]:
print(generate_text(model, start_string=u"w"))


input_eval tf.Tensor([[57]], shape=(1, 1), dtype=int32)
max predicted_id 42
random predicted_id 2
(1, 1)
max predicted_id 0
random predicted_id 0
(1, 1)
max predicted_id 0
random predicted_id 0
(1, 1)
max predicted_id 23
random predicted_id 13
(1, 1)
max predicted_id 25
random predicted_id 25
(1, 1)
max predicted_id 28
random predicted_id 28
(1, 1)
max predicted_id 19
random predicted_id 19
(1, 1)
max predicted_id 25
random predicted_id 25
(1, 1)
max predicted_id 22
random predicted_id 22
(1, 1)
max predicted_id 11
random predicted_id 11
(1, 1)
max predicted_id 24
random predicted_id 24
(1, 1)
max predicted_id 31
random predicted_id 31
(1, 1)
max predicted_id 29
random predicted_id 29
(1, 1)
max predicted_id 8
random predicted_id 8
(1, 1)
max predicted_id 0
random predicted_id 0
(1, 1)
max predicted_id 33
random predicted_id 53
(1, 1)
max predicted_id 42
random predicted_id 49
(1, 1)
max predicted_id 5
random predicted_id 5
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max pre

max predicted_id 8
random predicted_id 8
(1, 1)
max predicted_id 0
random predicted_id 0
(1, 1)
max predicted_id 33
random predicted_id 24
(1, 1)
max predicted_id 49
random predicted_id 49
(1, 1)
max predicted_id 5
random predicted_id 57
(1, 1)
max predicted_id 1
random predicted_id 5
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max predicted_id 54
random predicted_id 11
(1, 1)
max predicted_id 48
random predicted_id 40
(1, 1)
max predicted_id 1
random predicted_id 40
(1, 1)
max predicted_id 43
random predicted_id 43
(1, 1)
max predicted_id 37
random predicted_id 37
(1, 1)
max predicted_id 39
random predicted_id 43
(1, 1)
max predicted_id 54
random predicted_id 52
(1, 1)
max predicted_id 39
random predicted_id 43
(1, 1)
max predicted_id 55
random predicted_id 52
(1, 1)
max predicted_id 39
random predicted_id 39
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max predicted_id 35
random predicted_id 35
(1, 1)
max predicted_id 48
random predicted_id 53
(1, 1)
max predicte

max predicted_id 48
random predicted_id 48
(1, 1)
max predicted_id 38
random predicted_id 53
(1, 1)
max predicted_id 1
random predicted_id 5
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max predicted_id 54
random predicted_id 46
(1, 1)
max predicted_id 39
random predicted_id 39
(1, 1)
max predicted_id 35
random predicted_id 51
(1, 1)
max predicted_id 55
random predicted_id 55
(1, 1)
max predicted_id 39
random predicted_id 43
(1, 1)
max predicted_id 48
random predicted_id 52
(1, 1)
max predicted_id 39
random predicted_id 1
(1, 1)
max predicted_id 54
random predicted_id 54
(1, 1)
max predicted_id 42
random predicted_id 42
(1, 1)
max predicted_id 39
random predicted_id 39
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max predicted_id 50
random predicted_id 50
(1, 1)
max predicted_id 39
random predicted_id 39
(1, 1)
max predicted_id 49
random predicted_id 40
(1, 1)
max predicted_id 50
random predicted_id 1
(1, 1)
max predicted_id 54
random predicted_id 59
(1, 1)
max pred

random predicted_id 43
(1, 1)
max predicted_id 53
random predicted_id 47
(1, 1)
max predicted_id 1
random predicted_id 39
(1, 1)
max predicted_id 1
random predicted_id 48
(1, 1)
max predicted_id 54
random predicted_id 49
(1, 1)
max predicted_id 55
random predicted_id 55
(1, 1)
max predicted_id 52
random predicted_id 53
(1, 1)
max predicted_id 1
random predicted_id 2
(1, 1)
max predicted_id 0
random predicted_id 0
(1, 1)
max predicted_id 0
random predicted_id 0
(1, 1)
max predicted_id 13
random predicted_id 13
(1, 1)
max predicted_id 25
random predicted_id 25
(1, 1)
max predicted_id 28
random predicted_id 28
(1, 1)
max predicted_id 19
random predicted_id 19
(1, 1)
max predicted_id 25
random predicted_id 25
(1, 1)
max predicted_id 22
random predicted_id 22
(1, 1)
max predicted_id 11
random predicted_id 11
(1, 1)
max predicted_id 24
random predicted_id 24
(1, 1)
max predicted_id 31
random predicted_id 31
(1, 1)
max predicted_id 29
random predicted_id 29
(1, 1)
max predicted_id 8
random pr

max predicted_id 39
random predicted_id 35
(1, 1)
max predicted_id 38
random predicted_id 53
(1, 1)
max predicted_id 54
random predicted_id 54
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max predicted_id 49
random predicted_id 42
(1, 1)
max predicted_id 39
random predicted_id 43
(1, 1)
max predicted_id 47
random predicted_id 53
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max predicted_id 37
random predicted_id 53
(1, 1)
max predicted_id 49
random predicted_id 42
(1, 1)
max predicted_id 35
random predicted_id 49
(1, 1)
max predicted_id 55
random predicted_id 48
(1, 1)
max predicted_id 38
random predicted_id 38
(1, 1)
max predicted_id 1
random predicted_id 5
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max predicted_id 35
random predicted_id 35
(1, 1)
max predicted_id 48
random predicted_id 48
(1, 1)
max predicted_id 38
random predicted_id 38
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max predicted_id 54
random predicted_id 52
(1, 1)
max predic

max predicted_id 41
random predicted_id 39
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max predicted_id 54
random predicted_id 54
(1, 1)
max predicted_id 42
random predicted_id 55
(1, 1)
max predicted_id 46
random predicted_id 39
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max predicted_id 42
random predicted_id 38
(1, 1)
max predicted_id 49
random predicted_id 43
(1, 1)
max predicted_id 53
random predicted_id 41
(1, 1)
max predicted_id 42
random predicted_id 42
(1, 1)
max predicted_id 54
random predicted_id 54
(1, 1)
max predicted_id 1
random predicted_id 59
(1, 1)
max predicted_id 1
random predicted_id 1
(1, 1)
max predicted_id 35
random predicted_id 56
(1, 1)
max predicted_id 39
random predicted_id 49
(1, 1)
max predicted_id 43
random predicted_id 49
(1, 1)
max predicted_id 38
random predicted_id 37
(1, 1)
max predicted_id 39
random predicted_id 42
(1, 1)
max predicted_id 5
random predicted_id 0
(1, 1)
max predicted_id 33
random predicted_id 12
(1, 1)
max predi