In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import tensorflow as tf
tf.keras.mixed_precision.set_global_policy('mixed_float16')

# Dataset preparation

In [None]:
from datasets import load_dataset
dataset = load_dataset("conll2003")

In [None]:
# len(dataset["train"])
NER_TAGS = {
    "O": 0,
    "B-PER": 1,
    "I-PER": 2,
    "B-ORG": 3,
    "I-ORG": 4,
    "B-LOC": 5,
    "I-LOC": 6,
    "B-MISC": 7,
    "I-MISC": 8,
}
NER_TAGS_INV = {v: k for k, v in NER_TAGS.items()}

In [None]:
dataset["train"][100]

In [None]:
from tqdm import tqdm

train_data = []
for data in tqdm(dataset["train"]):
    tags = [NER_TAGS_INV[tag] for tag in data["ner_tags"]]
    train_data.append((data["tokens"], tags))

In [None]:
validation_data = []
for data in tqdm(dataset["validation"]):
    tags = [NER_TAGS_INV[tag] for tag in data["ner_tags"]]
    validation_data.append((data["tokens"], tags))

In [None]:
import numpy as np
np.quantile([len(data[0]) for data in train_data], 0.95)

In [None]:
def prepare_data(item):
    targets_data = []
    current_ner = []

    for token, tag_name in zip(*item):

        if tag_name.startswith("B-"):
            if len(current_ner) == 0:
                current_ner.append((token, tag_name[2:]))
                continue
            else:
                targets_data.append(current_ner)
                current_ner = []
                current_ner.append((token, tag_name[2:]))
                continue

        if tag_name.startswith("I-") and len(current_ner) > 0:
            current_ner.append((token, tag_name[2:]))
            continue

        if len(current_ner) > 0:
            targets_data.append(current_ner)
            current_ner = []

    if len(current_ner) > 0:
        targets_data.append(current_ner)


    context = " ".join(item[0])
    
    target_text = ""
    for tokens in targets_data:
        words = " ".join([word for word, tag in tokens])
        target_text += f"{words}//{tokens[0][1]}\n"

    return context, target_text


item = train_data[5]
for token, tag_name in zip(*item):
    print(f"{token:15}{tag_name}")

context_text, target_text = prepare_data(item)

print(f"prompt:{context_text}")
print(f"target:{target_text}")

In [None]:
prepared_train_data = [prepare_data(item) for item in tqdm(train_data)]
prepared_validation_data = [prepare_data(item) for item in tqdm(validation_data)]

# Build tf.data.Dataset iterators

In [None]:
from tf_nano_gpt.model import GPT2Tokenizer

sequence_length = 256
context_length =  sequence_length // 2
tokenizer = GPT2Tokenizer()

In [None]:
def prepare_lm_targets(context, target):
    context_tokens, target_tokens = tokenizer.tokenize_sample(context, target)
    context_tokens = tokenizer.pad_or_slice(context_tokens, context_length)
    target_tokens = tokenizer.pad_or_slice(target_tokens, context_length + 1)

    x = tf.concat([context_tokens, target_tokens[:-1]], 0)
    y = tf.concat([context_tokens, target_tokens[1:]], 0)

    # simple mask to remove context from the loss computation as well as pad tokens
    mask = tf.cast(tf.abs(x - y) > 0, tf.int32)
    targets_ids = tf.stack([y, mask], 1)

    return {"inputs_ids": x, "targets_ids": targets_ids}, targets_ids

In [None]:
def build_dataset_iterator(prepared_data, is_training: bool, batch_size: int) -> tf.data.Dataset:
    dataset = tf.data.Dataset.from_tensor_slices(prepared_data)
    dataset = dataset.map(
        lambda x: prepare_lm_targets(x[0], x[1]), num_parallel_calls=tf.data.AUTOTUNE
    )

    if is_training:
        dataset = dataset.repeat(-1)
        dataset = dataset.shuffle(2048)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset

In [None]:
batch_size = 8
train_dataset = build_dataset_iterator(prepared_train_data, is_training=True, batch_size=batch_size)
train_dataset

In [None]:
validation_dataset = build_dataset_iterator(prepared_validation_data, is_training=False, batch_size=batch_size)
validation_dataset

In [None]:
for x, y in train_dataset:
    break

x = x["inputs_ids"]
pad_str = tokenizer.pad_token_str
print(tokenizer.detokenize(x[0, :context_length]).numpy().decode("utf-8").replace(pad_str, ""))
print(tokenizer.detokenize(y[0, context_length:, 0]).numpy().decode("utf-8").replace(pad_str, ""))

# Build model

In [None]:
from transformers import TFGPT2Model
from tf_nano_gpt.model import freeze_embeddings, freeze_layers
from tf_nano_gpt.metrics import masked_lm_loss, masked_accuracy

base_model = TFGPT2Model.from_pretrained('gpt2')

In [None]:
freeze_embeddings(base_model)
freeze_layers(base_model, num_blocks_to_freeze=8, use_lora=False)

In [None]:
def model_predict_fn(
    inputs_ids: tf.Tensor,
    past_key_values: tf.Tensor = None,
    position_ids: tf.Tensor = None,
):
    encoded_input = {
        "input_ids": inputs_ids,
        "attention_mask": tf.ones_like(inputs_ids),
        "past_key_values": past_key_values,
        "position_ids": position_ids,
    }

    output = base_model(encoded_input)
    last_hidden_state = output.last_hidden_state
    past_key_values = output.past_key_values

    last_hidden_state = tf.keras.layers.Dropout(0.2)(last_hidden_state)

    logits = base_model.transformer.wte(last_hidden_state, mode="linear")
    return logits, past_key_values


inputs_ids = tf.keras.layers.Input(
    shape=(sequence_length,), dtype=tf.int32, name="inputs_ids"
)
targets_ids = tf.keras.layers.Input(
    shape=(sequence_length, 2), dtype=tf.int32, name="targets_ids"
)

logits, _ = model_predict_fn(inputs_ids)

loss_value = masked_lm_loss(targets_ids, logits)
accuracy_value = masked_accuracy(targets_ids, tf.argmax(logits, -1))

train_model = tf.keras.Model(inputs=[inputs_ids, targets_ids], outputs=logits)

train_model.add_loss(loss_value)
train_model.add_metric(accuracy_value, name="accuracy")

train_model.summary(180)

In [None]:
inference_model = tf.keras.Model(inputs=inputs_ids, outputs=logits)

# Training model

In [None]:
epochs = 10
steps_per_epoch = 1000
save_dir = "models/test-gpt-2-model-v1"


In [None]:

decay_steps = steps_per_epoch * epochs
validation_steps = validation_dataset.cardinality().numpy() // batch_size
validation_steps

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)
train_model.compile(optimizer, jit_compile=True)

In [None]:
from pathlib import Path

lr_scheduler = tf.keras.optimizers.schedules.CosineDecay(0.0001, epochs, alpha=0.01)

def lr_schedule(epoch):
  learning_rate = lr_scheduler(epoch)
  tf.summary.scalar('learning rate', data=learning_rate, step=epoch)
  return learning_rate


callbacks = [
    tf.keras.callbacks.TensorBoard(
        log_dir=Path(save_dir) / "logs",
        histogram_freq=0,
        embeddings_freq=0,
        update_freq="epoch",
        write_steps_per_second=True,
        profile_batch=(200, 250),
        write_graph=False,
    ),
    tf.keras.callbacks.LearningRateScheduler(lr_schedule),
    tf.keras.callbacks.ModelCheckpoint(
        Path(save_dir) / "model",
        monitor = "val_accuracy",
        verbose = 1,
        save_best_only = True,
        save_weights_only = True,
        mode = "auto",
    ),
]

In [None]:
# train_model.load_weights(Path(save_dir) / "model")

In [None]:
for x, y in train_dataset:
    break

y_pred = train_model(x)
y_pred.shape

In [None]:
train_model.fit(
    train_dataset,
    steps_per_epoch=steps_per_epoch,
    validation_data=validation_dataset,
    validation_steps=validation_steps,
    verbose=1,
    callbacks=callbacks,
    epochs=epochs,
)

# Evaluate model

In [None]:
validation_dataset_iter = iter(validation_dataset)

In [None]:
def greedy_predict_next_token(inputs_ids: tf.Tensor) -> tf.Tensor:
    
    current_index = tf.reduce_sum(tf.cast(inputs_ids > 0, tf.int32), -1) - 1
    num_sentences, maxlen = tf.shape(inputs_ids)[0], tf.shape(inputs_ids)[1]

    y = inference_model(inputs_ids)
    logits = tf.gather(y, current_index, batch_dims=1)
    sampled_indices = tf.argmax(logits, axis=-1, output_type=tf.int32)

    current_index = tf.minimum(current_index + 1, maxlen - 1)
    scatter_indices = tf.stack([tf.range(num_sentences), current_index], axis=1)

    inputs_ids = (
        tf.scatter_nd(scatter_indices, sampled_indices, shape=(num_sentences, maxlen))
        + inputs_ids
    )
    return inputs_ids

In [None]:
x, y = next(validation_dataset_iter)
x = x['inputs_ids']

In [None]:
idx = 0
x_test = tf.concat([x[:, :context_length + 1], tf.zeros_like(x)[:, context_length + 1:]], -1)
context, target = tokenizer.detokenize(x)[idx].numpy().decode().split(tokenizer.start_token_str)
target_text = tokenizer.detokenize(y[:, context_length:, 0])[idx].numpy().decode()

print(context.replace(tokenizer.pad_token_str, ""))
print(target_text.replace(tokenizer.pad_token_str, ""))

In [None]:
stop_token = tokenizer.stop_token_id
for i in range(context_length):
    x_test = greedy_predict_next_token(x_test)
    all_complete = tf.shape(tf.unique(tf.where(x_test[:, context_length:] == stop_token)[:, 0]).y) == batch_size
    if bool(all_complete[0].numpy()):
        break

In [None]:
result = tokenizer.detokenize(x_test[:, context_length + 1:])[idx].numpy().decode().replace("!", "")
print(result)

# Sampling using keras_nlp functions

In [None]:
import keras_nlp


@tf.function(input_signature=[tf.TensorSpec(shape=(None, None), dtype=tf.int32)])
def token_probability_fn(inputs):

    input_len = tf.shape(inputs)[1]

    inputs = tf.map_fn(
        lambda _: tokenizer.pad_or_slice(_, sequence_length), inputs
    )

    y = inference_model(inputs)
    return y[:, input_len - 1, :]

In [None]:
prompt = x[idx, : context_length + 1][None, :]

predicted_tokens = keras_nlp.utils.top_k_search(
    token_probability_fn,
    prompt,
    max_length=sequence_length,
    end_token_id=tokenizer.stop_token_id,
    from_logits=True,
    k=5,
)

if len(predicted_tokens.shape) == 1:
    predicted_tokens = predicted_tokens[None, :]
predicted_tokens.shape

In [None]:
result = tokenizer.detokenize(predicted_tokens[:, context_length + 1:])[0].numpy().decode().replace("!", "")
print(result)

# Export model - base method

In [None]:
class GPT2Exporter(tf.Module):
    def __init__(
        self, model: tf.keras.Model, tokenizer: GPT2Tokenizer, jit_compile: bool = None
    ):
        super(GPT2Exporter, self).__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.jit_compile = jit_compile
        self.predict_next_token_fn = greedy_predict_next_token

        if jit_compile:
            self.predict_next_token_fn = tf.function(
                greedy_predict_next_token, jit_compile=True
            )

    def prepare_inputs(self, text: str):
        input_ids = self.tokenizer.tokenize(text)
        input_ids = self.tokenizer.pad_or_slice(input_ids, context_length)
        input_ids = self.tokenizer.pad_or_slice(
            tf.concat([input_ids, [self.tokenizer.start_token_id]], 0),
            sequence_length,
            pad_value=0,
        )

        input_ids = tf.reshape(input_ids, [1, -1])

        return input_ids

    @tf.function(input_signature=[tf.TensorSpec([], tf.string)])
    def __call__(self, text):
        input_ids = self.prepare_inputs(text)

        i = tf.constant(0)
        while i < context_length:
            i += 1
            input_ids = self.predict_next_token_fn(input_ids)

            completed = tf.reduce_any(input_ids == self.tokenizer.stop_token_id)

            if completed:
                break

        prediction = self.tokenizer.detokenize(input_ids[:, : context_length + i - 1])
        prediction = tf.strings.split(prediction, self.tokenizer.start_token_str)[0, 1]

        return {"outputs": prediction}


gpt2_predictor = GPT2Exporter(inference_model, tokenizer)
gpt2_predictor_jit = GPT2Exporter(inference_model, tokenizer, jit_compile=True)

In [None]:
prompt = "Headingley is a suburb of Leeds, West Yorkshire, England, approximately two miles out of the city centre, to the north west along the A660 road. Headingley is the location of the Beckett Park campus of Leeds Beckett University and Headingley Stadium."

prediction = gpt2_predictor(prompt)
prediction['outputs'].numpy().decode().split("\n")

In [None]:
prediction = gpt2_predictor_jit(prompt)
prediction['outputs'].numpy().decode().split("\n")

In [None]:
%timeit gpt2_predictor(prompt)

In [None]:
%timeit gpt2_predictor_jit(prompt)

In [None]:
tf.saved_model.save(gpt2_predictor, Path(save_dir) / "exported-models/gpt2-ner/1/")
tf.saved_model.save(gpt2_predictor_jit, Path(save_dir) / "exported-models/gpt2-ner-jit/1/")

# Export model with cached keys and queries
This is applicable only for GPT2 small which has 12 blocks

In [None]:
def sample_argmax(logits):
    return tf.argmax(logits, axis=-1, output_type=tf.int32)


class CachedGPT2Exporter(GPT2Exporter):
    def __init__(
        self, model: tf.keras.Model, tokenizer: GPT2Tokenizer, jit_compile: bool = False
    ):
        super(CachedGPT2Exporter, self).__init__(model, tokenizer, jit_compile=False)
        self.predict_next_token_fn = model_predict_fn
        if jit_compile:
            self.predict_next_token_fn = tf.function(model_predict_fn, jit_compile=True)

    @tf.function(input_signature=[tf.TensorSpec([], tf.string)])
    def __call__(self, text):
        input_ids = self.prepare_inputs(text)
        input_ids = input_ids[:, : context_length + 1]

        logits, past_key_values = self.predict_next_token_fn(input_ids)
        new_input_ids = tf.expand_dims(sample_argmax(logits[:, -1, :]), axis=-1)

        states = tf.TensorArray(tf.int32, size=context_length)
        states = states.write(0, new_input_ids)

        # for autograph to compile this function we need to specify each variable explicitly
        # since set_loop_options does not woth with python lists, that's why this function
        # will work only with GPT2 small model
        kv0 = past_key_values[0]
        kv1 = past_key_values[1]
        kv2 = past_key_values[2]
        kv3 = past_key_values[3]
        kv4 = past_key_values[4]
        kv5 = past_key_values[5]
        kv6 = past_key_values[6]
        kv7 = past_key_values[7]
        kv8 = past_key_values[8]
        kv9 = past_key_values[9]
        kv10 = past_key_values[10]
        kv11 = past_key_values[11]

        for i in tf.range(context_length - 1):
            tf.autograph.experimental.set_loop_options(
                shape_invariants=[
                    (kv0, tf.TensorShape([2, 1, 12, None, 64])),
                    (kv1, tf.TensorShape([2, 1, 12, None, 64])),
                    (kv2, tf.TensorShape([2, 1, 12, None, 64])),
                    (kv3, tf.TensorShape([2, 1, 12, None, 64])),
                    (kv4, tf.TensorShape([2, 1, 12, None, 64])),
                    (kv5, tf.TensorShape([2, 1, 12, None, 64])),
                    (kv6, tf.TensorShape([2, 1, 12, None, 64])),
                    (kv7, tf.TensorShape([2, 1, 12, None, 64])),
                    (kv8, tf.TensorShape([2, 1, 12, None, 64])),
                    (kv9, tf.TensorShape([2, 1, 12, None, 64])),
                    (kv10, tf.TensorShape([2, 1, 12, None, 64])),
                    (kv11, tf.TensorShape([2, 1, 12, None, 64])),
                ]
            )
            past_key_values = [
                kv0,
                kv1,
                kv2,
                kv3,
                kv4,
                kv5,
                kv6,
                kv7,
                kv8,
                kv9,
                kv10,
                kv11,
            ]

            past_length = i + context_length - 1
            position_ids = tf.expand_dims(
                tf.range(past_length, past_length + 1), axis=0
            )

            logits, new_past_key_values = self.predict_next_token_fn(
                inputs_ids=new_input_ids,
                past_key_values=past_key_values,
                position_ids=position_ids,
            )
            new_input_ids = tf.expand_dims(sample_argmax(logits[:, -1, :]), axis=-1)
            states = states.write(i + 1, new_input_ids)

            kv0 = new_past_key_values[0]
            kv1 = new_past_key_values[1]
            kv2 = new_past_key_values[2]
            kv3 = new_past_key_values[3]
            kv4 = new_past_key_values[4]
            kv5 = new_past_key_values[5]
            kv6 = new_past_key_values[6]
            kv7 = new_past_key_values[7]
            kv8 = new_past_key_values[8]
            kv9 = new_past_key_values[9]
            kv10 = new_past_key_values[10]
            kv11 = new_past_key_values[11]

            completed = tf.reduce_any(new_input_ids == self.tokenizer.stop_token_id)

            if completed:
                break

        input_ids = tf.reshape(states.stack(), [1, -1])[:, :i]

        prediction = self.tokenizer.detokenize(input_ids)

        return {"outputs": prediction[0]}


gpt2_cached_predictor = CachedGPT2Exporter(
    inference_model, tokenizer, jit_compile=False
)
gpt2_cached_predictor_jit = CachedGPT2Exporter(
    inference_model, tokenizer, jit_compile=True
)

In [None]:
prediction = gpt2_cached_predictor(prompt)
prediction['outputs'].numpy().decode().replace("!", "").split("\n")

In [None]:
prediction = gpt2_cached_predictor_jit(tf.constant(prompt))
prediction['outputs'].numpy().decode().replace("!", "").split("\n")

In [None]:
%timeit gpt2_cached_predictor(tf.constant(prompt))

In [None]:
%timeit gpt2_cached_predictor_jit(tf.constant(prompt))

In [None]:
tf.saved_model.save(gpt2_cached_predictor, Path(save_dir) / "exported-models/gpt2-ner-cached/1/")
tf.saved_model.save(gpt2_cached_predictor_jit, Path(save_dir) / "exported-models/gpt2-ner-cached-jit/1/")

In [None]:
loaded_model = tf.saved_model.load("models/test-gpt-2-model/exported-models/gpt2-ner-cached-jit/1")
loaded_model.signatures

# Test tensorflow Serving

In [None]:
from pathlib import Path


cwd = Path.cwd()
save_dir = "models/test-gpt-2-model"

run_serving_cmd = f"docker run -p 8501:8501 --rm --gpus all --name tfserving_models --mount type=bind,source={cwd}/{save_dir}/exported-models/gpt2-ner,target=/models/model -e MODEL_NAME=model -t tensorflow/serving:2.11.1-gpu"
print(run_serving_cmd)

In [None]:
import requests
import json

context = "Headingley is a suburb of Leeds, West Yorkshire, England, approximately two miles out of the city centre, to the north west along the A660 road. Headingley is the location of the Beckett Park campus of Leeds Beckett University and Headingley Stadium."

prediction_url = "http://localhost:8501/v1/models/model:predict"
post_data = {"inputs": {"text": context}}
response = requests.post(prediction_url, data=json.dumps(post_data))
prediction = response.json()["outputs"]
prediction.split("\n")