## Parameter

In [None]:
# These parameters can be injected from Papermill
model_type = "pre_ln"
train_file = "wikitext-103-raw/wiki.train.raw"
valid_file = "wikitext-103-raw/wiki.valid.raw"
epochs = 10
batch_size = 2
max_learning_rate = 1e-4
warmup_steps = 0
save_model_dir = "tfchat_model"

In [None]:
# Assert parameters
assert model_type in ["pre_ln", "post_ln", "min_gpt", "transformers"]

## Installation

In [None]:
!apt install -y git
!pip install git+https://github.com/noriyukipy/tfchat@8b3551

## Configure GPU

In [None]:
from tfchat.utils import set_memory_growth

In [None]:
set_memory_growth()

## Setup tokenizer

In [None]:
# Install transformers by HuggingFace to use GPT2 tokenizer
! pip install transformers==3.4.0
# Enable widgetsnbextention to avoid the following error when running GPT2.from_pretrained method
#     ImportError: IProgress not found. Please update jupyter and ipywidgets.
! jupyter nbextension enable --py widgetsnbextension

In [None]:
# setup tokenizer
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

## Prepare model config

In [None]:
from tfchat.configs import GPT2SmallConfig

config = GPT2SmallConfig()

# Set the larger number of vocab size than 33,278, which is the vocab size of Wikitext-2
config.vocab_size = tokenizer.vocab_size

In [None]:
config

## Prepare Dataset

In [None]:
from pathlib import Path
from urllib.request import urlretrieve
import zipfile
import numpy as np


def encode_file(_tokenizer, _filepath):
    ids = []
    with open(_filepath) as f:
        for line in f.readlines():
            text = line.strip("\n")
            ids.extend(_tokenizer.encode(text))

    return np.array(ids, dtype=np.int32)

In [None]:
train_ids = encode_file(tokenizer, train_file)
valid_ids = encode_file(tokenizer, valid_file)

In [None]:
print("Train:", train_ids.shape)
print("Valid:", valid_ids.shape)

In [None]:
print(train_ids.shape)
print(valid_ids.shape)

In [None]:
from tfchat.data import BlockDataset


dataset = BlockDataset(block_size=config.context_size, batch_size=batch_size)

train_dataset = dataset.build(train_ids, shuffle=True)
valid_dataset = dataset.build(valid_ids, shuffle=False)

In [None]:
num_train_steps = len([_ for _ in train_dataset])
num_valid_steps = len([_ for _ in valid_dataset])
print("Train steps:", num_train_steps)
print("Valid steps:", num_valid_steps)

## Transformers model implementation

In [None]:
from transformers import TFGPT2LMHeadModel
from transformers import GPT2Config
import tensorflow.keras as keras
import tensorflow as tf
from tfchat.models import create_combined_mask

In [None]:
class TransformersGPT2(keras.Model):
    def __init__(self, config):
        super().__init__()
        tf_config = GPT2Config(
            n_layers=config.num_layers,
            n_embd=config.d_model,
            n_head=config.num_heads,
            n_inner=config.d_ff,
            vocab_size=config.vocab_size,
            n_ctx=config.context_size,
            n_positions=config.context_size,
            attn_pdrop=config.attention_dropout_rate,
            resid_pdrop=config.residual_dropout_rate,
            embd_pdrop=config.embedding_dropout_rate,
            layer_norm_epsilon=config.epsilon,
            activation_function="gelu_new",  # Default value of transformers implementation
            
        )
        self._decoder = TFGPT2LMHeadModel(tf_config)
        
    def call(self, inputs, training):
        inputs = tf.cast(inputs, tf.int32)
        x = self._decoder(inputs, training=training)
        return x[0]



## Prepare Model

In [None]:
from tfchat.losses import PaddingLoss
from tfchat.schedules import WarmupLinearDecay
import tensorflow.keras as keras



def train(_model, _train_dataset, _valid_dataset, _epochs, _warmup_steps, _num_train_steps, _max_learning_rate):
    schedule = WarmupLinearDecay(max_learning_rate=_max_learning_rate,
                                 warmup_steps=_warmup_steps,
                                 training_steps=_num_train_steps*_epochs)
    optimizer = keras.optimizers.Adam(schedule, beta_1=0.9, beta_2=0.999, epsilon=1e-8, clipnorm=1.0)
    _model.compile(loss=PaddingLoss(), optimizer=optimizer)


    history = _model.fit(
        _train_dataset,
        validation_data=_valid_dataset,
        epochs=_epochs,
        callbacks=[
            keras.callbacks.EarlyStopping(patience=1, restore_best_weights=True),
            # If you want to save chekcpoints, remove the next comment out
            #keras.callbacks.ModelCheckpoint("keras_model/", save_best_only=True)
        ],
        verbose=2,
    )


In [None]:
if model_type == "pre_ln":
    from tfchat.models import PreLNDecoder
    model = PreLNDecoder(config)
elif model_type == "post_ln":
    from tfchat.models import PostLNDecoder 
    model = PostLNDecoder(config)
elif model_type == "transformers":
    model = TransformersGPT2(config)
elif model_type == "min_gpt":
    from mingpt.model import GPT, GPTConfig
    mconf = GPTConfig(config.vocab_size, config.context_size,
                      n_layer=config.num_layers, n_head=config.num_heads, n_embd=config.d_model)
    model = GPT(mconf)
else:
    raise Exception("Model type is wrong")

In [None]:
model.build(input_shape=(None, config.context_size))
model.summary()

In [None]:
train(model, train_dataset, valid_dataset, epochs, warmup_steps, num_train_steps, max_learning_rate)

In [None]:
from tfchat.eval import perplexity

print("Validation PPL:", perplexity(model, valid_dataset))

In [None]:
from tfchat.utils import save_model

save_model(save_model_dir, model, config)