## Install dependencies

In [1]:
!pip install  keras-hub



In [2]:
!pip install -U tensorboard-plugin-profile

Collecting tensorboard-plugin-profile
  Downloading tensorboard_plugin_profile-2.19.0-cp311-none-manylinux2014_x86_64.whl.metadata (5.0 kB)
Collecting gviz-api>=1.9.0 (from tensorboard-plugin-profile)
  Downloading gviz_api-1.10.0-py2.py3-none-any.whl.metadata (2.6 kB)
Downloading tensorboard_plugin_profile-2.19.0-cp311-none-manylinux2014_x86_64.whl (25.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m25.8/25.8 MB[0m [31m72.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gviz_api-1.10.0-py2.py3-none-any.whl (13 kB)
Installing collected packages: gviz-api, tensorboard-plugin-profile
Successfully installed gviz-api-1.10.0 tensorboard-plugin-profile-2.19.0


## Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from datetime import datetime

In [2]:
import keras
import keras_hub

In [3]:
import json
import re
import string
from IPython.display import display, HTML
import warnings
warnings.filterwarnings("ignore")

In [4]:
import tensorflow as tf
import tensorflow.data as tf_data
import tensorflow.strings as tf_strings

In [5]:
from keras import layers, models, losses, callbacks
from keras.layers import Dense, Layer, Dropout
from keras.ops import softmax

In [6]:
from tensorflow import math, matmul, reshape, shape, transpose, cast, float32

## Set Configuration

In [7]:
gpus = tf.config.list_physical_devices('GPU')
num_gpus = len(gpus)

print("Num GPUs Available: ", num_gpus)

Num GPUs Available:  2


In [8]:
VOCAB_SIZE = 5000 # 50000 is the size of the vocabulary for gpt2 on common corpus. Simplebooks is smaller
MAX_LEN = 1024 # the maximum length of the input sequences,
EMBEDDING_DIM = 768 # the dimension of the word embeddings
KEY_DIM = 64 # the dimension of the keys in the attention mechanism
N_HEADS = 3 #  the number of attention heads
NUM_BLOCKS = 3 # Number of transformer blocks 12 in gpt2, 3 in gpt nano
FEED_FORWARD_DIM = 4*EMBEDDING_DIM #  the dimension of the feed-forward network in the block
VALIDATION_SPLIT = 0.2 # the fraction of data to be used for validation
SEED = 42 # the random seed for reproducibility

In [9]:
MIN_STRING_LEN = 256  # Strings shorter than this will be discarded
SEQ_LEN = 128 # 512  # Length of training sequences, in tokens

In [10]:
BATCH_SIZE = 256*num_gpus  # Batch size for training.

## Data ingestion and preprocessing

In [11]:
BASE_DIR="/kaggle/working"

In [12]:
keras.utils.get_file(
    origin="https://dldata-public.s3.us-east-2.amazonaws.com/simplebooks.zip",
    extract=True,
    cache_subdir=BASE_DIR+"/keras/",
)


'/kaggle/working/keras/simplebooks.zip'

In [12]:
dir = BASE_DIR+"/keras/simplebooks/"

In [13]:
# Load simplebooks-92 train set and filter out short lines.
raw_train_ds = (
    tf_data.TextLineDataset(dir + "simplebooks-92-raw/train.txt")
    .filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN)
    .batch(BATCH_SIZE, drop_remainder=True)
    .shuffle(buffer_size=256)
)

# Load simplebooks-92 validation set and filter out short lines.
raw_val_ds = (
    tf_data.TextLineDataset(dir + "simplebooks-92-raw/valid.txt")
    .filter(lambda x: tf_strings.length(x) > MIN_STRING_LEN)
    .batch(BATCH_SIZE, drop_remainder=True)
)

In [14]:
raw_val_ds

<_BatchDataset element_spec=TensorSpec(shape=(512,), dtype=tf.string, name=None)>

In [15]:
for element in raw_val_ds.take(1):
    print(element)

tf.Tensor(
[b'"I am glad of it," said a woolly Lamb on Wheels, who stood on the floor, just under the edge of the toy counter. She was rather too large to be up among the smaller toys. "Yes, I am glad of it," went on the Lamb. "I have kept still all day, and now I have something to tell you all, my friends."'
 b'For it was one of the rules of Toyland, as you know, that none of the folk who lived there could do anything while human eyes were watching them. The Dolls, Soldiers, Clowns, Rocking Horses, Lambs were not able to move, talk, or make believe come to life if a boy or a girl or any one at all looked at them.'
 b'"Yes, you are always ready to jump out of your box as soon as the cover is taken off," remarked the Lamb on Wheels. "But the rest of us are not such high kickers as you are. I cannot jump at all. I can only run around on my wheels, just as the White Rocking Horse, who used to live here, could only go on his rockers."'
 b'"Do you mean the Sawdust Doll who used to live here

In [16]:
#if there is no val_ds
#raw_train_ds, raw_val_ds= keras.utils.split_dataset(raw_train_ds, 0.8)

## Tokenization

In [24]:
# Train tokenizer vocabulary
vocab = keras_hub.tokenizers.compute_word_piece_vocabulary(
    raw_train_ds,
    vocabulary_size=VOCAB_SIZE,
    lowercase=True,
    reserved_tokens=["[PAD]", "[UNK]", "[BOS]"],
)

### Save vocabulary

In [25]:
import pickle
with open(BASE_DIR+'/simplebooks_vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)

### Retrieve vocabulary

In [17]:
#insert code to save and retrieve tokenizer
import pickle
with open('/kaggle/input/simplebooks-vocabulary/simplebooks_vocab.pkl', 'rb') as f:
    vocab = pickle.load(f)

In [18]:
vocab[100:120]

['that',
 'it',
 'had',
 '##s',
 'his',
 'as',
 'for',
 'with',
 'they',
 'on',
 'but',
 'her',
 'at',
 'she',
 'were',
 'not',
 'you',
 'be',
 'him',
 'all']

In [19]:
tokenizer = keras_hub.tokenizers.WordPieceTokenizer(
    vocabulary=vocab,
    sequence_length=SEQ_LEN,
    lowercase=True,
)

In [20]:
# packer adds a start token
start_packer = keras_hub.layers.StartEndPacker(
    sequence_length=SEQ_LEN,
    start_value=tokenizer.token_to_id("[BOS]"),
)

def preprocess(inputs):
    outputs = tokenizer(inputs)
    features = start_packer(outputs)
    labels = outputs
    return features, labels


# Tokenize and split into train and label sequences.
train_ds = raw_train_ds.map(preprocess, num_parallel_calls=tf_data.AUTOTUNE).prefetch(
    tf_data.AUTOTUNE
)
val_ds = raw_val_ds.map(preprocess, num_parallel_calls=tf_data.AUTOTUNE).prefetch(
    tf_data.AUTOTUNE
)

## Define model

### Attention

In [21]:
@keras.saving.register_keras_serializable()
def causal_attention_mask(batch_size, n_dest, n_src, dtype):
    i = tf.range(n_dest)[:, None]
    j = tf.range(n_src)
    m = i >= j - n_src + n_dest
    mask = tf.cast(m, dtype)
    mask = tf.reshape(mask, [1, n_dest, n_src])
    mult = tf.concat(
        [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
    )
    return tf.tile(mask, mult)
np.transpose(causal_attention_mask(1, 10, 10, dtype=tf.int32)[0])

array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
       [0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
       [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=int32)

## LLama additions

In [22]:
## RMS normalization
@keras.saving.register_keras_serializable()
class RMSNorm(Layer):
    def __init__(self, epsilon=1e-6, **kwargs):
        super(RMSNorm, self).__init__(**kwargs)
        self.epsilon = epsilon

    def build(self, input_shape):
        self.gamma = self.add_weight(
            name='gamma',
            shape=(input_shape[-1],),
            initializer='ones',
            trainable=True
        )

    def call(self, inputs):
        rms = tf.sqrt(tf.reduce_mean(tf.square(inputs), axis=-1, keepdims=True) + self.epsilon)
        return inputs * self.gamma / rms


In [23]:
## SwiGLU Activation
@keras.saving.register_keras_serializable()
class SwiGLU(Layer):
    def __init__(self, bias=True, dim=-1, **kwargs):
        """
        SwiGLU Activation Layer
        """
        super(SwiGLU, self).__init__(**kwargs)
        self.bias = bias
        self.dim = dim
        self.dense = Dense(2, use_bias=bias)

    def call(self, x):
        out, gate = tf.split(x, num_or_size_splits=2, axis=self.dim)
        gate = keras.activations.swish(gate)
        x = tf.multiply(out, gate)
        return x

### Transformer Decoder Block

In [24]:
@keras.saving.register_keras_serializable()
class TransformerBlock(Layer):
    def __init__(self, num_heads, key_dim, embed_dim, ff_dim, dropout_rate=0.1, **kwargs):
        super(TransformerBlock, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.key_dim = key_dim
        self.embed_dim = embed_dim
        self.ff_dim = ff_dim
        self.dropout_rate = dropout_rate

        self.attn = layers.MultiHeadAttention(num_heads, key_dim, embed_dim)
        self.dropout_1 = Dropout(self.dropout_rate, seed=SEED)
        self.ln_1 = RMSNorm(epsilon=1e-6)
        self.ffn_1 = Dense(self.ff_dim, use_bias=False)
        self.swiglu=SwiGLU()
        self.ffn_2 = Dense(self.embed_dim, use_bias=False)
        self.dropout_2 = Dropout(self.dropout_rate, seed=SEED)
        self.ln_2 = RMSNorm(epsilon=1e-6)

    def call(self, inputs):
      input_shape = tf.shape(inputs)
      batch_size = input_shape[0]
      seq_len = input_shape[1]
      causal_mask = causal_attention_mask(
          batch_size, seq_len, seq_len, tf.bool
      )
      inputs_n=self.ln_1(inputs)
      attention_output = self.attn(
          query=inputs_n,
          value=inputs_n,
          attention_mask=causal_mask,
      )
      attention_output = self.dropout_1(attention_output)
      out1 = self.ln_1(inputs + attention_output)
      ffn_1 = self.ffn_1(out1)
      swiglu_1 = self.swiglu(ffn_1)
      ffn_2 = self.ffn_2(swiglu_1)
      ffn_output = self.dropout_2(ffn_2)
      return (out1 + ffn_output) #(self.ln_2(out1 + ffn_output), attention_scores)

    def get_config(self):
      config = super().get_config()
      config.update(
          {
              "key_dim": self.key_dim,
              "embed_dim": self.embed_dim,
              "num_heads": self.num_heads,
              "ff_dim": self.ff_dim,
              "dropout_rate": self.dropout_rate,
          }
      )
      return config

## Compile and Train

In [25]:
print (N_HEADS, NUM_BLOCKS)

3 3


In [26]:
strategy = tf.distribute.MirroredStrategy()


In [27]:
with strategy.scope():

    ## Optimizer with decaying weights
    optimizer=keras.optimizers.AdamW(learning_rate=1e-5*num_gpus,
                                  weight_decay=0.1,
                                  beta_1=0.9,
                                  beta_2=0.95,
                                  epsilon=1e-5,
                                  )

    ## Model definition
    inputs = keras.layers.Input(shape=(None,), dtype="int32")

    ## From Llama
    #x = TokenAndPositionEmbedding(MAX_LEN, VOCAB_SIZE, EMBEDDING_DIM)(inputs)
    x=layers.Embedding(input_dim=VOCAB_SIZE, output_dim=EMBEDDING_DIM )(inputs)
    x=keras_hub.layers.RotaryEmbedding()(x)

    for i in range(NUM_BLOCKS):
        x = TransformerBlock(
            N_HEADS, KEY_DIM, EMBEDDING_DIM, FEED_FORWARD_DIM, dropout_rate=0.2,
            name=f"transformer_block_{i}",
        )(x)

    x=RMSNorm(epsilon=1e-6)(x)
    outputs = Dense(VOCAB_SIZE,
                    use_bias=False,
                    #activation="softmax" #we'll use the logits
                    )(x)

    gpt = keras.Model(inputs=inputs, outputs=[outputs])

    gpt.compile(optimizer=optimizer,
            loss=[losses.SparseCategoricalCrossentropy(from_logits=True)],
            metrics=[keras_hub.metrics.Perplexity(from_logits=True, mask_token_id=0)],
            )


## Train

In [28]:
gpt.summary()

In [29]:
earlystop=keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0,
    patience=5,
    verbose=1,
    mode="min",
    baseline=None,
    restore_best_weights=True,
    start_from_epoch=0,
)

In [30]:
tensorboard=keras.callbacks.TensorBoard(
    log_dir=BASE_DIR+"/logs",
    histogram_freq=1,
    write_graph=True,
    write_images=False,
    write_steps_per_second=False,
    update_freq="epoch",
    profile_batch=0,
    embeddings_freq=1,
    embeddings_metadata=None,
)

In [31]:
checkpoint = keras.callbacks.ModelCheckpoint(
    filepath=BASE_DIR+"/models/checkpoints/gpu/llama2-"+datetime.now().strftime("%Y%m%d-%H%M%S")+".keras",
    monitor='val_loss',
    mode='min',
    save_best_only=True,
    save_freq="epoch",
    verbose=1)

In [None]:
%%time
history=gpt.fit(train_ds,
                validation_data=val_ds,
                epochs=50,
                callbacks=[earlystop, tensorboard, checkpoint],
                verbose=1)

Epoch 1/50
    982/Unknown [1m2150s[0m 2s/step - loss: 5.2200 - perplexity: 771.2275
Epoch 1: val_loss improved from inf to 3.26428, saving model to /kaggle/working/models/checkpoints/gpu/llama2-20250304-193730.keras
[1m982/982[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2160s[0m 2s/step - loss: 5.2195 - perplexity: 770.7723 - val_loss: 3.2643 - val_perplexity: 144.7335
Epoch 2/50
[1m982/982[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - loss: 4.1415 - perplexity: 165.0102
Epoch 2: val_loss improved from 3.26428 to 3.09866, saving model to /kaggle/working/models/checkpoints/gpu/llama2-20250304-193730.keras
[1m982/982[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2157s[0m 2s/step - loss: 4.1414 - perplexity: 164.9978 - val_loss: 3.0987 - val_perplexity: 112.6069
Epoch 3/50
[1m982/982[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - loss: 3.9442 - perplexity: 129.7001
Epoch 3: val_loss improved from 3.09866 to 2.97508, saving model to /kaggle/w

In [None]:
gpt.evaluate(val_ds)

## Plot losses

In [None]:
plt.plot(history.history['perplexity'])
plt.plot(history.history['val_perplexity'])

In [None]:
# Load the TensorBoard notebook extension if available
%load_ext tensorboard

In [None]:
%tensorboard --logdir $BASE_DIR/logs

## Inference with pre-trained model

In [None]:
prompt_tokens = start_packer(tokenizer([""]))
prompt_tokens

<tf.Tensor: shape=(1, 128), dtype=int32, numpy=
array([[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
      dtype=int32)>

In [None]:
def next(prompt, cache, index):
    logits = gpt(prompt)[:, index - 1, :]
    # Ignore hidden states for now
    hidden_states = None
    return logits, hidden_states, cache

### Test different kinds of samplers

In [None]:
sampler = keras_hub.samplers.GreedySampler()
output_tokens = sampler(
    next=next,
    prompt=prompt_tokens,
    index=1,  # Start sampling immediately after the [BOS] token.
)
txt = tokenizer.detokenize(output_tokens)
print(f"Greedy search generated text: \n{txt}\n")

Greedy search generated text: 
['[BOS] " i \' m going to have a good time , " said the old man , " and i \' ll have to go to the old house and get a good dinner . i \' ll have a good dinner , and i \' ll have a good dinner . i \' ll have a good dinner , and i \' ll have a good dinner . " [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']



In [None]:
sampler = keras_hub.samplers.BeamSampler(num_beams=10)
output_tokens = sampler(
    next=next,
    prompt=prompt_tokens,
    index=1,
)
txt = tokenizer.detokenize(output_tokens)
print(f"Beam search generated text: \n{txt}\n")

Beam search generated text: 
['[BOS] " i \' ll tell you what i \' ll do , " he said . " i \' ll tell you what i \' ll do . i \' ll tell you what i \' ll do . i \' ll tell you what i \' ll do . i \' ll tell you what i \' ll do . " [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']



In [None]:
sampler = keras_hub.samplers.RandomSampler()
output_tokens = sampler(
    next=next,
    prompt=prompt_tokens,
    index=1,
)
txt = tokenizer.detokenize(output_tokens)
print(f"Random search generated text: \n{txt}\n")

Random search generated text: 
['[BOS] after this the forenoon began to play " states " in reciting the printed car . its owner met its plans . it got more dampier than the lastar tank searching for the createdchargent , and had to begin the fundament office . as it was reminded of the accident called , it was a type the envelope , the farmer had seen in it in the days of the big city . [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']



In [None]:
sampler = keras_hub.samplers.TopKSampler(k=5, temperature=1.2)
output_tokens = sampler(
    next=next,
    prompt=prompt_tokens,
    index=1,
)
txt = tokenizer.detokenize(output_tokens)
print(f"Top-K search generated text: \n{txt}\n")

Top-K search generated text: 
['[BOS] " i think i have been a fool , " he said . " if you will come here in a day or two , and you can go to a house where there will be plenty of room for you . i am going to ask you to go out , if you will . " [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']



In [None]:
sampler = keras_hub.samplers.TopPSampler(p=0.8, k=5, temperature=1.2)
output_tokens = sampler(
    next=next,
    prompt=prompt_tokens,
    index=1,
)
txt = tokenizer.detokenize(output_tokens)
print(f"Top-P search generated text: \n{txt}\n")

Top-P search generated text: 
["[BOS] the next day a large body of horsemen , riding on the ground , came to the edge of the plain , where they were riding , rode on . they were riding at a little distance from the road , the horse , a horseman and rider , riding in an avenue . the riders and the horse rode at a little distance , riding a long way , until they were within sight of the horse ' s riding horse , and rode up to the horse , riding at full speed . [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]"]



### Optional: keep training with callbacks

In [None]:
class TopKTextGenerator(keras.callbacks.Callback):
    """A callback to generate text from a trained model using top-k."""

    def __init__(self, k, temperature):
        self.sampler = keras_hub.samplers.TopKSampler(k=k, temperature=temperature)

    def on_epoch_end(self, epoch, logs=None):
        output_tokens = self.sampler(
            next=next,
            prompt=prompt_tokens,
            index=1,
        )
        txt = tokenizer.detokenize(output_tokens)
        print(f"\nTop-K search generated text: \n{txt}\n")


text_generation_callback = TopKTextGenerator(k=5, temperature=1.2)



In [None]:
class TopPTextGenerator(keras.callbacks.Callback):
    """A callback to generate text from a trained model using top-p."""

    def __init__(self, p, k, temperature):
        self.sampler = keras_hub.samplers.TopPSampler(p=p, k=k, temperature=temperature)

    def on_epoch_end(self, epoch, logs=None):
        output_tokens = self.sampler(
            next=next,
            prompt=prompt_tokens,
            index=1,
        )
        txt = tokenizer.detokenize(output_tokens)
        print(f"\nTop-P search generated text: \n{txt}\n")


text_generation_callback = TopPTextGenerator(p=0.8, k=5, temperature=1.2)

In [None]:
# Dummy training loop to demonstrate callback.
gpt.fit(train_ds.take(1), verbose=1, epochs=1, callbacks=[text_generation_callback])

In [None]:
# Training loop with callbacks
gpt.fit(train_ds, validation_data=val_ds, epochs=100, callbacks=[earlystop, tensorboard, text_generation_callback])

In [None]:
gpt.summary()

In [None]:
gpt.save(BASE_DIR+'/models/gpt2-simplebooks-pt.keras')

## Instruction tuning
Read and preprocess Q&A dataset

In [None]:
df = pd.read_parquet("hf://datasets/vicgalle/alpaca-gpt4/data/train-00000-of-00001-6ef3991c06080e14.parquet")
df.head()

Unnamed: 0,instruction,input,output,text
0,Give three tips for staying healthy.,,1. Eat a balanced and nutritious diet: Make su...,Below is an instruction that describes a task....
1,What are the three primary colors?,,"The three primary colors are red, blue, and ye...",Below is an instruction that describes a task....
2,Describe the structure of an atom.,,An atom is the basic building block of all mat...,Below is an instruction that describes a task....
3,How can we reduce air pollution?,,There are several ways to reduce air pollution...,Below is an instruction that describes a task....
4,Describe a time when you had to make a difficu...,,"As an AI assistant, I do not have my own perso...",Below is an instruction that describes a task....


In [None]:
ds=tf_data.Dataset.from_tensor_slices(df["text"]).batch(BATCH_SIZE, drop_remainder=True)

In [None]:
ds

<_BatchDataset element_spec=TensorSpec(shape=(256,), dtype=tf.string, name=None)>

In [None]:
ids=ds.map(preprocess, num_parallel_calls=tf_data.AUTOTUNE).prefetch(tf_data.AUTOTUNE)

In [None]:
train_ids, val_ids=keras.utils.split_dataset(ids, left_size=0.8, shuffle=True, seed=SEED)

In [None]:
for element in ds.take(1):
    print(element)

tf.Tensor(
[b'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\nGive three tips for staying healthy.\n\n### Response:\n1. Eat a balanced and nutritious diet: Make sure your meals are inclusive of a variety of fruits and vegetables, lean protein, whole grains, and healthy fats. This helps to provide your body with the essential nutrients to function at its best and can help prevent chronic diseases.\n\n2. Engage in regular physical activity: Exercise is crucial for maintaining strong bones, muscles, and cardiovascular health. Aim for at least 150 minutes of moderate aerobic exercise or 75 minutes of vigorous exercise each week.\n\n3. Get enough sleep: Getting enough quality sleep is crucial for physical and mental well-being. It helps to regulate mood, improve cognitive function, and supports healthy growth and immune function. Aim for 7-9 hours of sleep each night.'
 b'Below is an instruction that describes a 

### Training loop

In [None]:
text_generation_callback = TopKTextGenerator(k=5, temperature=1.2)

In [None]:
# Training loop with callbacks
it_history= gpt.fit(train_ids,
                    validation_data=val_ids,
                    epochs=100,
                    callbacks=[earlystop,
                               #text_generation_callback, ## very slow so comment out
                               tensorboard],
                    verbose=1,
                   )

Epoch 1/100
[1m162/162[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 179ms/step - loss: 1.5873 - perplexity: 5.2882
Top-K search generated text: 
['[BOS] below is an instruction that describes a task . write a response that appropriately completes the request . # # # instruction : generate a list of 19020 # # # response : 1 ) 2022 . the number of 10 is 144 . the number of numbers is 60 . 35000 , 17822 . 20221 . [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]']

[1m162/162[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m45s[0m 280ms/step - loss: 1.5873 - perplexity: 5.2880 - val_loss: 1.5167 - val_perplexity: 4.9200
Epoch 2/100
[1m162/162[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 179ms/step - loss: 1.5493 - perplexity: 5.0824
Top-K search generated text: 
['[BOS] below is an instruction that describes a task . write a response that appropriately completes the requ

In [None]:
plt.plot(it_history.history['perplexity'])
plt.plot(it_history.history['val_perplexity'])

In [None]:
gpt.evaluate(val_ids)

In [None]:
gpt.save(BASE_DIR+'/models/gpt2-simplebooks-it.keras')

## Inference with I.T. model

In [None]:
prompt_tokens = start_packer(tokenizer(["""Below is an instruction that describes a task, paired with an input that provides further context.
Write a response that appropriately completes the request.\n\n
Instruction:\nWrite a list of 3 ingredients for a sandwich.\n\n
Input:\nThere should be no ham.\n\n
Response: """ ]))
prompt_tokens

In [None]:
np.where(prompt_tokens.numpy().flatten()==0)[0][0]

In [None]:
sampler = keras_hub.samplers.TopPSampler(k=10,
                                         p=0.95,
                                         seed=SEED,
                                         temperature=1.5)
output_tokens = sampler(
    next=next,
    prompt=prompt_tokens,
    index=np.where(prompt_tokens.numpy().flatten()==0)[0][0],
)
txt = tokenizer.detokenize(output_tokens)
print(f"Top-P search generated text: \n{txt}\n")