In [1]:
import itertools
import os

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

physical_devices = tf.config.list_physical_devices('GPU')
for physical_device in physical_devices:
    tf.config.experimental.set_memory_growth(physical_device, enable=True)

In [2]:
BATCH_SIZE = 4096
BATCHED_ITEM_LENGTH = 12
TYPE=np.int16

subword_text_encoder = tfds.features.text.SubwordTextEncoder.load_from_file('vocab_4096')

class Articles:
    def __init__(self, path):
        with open(path, 'rb') as text_file:
            articles = [subword_text_encoder.encode(article) for article in text_file.read().split(b'\0')]
            self.articles = sorted(articles, key=len)

    def articles_generator(self, start, end):
        for _ in range(BATCH_SIZE - ((end - start - 1) % BATCH_SIZE + 1)):
            yield np.array([], dtype=TYPE)

        for article in itertools.islice(self.articles, start, end):
            yield np.array(article, dtype=TYPE)

    def subbatch_generator(self, start, end):
        dataset = tf.data.Dataset.from_generator(self.articles_generator, args=(start, end), output_types=TYPE)
        dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=([None]), drop_remainder=True)
        dataset = dataset.shuffle(50)

        for batch in dataset.as_numpy_iterator():
            remaining = batch
            while remaining.shape[1] > 1:
                yield remaining[:, :BATCHED_ITEM_LENGTH + 1]
                remaining = remaining[:, BATCHED_ITEM_LENGTH:]

    def dataset(self, start, end):
        dataset = tf.data.Dataset.from_generator(self.subbatch_generator, args=(start, end), output_types=TYPE, output_shapes=(BATCH_SIZE, None))
        return dataset.map(lambda batch: (batch[:, :-1], batch[:, 1:]))

In [3]:
def build_model(vocab_size, embedding_dim, rnn_units):
    return tf.keras.Sequential([
        tf.keras.layers.Masking(mask_value=0, batch_input_shape=[BATCH_SIZE, None]),
        tf.keras.layers.Embedding(vocab_size, embedding_dim),
        tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        tf.keras.layers.Dense(vocab_size),
    ])

In [4]:
checkpoint_dir = './training_checkpoints-2' # Directory where the checkpoints will be saved
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}") # Name of the checkpoint files

checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True)

class ModelStateResetter(tf.keras.callbacks.Callback):
    def __init__(self):
        self.last_total_length = 0

    def on_batch_end(self, batch, logs={}):
        average_batch_length = logs.get('average_batch_length', 0)
        total_length = int(round(average_batch_length * (batch + 1)))
        current_batch_length = total_length - self.last_total_length
        self.last_total_length = total_length
        
        if current_batch_length < BATCHED_ITEM_LENGTH:
            self.model.reset_states()
        
model_state_resetter_callback = ModelStateResetter()

def loss(labels, logits):
    return tf.keras.losses.sparse_categorical_crossentropy(labels, logits, from_logits=True)

def average_batch_length(true_labels, predictions):
    return tf.shape(true_labels)[1]

model = build_model(vocab_size = subword_text_encoder.vocab_size, embedding_dim=512, rnn_units=768)
# model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.compile(optimizer='adam', loss=loss, metrics=[average_batch_length])

In [5]:
articles = Articles('page_revisions_text')

In [7]:
model.fit(articles.dataset(0, 200000), epochs=5, callbacks=[checkpoint_callback, model_state_resetter_callback])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x1e5f2c07188>

In [8]:
model.fit(articles.dataset(0, 200000), epochs=5, callbacks=[checkpoint_callback, model_state_resetter_callback])

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x1e5f2e5a448>

In [18]:
model.fit(articles.dataset(200000, len(articles.articles)), callbacks=[checkpoint_callback, model_state_resetter_callback])



<tensorflow.python.keras.callbacks.History at 0x1e5f2ecf6c8>

In [19]:
model.fit(articles.dataset(0, len(articles.articles)), callbacks=[checkpoint_callback, model_state_resetter_callback])



<tensorflow.python.keras.callbacks.History at 0x1e5f2f2d208>

In [20]:
model.fit(articles.dataset(0, len(articles.articles)), callbacks=[checkpoint_callback, model_state_resetter_callback])



<tensorflow.python.keras.callbacks.History at 0x1e5f2f5fbc8>

Видяхме колко време отнема обработката на batch-ове с динамичен размер. GPU-тата са известни с къси pipeline-ове за инструкции и лош branch prediction. Ще е интересно да видим дали batch-ове с фиксиран размер ще доведат до expand-ване на цикъла по размерността за `BATCHED_ITEM_LENGTH` и потенциално да доведат до по-бързи итерации.

In [6]:
class Articles:
    EMPTY_ARTICLE = np.array([], dtype=TYPE) # used for padding

    def __init__(self, path):
        with open(path, 'rb') as text_file:
            articles = [np.array(subword_text_encoder.encode(article), dtype=TYPE) for article in text_file.read().split(b'\0')]
            self.articles = sorted(articles, key=len)

    def articles_generator(self, start, end):
        for _ in range(BATCH_SIZE - ((end - start - 1) % BATCH_SIZE + 1)):
            yield self.EMPTY_ARTICLE

        for article in itertools.islice(self.articles, start, end):
            yield article

    def subbatch_generator(self, start, end):
        dataset = tf.data.Dataset.from_generator(self.articles_generator, args=(start, end), output_types=TYPE)
        dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=([None]), drop_remainder=True)
        dataset = dataset.shuffle(50)

        for batch in dataset.as_numpy_iterator():
            remaining = batch
            while remaining.shape[1] >= BATCHED_ITEM_LENGTH + 1:
                yield remaining[:, :BATCHED_ITEM_LENGTH + 1]
                remaining = remaining[:, BATCHED_ITEM_LENGTH:]

            if remaining.shape[1] != 0:
                yield np.hstack([remaining, np.zeros([BATCH_SIZE, BATCHED_ITEM_LENGTH - remaining.shape[1] + 1])])

    def dataset(self, start = None, end = None):
        if start == None:
            start = 0

        if end == None:
            end = len(self.articles)

        dataset = tf.data.Dataset.from_generator(self.subbatch_generator, args=(start, end), output_types=TYPE, output_shapes=(BATCH_SIZE, BATCHED_ITEM_LENGTH + 1))
        return dataset.map(lambda batch: (batch[:, :-1], batch[:, 1:]))

In [7]:
articles = Articles('page_revisions_text')

In [8]:
def build_model(vocab_size, embedding_dim, rnn_units):
    return tf.keras.Sequential([
        tf.keras.layers.Masking(mask_value=0, batch_input_shape=[BATCH_SIZE, BATCHED_ITEM_LENGTH]),
        tf.keras.layers.Embedding(vocab_size, embedding_dim),
        tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        tf.keras.layers.GRU(rnn_units, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'),
        tf.keras.layers.Dense(vocab_size),
    ])

In [9]:
model = build_model(vocab_size = subword_text_encoder.vocab_size, embedding_dim=512, rnn_units=768)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.compile(optimizer='adam', loss=loss, metrics=[average_batch_length])

In [10]:
model.fit(articles.dataset(), callbacks=[checkpoint_callback, model_state_resetter_callback])



<tensorflow.python.keras.callbacks.History at 0x1c008e798c8>

In [13]:
model.fit(articles.dataset(), callbacks=[checkpoint_callback, model_state_resetter_callback])



<tensorflow.python.keras.callbacks.History at 0x1bf97e72148>

Никакъв ефект. Скоростта е абсолютно идентична.