In [1]:
import tensorflow as tf
import numpy as np

In [2]:
class DataLoader():
    def __init__(self):
        path = tf.keras.utils.get_file('nietzsche.txt',
            origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')
        with open(path, encoding='utf-8') as f:
            self.raw_text = f.read().lower()
        self.chars = sorted(list(set(self.raw_text)))
        self.char_indices = dict((c, i) for i, c in enumerate(self.chars))
        self.indices_char = dict((i, c) for i, c in enumerate(self.chars))
        self.text = [self.char_indices[c] for c in self.raw_text]

    def get_batch(self, seq_length, batch_size):
        seq = []
        next_char = []
        for i in range(batch_size):
            index = np.random.randint(0, len(self.text) - seq_length)
            seq.append(self.text[index:index+seq_length])
            next_char.append(self.text[index+seq_length])
        return np.array(seq), np.array(next_char)       # [batch_size, seq_length], [num_batch]

In [3]:
class RNN(tf.keras.Model):
    def __init__(self, num_chars, batch_size, seq_length):
        super().__init__()
        self.num_chars = num_chars
        self.seq_length = seq_length
        self.batch_size = batch_size
        self.cell = tf.keras.layers.LSTMCell(units=256)
        self.dense = tf.keras.layers.Dense(units=self.num_chars)

    def call(self, inputs, from_logits=False):
        inputs = tf.one_hot(inputs, depth=self.num_chars)       # [batch_size, seq_length, num_chars]
        state = self.cell.get_initial_state(batch_size=self.batch_size, dtype=tf.float32)
        for t in range(self.seq_length):
            output, state = self.cell(inputs[:, t, :], state)
        logits = self.dense(output)
        if from_logits:
            return logits
        else:
            return tf.nn.softmax(logits)

In [4]:
num_batches = 1000
seq_length = 40
batch_size = 50
learning_rate = 1e-3

In [6]:
data_loader = DataLoader()
model = RNN(num_chars=len(data_loader.chars), batch_size=batch_size, seq_length=seq_length)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(seq_length, batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

Downloading data from https://s3.amazonaws.com/text-datasets/nietzsche.txt
batch 0: loss 4.048768
batch 1: loss 4.023171
batch 2: loss 4.012205
batch 3: loss 4.004153
batch 4: loss 3.946578
batch 5: loss 3.912091
batch 6: loss 3.813337
batch 7: loss 3.487267
batch 8: loss 3.563107
batch 9: loss 3.378630
batch 10: loss 2.878976
batch 11: loss 3.183325
batch 12: loss 3.137625
batch 13: loss 3.182850
batch 14: loss 3.131811
batch 15: loss 2.900273
batch 16: loss 3.196112
batch 17: loss 2.986158
batch 18: loss 3.174043
batch 19: loss 3.186116
batch 20: loss 3.280022
batch 21: loss 3.321677
batch 22: loss 3.040190
batch 23: loss 3.229928
batch 24: loss 3.153940
batch 25: loss 3.148842
batch 26: loss 2.989758
batch 27: loss 3.063774
batch 28: loss 3.104701
batch 29: loss 2.952697
batch 30: loss 3.335694
batch 31: loss 2.977010
batch 32: loss 3.470153
batch 33: loss 3.009651
batch 34: loss 2.953448
batch 35: loss 3.143293
batch 36: loss 3.198764
batch 37: loss 3.165862
batch 38: loss 3.075455

batch 327: loss 2.980579
batch 328: loss 2.745222
batch 329: loss 2.703315
batch 330: loss 2.593916
batch 331: loss 2.871174
batch 332: loss 2.952435
batch 333: loss 2.853171
batch 334: loss 2.893400
batch 335: loss 2.775028
batch 336: loss 2.927334
batch 337: loss 2.988840
batch 338: loss 2.457830
batch 339: loss 2.850562
batch 340: loss 3.131493
batch 341: loss 2.867518
batch 342: loss 2.864533
batch 343: loss 2.772941
batch 344: loss 2.636870
batch 345: loss 2.775579
batch 346: loss 2.587173
batch 347: loss 2.700779
batch 348: loss 2.534466
batch 349: loss 2.798055
batch 350: loss 2.980839
batch 351: loss 2.883104
batch 352: loss 2.840234
batch 353: loss 2.933707
batch 354: loss 2.828535
batch 355: loss 2.918423
batch 356: loss 2.680420
batch 357: loss 2.728408
batch 358: loss 2.498987
batch 359: loss 2.778182
batch 360: loss 2.606412
batch 361: loss 2.527189
batch 362: loss 2.816576
batch 363: loss 2.863319
batch 364: loss 2.901592
batch 365: loss 2.632804
batch 366: loss 2.683136


batch 655: loss 2.467494
batch 656: loss 2.463978
batch 657: loss 2.566087
batch 658: loss 2.308928
batch 659: loss 2.590937
batch 660: loss 2.520720
batch 661: loss 2.198870
batch 662: loss 2.411905
batch 663: loss 2.260876
batch 664: loss 2.417760
batch 665: loss 2.405300
batch 666: loss 2.644347
batch 667: loss 2.443368
batch 668: loss 2.646087
batch 669: loss 2.607146
batch 670: loss 2.387889
batch 671: loss 2.239663
batch 672: loss 2.235370
batch 673: loss 2.661579
batch 674: loss 2.568810
batch 675: loss 2.325109
batch 676: loss 2.505965
batch 677: loss 2.190019
batch 678: loss 2.594875
batch 679: loss 2.250072
batch 680: loss 2.224346
batch 681: loss 2.310043
batch 682: loss 2.597472
batch 683: loss 2.689369
batch 684: loss 2.672514
batch 685: loss 2.470024
batch 686: loss 2.283690
batch 687: loss 2.676895
batch 688: loss 2.535475
batch 689: loss 2.466646
batch 690: loss 2.746387
batch 691: loss 2.629855
batch 692: loss 2.671843
batch 693: loss 2.546695
batch 694: loss 2.803538


batch 983: loss 2.503717
batch 984: loss 2.011304
batch 985: loss 2.443228
batch 986: loss 2.431129
batch 987: loss 2.397177
batch 988: loss 2.500897
batch 989: loss 2.345928
batch 990: loss 2.151245
batch 991: loss 2.568460
batch 992: loss 2.440712
batch 993: loss 2.529739
batch 994: loss 2.174102
batch 995: loss 2.903484
batch 996: loss 2.484757
batch 997: loss 2.448186
batch 998: loss 2.638718
batch 999: loss 2.212423


In [7]:
def predict(self, inputs, temperature=1.):
    batch_size, _ = tf.shape(inputs)
    logits = self(inputs, from_logits=True)
    prob = tf.nn.softmax(logits / temperature).numpy()
    return np.array([np.random.choice(self.num_chars, p=prob[i, :])
                     for i in range(batch_size.numpy())])

In [8]:
X_, _ = data_loader.get_batch(seq_length, 1)
for diversity in [0.2, 0.5, 1.0, 1.2]:
    X = X_
    print("diversity %f:" % diversity)
    for t in range(400):
        y_pred = model.predict(X, diversity)
        print(data_loader.indices_char[y_pred[0]], end='', flush=True)
        X = np.concatenate([X[:, 1:], np.expand_dims(y_pred, axis=1)], axis=-1)
    print("\n")

diversity 0.200000:


TypeError: slice indices must be integers or None or have an __index__ method