## Parameter

In [None]:
# These parameters can be injected from Papermill
model_type = "pre_ln"
train_file = "wikitext-103-raw/wiki.train.raw"
valid_file = "wikitext-103-raw/wiki.valid.raw"
epochs = 10
batch_size = 2
max_learning_rate = 1e-4
warmup_steps = 0
clipnorm = 1.0
fp16 = False
save_model_dir = f"output/tfdlg_train-{model_type}-model"
tensorboard_dir = f"output/tensorboard/{save_model_dir}-tensorboard"

In [None]:
# Assert parameters
assert model_type in ["pre_ln", "post_ln", "transformers"]

## Configure GPU

In [None]:
from tfdlg.utils import set_memory_growth
from tfdlg.utils import set_mixed_precision_policy

In [None]:
set_memory_growth()

In [None]:
if fp16:
    set_mixed_precision_policy()

## Setup tokenizer

In [None]:
# Install transformers by HuggingFace to use GPT2 tokenizer
! pip install transformers==3.4.0
# Enable widgetsnbextention to avoid the following error when running GPT2.from_pretrained method
#     ImportError: IProgress not found. Please update jupyter and ipywidgets.
! jupyter nbextension enable --py widgetsnbextension

In [None]:
# setup tokenizer
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

## Prepare model config

In [None]:
from tfdlg.configs import GPT2SmallConfig

config = GPT2SmallConfig()

# Set the larger number of vocab size than 33,278, which is the vocab size of Wikitext-2
config.vocab_size = tokenizer.vocab_size

In [None]:
config

## Prepare Dataset

In [None]:
from pathlib import Path
from urllib.request import urlretrieve
import zipfile
import numpy as np


def read_file(_filepath):
    return (t.strip("\n") for t in open(_filepath))


In [None]:
from tfdlg.data import BlockDataset


train_dataset = BlockDataset.from_generator(
    generator=lambda: read_file(train_file),
    encode_fn=tokenizer.encode,
    block_size=config.context_size,
    batch_size=batch_size,
    shuffle=True
)
valid_dataset = BlockDataset.from_generator(
    generator=lambda: read_file(valid_file),
    encode_fn=tokenizer.encode,
    block_size=config.context_size,
    batch_size=batch_size,
    shuffle=False
)

In [None]:
num_train_steps = sum(1 for _ in train_dataset)
num_valid_steps = sum(1 for _ in valid_dataset)
print("Train steps:", num_train_steps)
print("Valid steps:", num_valid_steps)

## Transformers model implementation

In [None]:
from transformers import TFGPT2LMHeadModel
from transformers import GPT2Config
import tensorflow.keras as keras
import tensorflow as tf

In [None]:
class TransformersGPT2(keras.Model):
    def __init__(self, config):
        super().__init__()
        tf_config = GPT2Config(
            n_layers=config.num_layers,
            n_embd=config.d_model,
            n_head=config.num_heads,
            n_inner=config.d_ff,
            vocab_size=config.vocab_size,
            n_ctx=config.context_size,
            n_positions=config.context_size,
            attn_pdrop=config.attention_dropout_rate,
            resid_pdrop=config.residual_dropout_rate,
            embd_pdrop=config.embedding_dropout_rate,
            layer_norm_epsilon=config.epsilon,
            activation_function="gelu_new",  # Default value of transformers implementation
            
        )
        self._decoder = TFGPT2LMHeadModel(tf_config)
        
    def call(self, inputs, training):
        inputs = tf.cast(inputs, tf.int32)
        x = self._decoder(inputs, training=training)
        return x[0]



## Prepare Model

In [None]:
from tfdlg.losses import PaddingLoss
from tfdlg.schedules import WarmupLinearDecay
import tensorflow.keras as keras



def train(
    _model,
    _train_dataset,
    _valid_dataset,
    _epochs,
    _warmup_steps,
    _num_train_steps,
    _max_learning_rate,
    _clipnorm,
    _tensorboard_dir
):
    schedule = WarmupLinearDecay(
        max_learning_rate=_max_learning_rate,
        warmup_steps=_warmup_steps,
        training_steps=_num_train_steps*_epochs
    )
    optimizer = keras.optimizers.Adam(schedule, beta_1=0.9, beta_2=0.999, epsilon=1e-8, clipnorm=_clipnorm)
    _model.compile(loss=PaddingLoss(), optimizer=optimizer)

    history = _model.fit(
        _train_dataset,
        validation_data=_valid_dataset,
        epochs=_epochs,
        callbacks=[
            keras.callbacks.EarlyStopping(patience=1, restore_best_weights=True),
            keras.callbacks.TensorBoard(
                log_dir=tensorboard_dir,
                update_freq=100,
                profile_batch=0,
            )
        ],
        verbose=2,
    )


In [None]:
if model_type == "pre_ln":
    from tfdlg.models import PreLNDecoder
    model = PreLNDecoder(config)
elif model_type == "post_ln":
    from tfdlg.models import PostLNDecoder 
    model = PostLNDecoder(config)
elif model_type == "transformers":
    model = TransformersGPT2(config)
else:
    raise Exception("Model type is wrong")

In [None]:
model.build(input_shape=(None, config.context_size))
model.summary()

In [None]:
train(
    model,
    train_dataset,
    valid_dataset,
    epochs,
    warmup_steps,
    num_train_steps,
    max_learning_rate,
    clipnorm,
    tensorboard_dir
)

In [None]:
from tfdlg.eval import perplexity

print("Validation PPL:", perplexity(model, valid_dataset))

In [None]:
from tfdlg.utils import save_model

save_model(save_model_dir, model, config)