In [60]:
#@formatter:off
%load_ext autoreload
%autoreload 2
#@formatter:on

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [96]:
import random
import string

import tensorflow as tf
from src.Generation import GenerationCallback, Generator
from src.Model import TokenAndPositionEmbedding
from src.Model import Transformer
from src.Model import WarmupScheduler
from src import Utils
from src.Configs import ModelConfig
from src.Configs import TrainingConfig
import pickle


In [97]:
def create_model(model_config: ModelConfig):
    # TODO: Remove this if we're using tokenizer
    embedding = TokenAndPositionEmbedding(model_config.M_MAX_LEN, model_config.M_VOCAB_SZ, model_config.M_DIM_EMB)
    transformer = Transformer(model_config.M_DIM_EMB, model_config.M_ATT_HEADS, model_config.M_DIM_FFN)

    l_input = tf.keras.layers.Input(shape=(model_config.M_MAX_LEN,), dtype=tf.int32)
    l_emb = embedding(l_input)
    l_trans = transformer(l_emb)
    l_output = tf.keras.layers.Dense(model_config.M_VOCAB_SZ)(l_trans)

    m = tf.keras.Model(inputs=l_input, outputs=[l_output, l_trans])
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
    learning_rate = WarmupScheduler(model_config.M_DIM_EMB, model_config.M_WARMUP_STEPS)
    optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

    m.compile("adam", loss=[loss_fn, None])
    return m

In [98]:
def create_dataset(file_pth, batch_sz, buf_sz=1000, shuffle=True):
    # Shuffle the data and create batches
    if shuffle:
        random.shuffle(file_pth)
    ds = tf.data.TextLineDataset(file_pth)
    ds = ds.shuffle(buffer_size=buf_sz)
    ds = ds.batch(batch_sz)
    return ds


def create_tokenizer(dataset, max_vocab_size, max_seq_len):
    def preprocess_txt(input_string):
        # Preprocessing for word-level model
        s1 = tf.strings.lower(input_string)
        return tf.strings.regex_replace(s1, f"([{string.punctuation}])", r" \1")

    # Vectorization of the data
    vectorize = tf.keras.layers.TextVectorization(
        standardize=preprocess_txt,
        max_tokens=max_vocab_size - 1,
        output_mode="int",
        output_sequence_length=max_seq_len + 1,
    )
    vectorize.adapt(dataset)
    vocab = vectorize.get_vocabulary()
    return vectorize, vocab


config_model, config_training = ModelConfig(), TrainingConfig()
# Read in the data and create the dataset
dataset = create_dataset(config_training.T_DATASET, config_training.T_BATCH_SIZE)
# Create the tokenizer
tokenizer, vocab = create_tokenizer(dataset, config_model.M_VOCAB_SZ, config_model.M_MAX_LEN)

In [65]:
for d in dataset.take(1):
    print("=" * 80)
    print(d[0])
    print("=" * 80)
    print(d[1])

tf.Tensor(b'So many women wanna call me baby', shape=(), dtype=string)
tf.Tensor(b'Why you gotta do me like that?', shape=(), dtype=string)


In [66]:
def create_sequences(txt):
    txt = tf.expand_dims(txt, -1)
    txt_tok = tokenizer(txt)
    return txt_tok[:, :-1], txt_tok[:, 1:]


dataset = dataset.map(create_sequences).prefetch(tf.data.AUTOTUNE)

In [67]:
for d in dataset.take(1):
    print("=" * 80)
    print(d[0])
    print("=" * 80)
    print(d[1])

tf.Tensor(
[[   71    98    40 ...     0     0     0]
 [ 9408  3205     2 ...     0     0     0]
 [   33     9    83 ...     0     0     0]
 ...
 [  949    51   187 ...     0     0     0]
 [   52   147  8454 ...     0     0     0]
 [  379 19253  1158 ...     0     0     0]], shape=(256, 100), dtype=int64)
tf.Tensor(
[[   98    40     3 ...     0     0     0]
 [ 3205     2 10521 ...     0     0     0]
 [    9    83     3 ...     0     0     0]
 ...
 [   51   187  5640 ...     0     0     0]
 [  147  8454  2491 ...     0     0     0]
 [19253  1158  4869 ...     0     0     0]], shape=(256, 100), dtype=int64)


In [68]:
model = create_model(config_model)
callbacks, tb_file_writer = Utils.create_callbacks("logs", model)
gen_callback = GenerationCallback("i will always be", 100, config_model.M_MAX_LEN, vocab)
callbacks.append(gen_callback)

# Model Dir: logs/model_7
 - History Path: logs/model_7/history.csv
 - Checkpoint Path: logs/model_7/checkpoints.h5
 - TB Path: logs/model_7


In [69]:
model.summary()

Model: "model_7"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_8 (InputLayer)        [(None, 100)]             0         
                                                                 
 token_and_position_embeddin  (None, 100, 128)         2572800   
 g_7 (TokenAndPositionEmbedd                                     
 ing)                                                            
                                                                 
 transformer_7 (Transformer)  (None, 100, 128)         264192    
                                                                 
 dense_23 (Dense)            (None, 100, 20000)        2580000   
                                                                 
Total params: 5,416,992
Trainable params: 5,416,992
Non-trainable params: 0
_________________________________________________________________


In [70]:
# wandb.tensorboard.patch(root_logdir="logs")
# wandb.init(project='transformer')
model.fit(dataset, verbose=1, epochs=config_training.T_EPOCHS, callbacks=callbacks)

Epoch 1/20
    675/Unknown - 86s 126ms/step - loss: 0.7827 - dense_23_loss: 0.7827
Generated:
i will always be
Epoch 2/20
Generated:
i will always be
Epoch 3/20
Generated:
i will always be
Epoch 4/20
Generated:
i will always be
Epoch 5/20
Generated:
i will always be
Epoch 6/20
Generated:
i will always be
Epoch 7/20
Generated:
i will always be
Epoch 8/20
Generated:
i will always be
Epoch 9/20
Generated:
i will always be
Epoch 10/20
Generated:
i will always be
Epoch 11/20
Generated:
i will always be
Epoch 12/20
Generated:
i will always be
Epoch 13/20

KeyboardInterrupt: 

In [99]:

model.save("model_save")
with open("model_save/vocab.pkl", "wb") as f:
    pickle.dump(vocab, f)
with open("model_save/config_model.pkl", "wb") as f:
    pickle.dump(config_model, f)
with open("model_save/config_training.pkl", "wb") as f:
    pickle.dump(config_training, f)



INFO:tensorflow:Assets written to: model_save/assets


INFO:tensorflow:Assets written to: model_save/assets


In [87]:
# model.save("model_warmup.h5")
generator = Generator(model, config_model.M_MAX_LEN, vocab)
generated_txt = generator.generate("i will always be", 50)
print(generated_txt)

Generated 51 tokens
i will always be where my niggaz at ?                                               where my niggaz at ?                                              
