In [78]:
import os
import random
import string

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.python.keras.layers import TextVectorization
import wandb
wandb.init(config=tf.flags.FLAGS, sync_tensorboard=True)

In [79]:
vocab_size = 20000
maxlen = 80

num_heads = 2
embed_dim = 256
feed_forward_dim = 256
batch_size = 256

In [80]:
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    i = tf.range(n_dest)[:, None]
    j = tf.range(n_src)
    m = i >= j - n_src + n_dest
    mask = tf.cast(m, dtype)
    mask = tf.reshape(mask, [1, n_dest, n_src])
    mult = tf.concat([tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0)
    return tf.tile(mask, mult)


class Transformer(layers.Layer):
    def __init__(self, embedding_dim, num_att_heads, state_dims, dropout_rate=0.1):
        super(Transformer, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_att_heads = num_att_heads
        self.state_dims = state_dims
        self.dropout_rate = dropout_rate
        self.attention = layers.MultiHeadAttention(num_att_heads, embedding_dim)
        self.feed_forward = keras.Sequential([
            layers.Dense(state_dims, activation="relu"),
            layers.Dense(embedding_dim)
        ])
        self.norm1, self.norm2 = layers.LayerNormalization(epsilon=1e-6), layers.LayerNormalization(epsilon=1e-6)
        self.drop1, self.drop2 = layers.Dropout(dropout_rate), layers.Dropout(dropout_rate)

    def call(self, inputs):
        inp_shape = tf.shape(inputs)
        batch_sz, seq_len = inp_shape[0], inp_shape[1]
        causal_mask = causal_attention_mask(batch_sz, seq_len, seq_len, tf.bool)
        attention_out = self.attention(inputs, inputs, attention_mask=causal_mask)
        attention_out = self.drop1(attention_out)
        out1 = self.norm1(inputs + attention_out)
        feed_forward_out = self.feed_forward(out1)
        feed_forward_out = self.drop2(feed_forward_out)
        return self.norm2(out1 + feed_forward_out)

    def get_config(self):
        config = super().get_config()
        config.update({
            "embedding_dim": self.embedding_dim,
            "num_att_heads": self.num_att_heads,
            "state_dims": self.state_dims,
            "dropout_rate": self.dropout_rate,
        })
        return config

class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, max_len, vocab_size, embed_dim):
        super(TokenAndPositionEmbedding, self).__init__()
        self.max_len = max_len
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.pos_emb = layers.Embedding(input_dim=max_len, output_dim=embed_dim)

    def call(self, x):
        max_len = tf.shape(x)[-1]
        pos = tf.range(start=0, limit=max_len, delta=1)
        pos = self.pos_emb(pos)
        x = self.token_emb(x)
        return x + pos

    def get_config(self):
        config = super().get_config()
        config.update({
            "max_len": self.max_len,
            "vocab_size": self.vocab_size,
            "embed_dim": self.embed_dim,
        })
        return config

In [81]:
def create_model():
    # TODO: Remove this if we're using tokenizer
    embedding = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
    transformer = Transformer(embed_dim, num_heads, feed_forward_dim)

    l_input = layers.Input(shape=(maxlen,), dtype=tf.int32)
    l_emb = embedding(l_input)
    l_trans = transformer(l_emb)
    l_output = layers.Dense(vocab_size)(l_trans)

    m = keras.Model(inputs=l_input, outputs=[l_output, l_trans])
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    m.compile("adam", loss=[loss_fn, None])
    return m

In [82]:
def get_files(dirs):
    fns = []
    for d in dirs:
        fns.extend(os.path.join(d, f) for f in os.listdir(d))
    return fns


def create_dataset(file_pth, batch_sz, buf_sz=1000, shuffle=True):
    # Shuffle the data and create batches
    if shuffle:
        random.shuffle(file_pth)
    ds = tf.data.TextLineDataset(file_pth)
    ds = ds.shuffle(buffer_size=buf_sz)
    ds = ds.batch(batch_sz)
    return ds


def create_tokenizer(dataset, max_vocab_size):
    def preprocess_txt(input_string):
        # Preprocessing for word-level model
        s1 = tf.strings.lower(input_string)
        return tf.strings.regex_replace(s1, f"([{string.punctuation}])", r" \1")

    # Vectorization of the data
    vectorize = TextVectorization(
        standardize=preprocess_txt,
        max_tokens=vocab_size - 1,
        output_mode="int",
        output_sequence_length=maxlen + 1,
    )
    vectorize.adapt(dataset)
    vocab = vectorize.get_vocabulary()
    return vectorize, vocab


# Read in the data and create the dataset
d_files = get_files(["data_test"])
dataset = create_dataset(d_files, batch_size)
# Create the tokenizer
tokenizer, vocab = create_tokenizer(dataset, vocab_size)


def create_sequences(txt):
    txt = tf.expand_dims(txt, -1)
    txt_tok = tokenizer(txt)
    return txt_tok[:, :-1], txt_tok[:, 1:]


dataset = dataset.map(create_sequences).prefetch(tf.data.AUTOTUNE)

In [83]:
for d in dataset.take(1):
    print(d)

(<tf.Tensor: shape=(256, 80), dtype=int64, numpy=
array([[   3,   17,  376, ...,    0,    0,    0],
       [  72,   13, 7782, ...,    0,    0,    0],
       [ 235,    4, 5345, ...,    0,    0,    0],
       ...,
       [   9,   16,    6, ...,    0,    0,    0],
       [   3,   50, 1756, ...,    0,    0,    0],
       [  85,    3, 1548, ...,    0,    0,    0]])>, <tf.Tensor: shape=(256, 80), dtype=int64, numpy=
array([[  17,  376,   19, ...,    0,    0,    0],
       [  13, 7782, 1032, ...,    0,    0,    0],
       [   4, 5345,  201, ...,    0,    0,    0],
       ...,
       [  16,    6,  239, ...,    0,    0,    0],
       [  50, 1756,   33, ...,    0,    0,    0],
       [   3, 1548,   30, ...,    0,    0,    0]])>)


In [84]:
class TextGenerator(keras.callbacks.Callback):
    def __init__(self, max_tokens, start_tokens, index_to_word, top_k=10, print_every=1):
        self.max_tokens = max_tokens
        self.start_tokens = start_tokens
        self.index_to_word = index_to_word
        self.print_every = print_every
        self.k = top_k

    def sample_from(self, logits):
        logits, indices = tf.math.top_k(logits, k=self.k, sorted=True)
        indices = np.asarray(indices).astype("int32")
        preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0]
        preds = np.asarray(preds).astype("float32")
        return np.random.choice(indices, p=preds)

    def detokenize(self, number):
        return self.index_to_word[number]

    def on_epoch_end(self, epoch, logs=None):
        start_tokens = list(self.start_tokens)
        if (epoch + 1) % self.print_every != 0:
            return
        num_tokens_generated = 0
        tokens_generated = []
        while num_tokens_generated <= self.max_tokens:
            pad_len = maxlen - len(start_tokens)
            sample_index = len(start_tokens) - 1
            if pad_len < 0:
                x = start_tokens[:maxlen]
                sample_index = maxlen - 1
            elif pad_len > 0:
                x = start_tokens + [0] * pad_len
            else:
                x = start_tokens
            x = np.array([x])
            y, _ = self.model.predict(x)
            sample_token = self.sample_from(y[0][sample_index])
            tokens_generated.append(sample_token)
            start_tokens.append(sample_token)
            num_tokens_generated = len(tokens_generated)
        txt = " ".join([self.detokenize(_) for _ in self.start_tokens + tokens_generated])
        print(f"Generated:\n{txt}\n")

In [85]:
def create(start_prompt, vocabulary):
    # Tokenize starting prompt
    word_to_index = {word: index for index, word in enumerate(vocabulary)}
    prompt_tokens = [word_to_index.get(_, 1) for _ in start_prompt.lower().split()]
    return TextGenerator(40, prompt_tokens, vocabulary)

def create_callbacks(base_dir, model, defaults: list = None):
    import tensorflow as tf
    print(f'base_dir: {base_dir}')
    dir_models = os.path.join(base_dir, model.name)
    path_csv = os.path.join(dir_models, 'history.csv')
    print("History CSV:", path_csv)
    path_ckp = os.path.join(dir_models, 'checkpoints.h5')
    print("Checkpoint:", path_ckp)
    path_tb = os.path.join(dir_models, "logs")
    tb_file_writer = tf.summary.create_file_writer(path_tb)
    callbacks = [] if defaults is None else defaults
    callbacks.append(tf.keras.callbacks.CSVLogger(path_csv, separator=",", append=True))
    callbacks.append(create("I will always", vocab))
    # callbacks.append(tf.keras.callbacks.ModelCheckpoint(path_ckp,
    #                                                     monitor='loss',
    #                                                     save_best_only=True,
    #                                                     mode='auto',
    #                                                     verbose=0))
    os.makedirs(dir_models, exist_ok=True)
    return callbacks, tb_file_writer

In [86]:
model = create_model()

In [87]:
wandb.init(project='transformer', sync_tensorboard=True)
callbacks = create_callbacks("logs", model)
model.fit(dataset, verbose=2, epochs=5, callbacks=callbacks)

Epoch 1/5
Generated:
i will always be the things                                      

696/696 - 92s - loss: 0.7550 - dense_32_loss: 0.7550 - 92s/epoch - 132ms/step
Epoch 2/5
Generated:
i will always get away to go                                     

696/696 - 90s - loss: 0.5364 - dense_32_loss: 0.5364 - 90s/epoch - 130ms/step
Epoch 3/5
Generated:
i will always be home                                       

696/696 - 95s - loss: 0.4975 - dense_32_loss: 0.4975 - 95s/epoch - 137ms/step
Epoch 4/5
Generated:
i will always see the things .                                     

696/696 - 94s - loss: 0.4691 - dense_32_loss: 0.4691 - 94s/epoch - 134ms/step
Epoch 5/5
Generated:
i will always be the way of my life .                                  

696/696 - 97s - loss: 0.4462 - dense_32_loss: 0.4462 - 97s/epoch - 140ms/step


<keras.callbacks.History at 0x7f63c20c9460>

In [89]:
model.save("model.h5")