In [1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras
print('tf    version: {}'.format(tf.__version__) )
print('keras version: {}'.format(keras.__version__) )
for module in mpl, np, pd, sklearn, tf, keras:
    print(module.__name__, module.__version__)

tf    version: 2.0.0
keras version: 2.2.4-tf
matplotlib 3.0.3
numpy 1.16.4
pandas 0.24.2
sklearn 0.21.2
tensorflow 2.0.0
tensorflow_core.keras 2.2.4-tf


In [2]:
input_filepath = './data/shakespeare.txt'

text = open(input_filepath, 'r').read()

In [3]:
print(len(text))
print(text[0:100])

1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


* generate vocab
* build mapping char->id
* data -> id_data
* abcd -> bcd<eos> 对应下一个字符的模型



In [4]:
vocab = sorted(set(text))
print(len(vocab))
print(vocab)

65
['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [5]:
char2idx = {char:idx for idx, char in enumerate(vocab)}
print(char2idx)

{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}


In [6]:
idx2char = np.array(vocab)
print(idx2char)

['\n' ' ' '!' '$' '&' "'" ',' '-' '.' '3' ':' ';' '?' 'A' 'B' 'C' 'D' 'E'
 'F' 'G' 'H' 'I' 'J' 'K' 'L' 'M' 'N' 'O' 'P' 'Q' 'R' 'S' 'T' 'U' 'V' 'W'
 'X' 'Y' 'Z' 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j' 'k' 'l' 'm' 'n' 'o'
 'p' 'q' 'r' 's' 't' 'u' 'v' 'w' 'x' 'y' 'z']


In [7]:
%%time
text_as_int = np.array([char2idx[c] for c in text])


CPU times: user 142 ms, sys: 16.4 ms, total: 158 ms
Wall time: 157 ms


In [8]:
print(len(text_as_int))
print(text_as_int[:10])
print(text[:10])

1115394
[18 47 56 57 58  1 15 47 58 47]
First Citi


In [9]:
def split_input_target(id_text):
    """
    abcde -> abcd, bcde, 输入和输出
    """
    return id_text[0:-1], id_text[1:]

char_dataset = tf.data.Dataset.from_tensor_slices(text_as_int)
seq_length = 100

seq_dataset = char_dataset.batch(seq_length +1 , drop_remainder=True)



for ch_id in char_dataset.take(2):
    print(ch_id, idx2char[ch_id.numpy()])
    
for seq_id in seq_dataset.take(2):
    print(seq_id)
    print(repr(''.join(idx2char[seq_id.numpy()])))


tf.Tensor(18, shape=(), dtype=int64) F
tf.Tensor(47, shape=(), dtype=int64) i
tf.Tensor(
[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43
  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43
 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49
  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10
  0 37 53 59  1], shape=(101,), dtype=int64)
'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
tf.Tensor(
[39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39 58 46 43 56  1
 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47 57 46 12  0  0
 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53 50 60 43 42  8
  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47 56 57 58  6  1
 63 53 59  1 49], shape=(101,), dtype=int64)
'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k'


In [10]:
seq_dataset = seq_dataset.map(split_input_target)

for item_input, item_output in seq_dataset.take(2):
    print(item_input.numpy())
    print(item_output.numpy())



[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43
  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43
 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49
  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10
  0 37 53 59]
[47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43  1
 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43 39
 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49  6
  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0
 37 53 59  1]
[39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39 58 46 43 56  1
 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47 57 46 12  0  0
 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53 50 60 43 42  8
  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47 56 57 58  6  1
 63 53 59  1]
[56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39 58 46 43 56  1 58
 53  1 42

In [11]:
batch_size = 64
buffer_size = 10000

seq_dataset = seq_dataset.shuffle(buffer_size).batch(
    batch_size, drop_remainder=True)


In [12]:
vocab_size = len(vocab)
embedding_dim = 256
rnn_units = 1024

def build_model(vocab_size, embedding_dim, rnn_units, batch_size):
    model = keras.models.Sequential([
        keras.layers.Embedding(vocab_size, embedding_dim, 
                              batch_input_shape=[batch_size,None]), 
        keras.layers.SimpleRNN(units=rnn_units, 
                              return_sequences=True), 
        keras.layers.Dense(vocab_size)
    ])
    return model
    
model = build_model(
    vocab_size = vocab_size,
    embedding_dim = embedding_dim,
    rnn_units = rnn_units,
    batch_size= batch_size
)

model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
embedding (Embedding)        (64, None, 256)           16640     
_________________________________________________________________
simple_rnn (SimpleRNN)       (64, None, 1024)          1311744   
_________________________________________________________________
dense (Dense)                (64, None, 65)            66625     
Total params: 1,395,009
Trainable params: 1,395,009
Non-trainable params: 0
_________________________________________________________________


In [13]:
for input_example_batch, target_example_batch in seq_dataset.take(1):
    example_batch_predictions = model(input_example_batch)
    print(example_batch_predictions.shape)

(64, 100, 65)


In [14]:
# random sampling 随机策略
# greedy 贪心策略

sample_indices = tf.random.categorical(logits=example_batch_predictions[0], 
                     num_samples=1)
# print(sample_indices)
sample_indices = tf.squeeze(sample_indices, axis=1)
# print(sample_indices.numpy())
print("Input: ", repr(''.join(idx2char[input_example_batch[0]])) )
print('')
print("Output: ", repr(''.join(idx2char[target_example_batch[0]])) )
print('')
print("Predictions: ", repr(''.join(idx2char[sample_indices])) )

Input:  "tes: i' the people's name,\nI say it shall be so.\n\nCitizens:\nIt shall be so, it shall be so; let him "

Output:  "es: i' the people's name,\nI say it shall be so.\n\nCitizens:\nIt shall be so, it shall be so; let him a"

Predictions:  "n&p&iUq3\ni.IX.NyKFazYl.eMNqkbwPFcZBJr!kLA:gHQcxwA;tfc;ead Jmz\nh'YsDTRXxFNbhO$nKgnTQeI-Wx$YiA&,,Id'S."


In [15]:
def loss(labels, logits):
    return keras.losses.sparse_categorical_crossentropy(
        labels, logits, from_logits=True)
 
model.compile(optimizer= 'adam', loss=loss)
example_loss = loss(target_example_batch, example_batch_predictions)

print(example_loss.shape)
print(example_loss.numpy().mean())

(64, 100)
4.1851006


In [16]:
output_dir = './data/text_generation'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
checkpoint_prefix = os.path.join(output_dir, 'ckpt_{epoch}')

checkpoint_callback = keras.callbacks.ModelCheckpoint(
    filepath = checkpoint_prefix,
    save_weights_only = True)

epochs = 50
history = model.fit(seq_dataset, epochs = epochs,
                    callbacks = [checkpoint_callback])

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


In [17]:
tf.train.latest_checkpoint(output_dir)

'./data/text_generation/ckpt_50'

In [19]:
model2 = build_model(vocab_size, embedding_dim, 
                    rnn_units, 
                    batch_size=1)

model2.load_weights(tf.train.latest_checkpoint(output_dir))
model2.build(tf.TensorShape([1, None]))

# start ch sequence A, 
# A -> model -> b
# A.append(b) -> B
# B -> model -> c
# B.append(c) -> C
# C -> model -> d

def generate_text(model, start_string, num_generate=1000):
    input_eval = [char2idx[ch] for ch in start_string]
    input_eval = tf.expand_dims(input_eval, 0)
    
    text_generated = []
    model.reset_states()
    
    for _ in range(num_generate):
        #1. model inference -> predictions
        #2. sample -> ch -> text_generated.
        #3. update input_eval
        predictions = model(input_eval)
        
        predictions = tf.squeeze(predictions, 0)
        predicted_id = tf.random.categorical(
        predictions, num_samples= 1)[-1, 0].numpy()
        
        text_generated.append(idx2char[predicted_id])
        input_eval = tf.expand_dims([predicted_id], 0)
        
    return start_string + ''.join(text_generated)
        
new_text = generate_text(model2, "All: ")
print(new_text)

All: ad. the.
NGE:
Tir schitea t heng thid w u amanobe t, nonommar the h thathormealake, NThowallyind lis fite onteal thirord

So h w t lfone ceer d, ortho ncak, blor-m s ary y IOfed 'd, harome, o othed, thino Mowomm rrthetheer. wnits amiorisshecoo otoulas m:
A Camet, ot site sceafomoso whesis anoon tonakitee, s titertif t
Anghomees ughisive meathon ain talle sat t, tisu, d l
I a vese u tcave


mieno t wans; f ppoom nchorsp berajul pe are nould t g aies se m f fr?
LAne yo fud s.
Whef oushalig s iese l nd CERICin uthan N:
CLAMe sitor,
INomy Whaice, th Qurrneeeve vithes p'thm the surs me

sh t, ong IEOf she t foforsim rrsagld tidy I ng, IOfandore t uk:

S:
Le thy IRO:

Hmit, thigherremofrelatheno an
WA:
PUD helath thanoprer-vinareedithy tschis
N:

So, th! fon thero he a oos ur?
fts, IE INI
T:

Couloucity maio IZo postur, thou IZAn,


APinofors n d mathomimay turwe; alacis! thres thed t that the, d toupif d. hingh t nor-t

Ourealumy CK:
IIUS: ly ngug RAnod towan d sistowif, the,

s bllova