> <p><small><small>This Notebook is made available subject to the licence and terms set out in the <a href = "http://www.github.com/google-deepmind/ai-foundations">AI Research Foundations Github README file</a>.

# **Build Your Own Small Language Model, Lab 5: Training Your Own Small Language Model**

<a href='https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_1/introduction_to_language_modeling_lab_5.ipynb' target='_parent'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a>

In [None]:
# Packages used.
import tensorflow as tf
import pandas as pd

## Step 1: Load the dataset

To begin, you will load the dataset. This lab uses the [Africa Galore](https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore.json) dataset as in the previous lab.

In [None]:
africa_galore = pd.read_json('https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore.json')
train_dataset = africa_galore['description'].values
print('Training dataset contains', train_dataset.shape[0], 'paragraphs.')

## Step 2: Tokenize the text

At the end of the previous lab, you created a `SimpleWordTokenize` class. Below, you will use it to tokenize the text dataset:

In [None]:
# Putting it all together.
class SimpleWordTokenizer:
    """A simple word tokenizer that can be initialized with texts
       or using a provided vocabulary list.

    The tokenizer splits the text sequence based on whitespace,
    using the `encode` method to convert the text into a sequence of indices
    and the `decode` method to convert indices back into text.

    Typical usage example:

        text = 'Hello there!'
        tokenizer = SimpleWordTokenizer(text)
        print(tokenizer.encode('Hello'))

    Attributes:
        texts: Input text dataset.
        vocab: A pre-defined vocabulary. Defaults to None. If None,
               the vocab is automatically inferred from the texts.
    """

    # Define constants.
    UNKNOWN_TOKEN = '<UNK>'
    PAD_TOKEN = '<PAD>'

    def __init__(self, texts: list[str], vocab: list[str] | None = None):
        """Initializes the tokenizer with texts or using a provided vocabulary.

        Args:
          texts: Input text dataset.
          vocab: A pre-defined vocabulary. Defaults to None. If None,
                the vocab is automatically inferred from the texts.
        """

        if vocab is None:
            # Build the vocab from scratch.
            if isinstance(texts, str):
              texts = [texts]

            # Convert text sequence to tokens.
            tokens = [token for text in texts
                      for token in self.split_text(text)]

            # Create a vocabulary comprising of unique tokens.
            vocab = self.build_vocab(tokens)

            # Add special unknown and pad token to the vocabulary list.
            self.vocab = [self.PAD_TOKEN] + vocab +  [self.UNKNOWN_TOKEN]

        else:
            self.vocab = vocab

        # Size of vocabulary.
        self.vocab_size = len(self.vocab)

        # Create token-to-index and index-to-token mappings.
        self.token_to_index = {token: index
                               for index, token in enumerate(self.vocab)}
        self.index_to_token = {index: token
                               for index, token in enumerate(self.vocab)}

        # Map the special tokens to their IDs.
        self.pad_token_id = self.token_to_index[self.PAD_TOKEN]
        self.unknown_token_id = self.token_to_index[self.UNKNOWN_TOKEN]

    def split_text(self, text: str) -> list[str]:
        """Splits a given text on whitespace into tokens."""
        return text.split(' ')

    def join_text(self, text_lists: list[str]) -> str:
        """Combines a list of tokens into a single string,
            with tokens separated by spaces.
        """
        return ' '.join(text_lists)

    def build_vocab(self, tokens: list[str])-> list[str]:
      """Create a vocabulary list from the set of tokens."""
      return list(set(tokens))

    def encode(self, text: str) -> list[int]:
        """Encodes a text sequence into a list of indices based on the
           vocabulary.

        Args:
            text: The input text to be encoded.

        Returns:
            list: A list of indices corresponding to the tokens in the
                  input text.
        """

        # Convert tokens into indexes.
        return [self.token_to_index.get(token,
                                        self.token_to_index[self.UNKNOWN_TOKEN])
                for token in self.split_text(text)]

    def decode(self, numbers: int | list[int]) -> str:
        """Decodes a list (or single index) of integers back into
        corresponding tokens from the vocabulary.

        Args:
            numbers: A single index or a list of indices to be
                     decoded into tokens.

        Returns:
            str: A string of decoded tokens corresponding to the input indices.
        """

        # If a single integer is passed, convert it into a list.
        if isinstance(numbers, int):
            numbers = [numbers]

        # Map indices to tokens.
        tokens = [self.index_to_token.get(number, self.unknown_token_id)
                  for number in numbers]

        # Join the decoded tokens into a single string.
        return self.join_text(tokens)

In [None]:
tokenizer = SimpleWordTokenizer(train_dataset)
encoded_tokens = [tokenizer.encode(text) for text in train_dataset]

Now, make sure to check the length of the encoded tokens:

In [None]:
print(len(encoded_tokens))

Next, examine the first ten token IDs of the first tokenized paragraph in the train dataset:

In [None]:
encoded_tokens[0][:10]

## Step 3: Pad or truncate the tokens to the desired length


The padding `'<PAD>'` token is used to ensure that all sequences have the same length. The paragraphs have varying lengths but neural networks expect inputs to have a uniform shape. Shorter paragraphs need to be padded to match the longest paragraph, so that all inputs to the network follow the same dimensions. The transformer model takes in each paragraph as its context and learns the relationship between the tokens.

This part checks the length of the first paragraph:



In [None]:
print('length of first paragraph:', len(encoded_tokens[0] ))

Count the maximum and minimum number of tokens in a paragraph in the dataset to determine the length to pad up to:

In [None]:
shortest_paragragh_length = len(min(encoded_tokens, key=len))
longest_paragragh_length = len(max(encoded_tokens, key=len))
print(f'length of the shortest paragraph is:', shortest_paragragh_length)
print(f'length of the longest paragraph is:', longest_paragragh_length)

As discussed earlier, all paragraphs are required to be the same length to prepare the training dataset. One option is to truncate longer paragraphs to match the shortest one. While this approach is efficient, it risks losing important context since longer paragraphs would have tokens cut off.

Another option is to pad the shorter paragraphs with the special `'<PAD>'` token, making all paragraphs the same length as the longest one. This method ensures that each paragraph retains the full context. However, while padding helps maintain meaning, it may introduce extra memory and computation overhead.

Alternatively, the distribution of paragraph lengths can be analyzed to choose a padding length that covers most of the content, which avoids excessive padding. This approach helps balance context retention with performance optimization. The next section provides flexibility in adjusting how much padding to add, but if you set it too low you will effectively be truncating the paragraphs.



Use the box below to enter the maximum length you want to pad your paragraphs to. To understand how this works, try entering the minimum paragraph length you computed earlier, then run the cell and observe the output. After that, try entering the maximum paragraph length. If your dataset has short paragraphs, you can simply pad them up to the maximum paragraph length. However, if you start training your model and encounter an "out of memory" error, you can return to this part and reduce the maximum length if the maximum length you used is too large:

In [None]:
maxlen = 320 #@param {type: 'number'}

# Ensure that maxlen is positive.
assert maxlen > 0, 'Max length must be greater than 0. Increase the `maxlen`'
assert maxlen <= longest_paragragh_length, ('Note: The padding token '
       f'{tokenizer.pad_token_id} will be added to sequences longer than the'
       'longest paragraph - You probably don"t want that. Reduce the `maxlen`')

# Check if maxlen is shorter or longer than the longest paragraph.
if maxlen < longest_paragragh_length:
    print('\033[33mWarning: The longest paragraph has '
    f'{longest_paragragh_length} tokens, but `maxlen` '
    f'is set to {maxlen}. As a result, paragraphs longer than '
    '`maxlen` will be truncated.\033[0m')

padded_sequences = tf.keras.preprocessing.sequence.pad_sequences(
    encoded_tokens,
    maxlen=maxlen,
    padding='post',
    truncating='post',
    value=tokenizer.pad_token_id)

print('New length of first paragraph:', len(padded_sequences[0]), '\n')

print('Padding makes the length of all sequences the same as the specified ' +
      '`maxlen`')

print('Notice the first 10 tokens observed above appear after the '
      f'padded token {tokenizer.pad_token_id} \n')
print('Padded tokens of first paragraph:\n', padded_sequences[0])

In [None]:
print('A different paragraph looks like this after padding:\n', padded_sequences[-1])

In [None]:
padded_sequences.shape

## Step 4: Prepare input and target

It is important to review the inputs and targets for the transformer model and  how to prepare them.

The model works *autoregressively*. This means it generates one token at a time and uses the previously generated tokens as context to predict the next one. Therefore, you need to organize the data in a specific way to train the model:

- **Input**: The input is a sequence of tokens that is fed into the transformer model. This can be part of a paragraph, a full paragraph, or even multiple paragraphs, depending on how the data is structured.
  
- **Target**: The target sequence is what you want the model to predict. The target will be the same as the input sequence, but *shifted left by one token*. This means the target will contain the next token that should follow the input sequence.

For example:
- Input: "The cat sat"
- Target: "cat sat on" (shifted by one token)

This setup helps the model learn how to predict the next token in the sequence based on the context provided by the previous ones.

This method of using input and target or label sequences is a common approach in *supervised learning*, where the model is trained to predict the target given the input:


In [None]:
# Prepare input and target to the transformer model.
input = padded_sequences[:, :-1] # All tokens except the last one.
target = padded_sequences[:, 1:]  # All tokens except the first one.

Print out the first ten token IDs of the first input and target sequence:

In [None]:
print(input[0, :10])
print(target[0, :10])

Pay attention to how the input and target are shifted by 1.

Now, decode the numbers to visualize the texts that are shifted:


In [None]:
# Decodes the first 10 tokens of the first paragraph from input.
tokenizer.decode(input[0, :10])

In [None]:
# Decodes the first 10 tokens of the first paragraph from target.
tokenizer.decode(target[0, :10])

Print the input and output shape:
- This returns a tuple of sequence length (the number of paragraphs you selected) and maximum length (the maximum length enforced through padding).
- The shape is reduced by 1, as the input and target are shifted by 1.

In [None]:
input.shape, target.shape

Update the maxlen to reflect that the input and target take sequences that are one token shorter:

In [None]:
maxlen = input.shape[1]

## Step 5: Shuffle the dataset and specify the batch size

Neural network training involves selecting a number of random examples, called a `batch`, from the dataset. The dataset is shuffled to achieve this randomness and the size of the batch is determined by the `batch_size`.


The figure below illustrates a dataset with seven examples, where each example is padded to `maxlen`, the length of the longest example. The dataset is then shuffled, and a batch of size three is created, with the final batch containing only one example.


<!--Below is an illustration of how data is prepared for training a neural network when dealing with sequences of varying lengths. First, each example (in this case, a paragraph) may have a different size, so we apply padding to make them uniformly long. After padding, we shuffle the data to avoid feeding examples in a fixed order. Finally, we group these shuffled and padded examples into batches, each containing `batch_size` examples: -->

<img src='https://storage.googleapis.com/dm-educational/assets/ai_foundations/evolve_graphic.png' width='1000'>

*Run the cell below to shuffle and create batches of examples:*

In [None]:
# (1) Create TensorFlow dataset to prepare sequence.
dataset = tf.data.Dataset.from_tensor_slices((input, target))

# (2) Randomly shuffle the dataset.
#     The buffer_size determines how many examples from the dataset
#     are held in memory before shuffling.
#     If you're working with a very large dataset,
#     reduce the buffer_size as needed.
dataset = dataset.shuffle(buffer_size=len(input))

# (3) Specify batch size.
batch_size = 32  #@param {type: 'number'}

# (4) Create batches.
dataset = dataset.batch(batch_size)

for batch in dataset.take(1):
    print(batch)

Count the total number of batches:

In [None]:
total_batches = 0
for batch in dataset:
    total_batches += 1
print('Total number of batches is:', total_batches)

## Step 6: Train a small language model (SLM)


In this next step, you will train a language model with around four million parameters, which is far smaller in size compared to production systems like Google Gemini (with billions of parameters). These are referred to as large language models (LLMs).

It's important to note that the size of the transformer model has an impact on its performance. Larger models with more parameters, have the capacity to learn more complex patterns and deliver better accuracy. However, they also require more computational resources, memory, and processing power, which can lead to longer training times (how long the model needs to update to reach optimal performance) and higher costs.


**What are parameters?**

Parameters refer to a set of numbers in a machine learning model that are adjusted during training in order to perform the training task. Our language model updates its parameters after processing each batch of training data to better predict the next token given a context (prior tokens). At the start of the training, the parameters are random numbers, and during each training iteration, the model updates these numbers such that it gets better at predicting the next token.

In [None]:
# @title Hidden code used for training and sampling from the trained model.

import os
import numpy as np
import jax
import jax.numpy as jnp
from typing import Any
import plotly.express as px
import tensorflow as tf
import keras
from keras import ops, layers

os.environ['KERAS_BACKEND'] = 'jax'
tf.random.set_seed(812)  # For TensorFlow operations.
keras.utils.set_random_seed(812)  # For Keras layers.


def create_model(vocab_size: int,
                 maxlen: int,
                 d_model: int = 256,
                 ff_dim: int = 256,
                 num_heads: int = 2,
                 n_blocks: int = 1,
                 optimizer: str = 'adamw',
                 learning_rate: float = 1e-4,
                 dropout_rate: float = 0.0,
                 activation: str = 'relu',
                 pad_token_id: int = 0) -> keras.Model:
    """Creates a transformer-based model for sequence processing tasks.

    Example:
        model = create_model(vocab_size=5000, maxlen=100,
                            embed_dim=256, ff_dim=512,
                            num_heads=8, n_blocks=2)
        print(model.summary())

    Notes:
        - The model uses causal (masked) attention to ensure that each token
          only attends to previous tokens and not future tokens.
        - The final dense layer produces a logit over the vocabulary for
          each token in the sequence.
        - The loss function is `CustomMaskPadLoss`, which ignores padding
          tokens in the loss computation.

    Args:
        vocab_size: The size of the vocabulary, i.e.,
                    the number of unique tokens.
        maxlen: The maximum length of the input sequences.
        d_model: The dimensionality of the embedding space.
                   Default is 256.
        ff_dim: The number of units in the feed-forward network
                of each transformer block. Default is 256.
        num_heads: The number of attention heads in the multi-head
                   attention mechanism. Default is 2.
        n_blocks: The number of transformer blocks to stack in the model.
                  Default is 1.
        optimizer: The optimizer to use for training, either 'adamw'
                   ('adam with weight decay) or 'sgd'.
                   Default is 'adamw'.
        learning_rate: The learning rate for the optimizer. Default is 1e-4.
        dropout_rate: The dropout rate to prevent overfitting.
                       Default is 0.1 (no dropout).
        activation: The activation function to use in the feed-forward network
                    of each Transformer block. Default is 'relu'.
        pad_token_id: The ID used to represent padding tokens in the sequence.
                      This is used to mask padded tokens in the loss
                      calculation. Default is 0.

    Returns:
        keras.Model: The compiled Keras model which outputs the probability
                      of the next token prediction.


    Raises:
        NotImplementedError: If an unsupported optimizer is specified.
    """
    # Create input layer.
    inputs = layers.Input(shape=(maxlen,), dtype='int32')

    # Embedding layer that combines token and positional embeddings.
    embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, d_model)
    x = embedding_layer(inputs)

    # Apply a stack of transformer blocks.
    for _ in range(n_blocks):
        transformer_block = TransformerBlock(d_model,
                                            num_heads,
                                            ff_dim,
                                            dropout_rate=dropout_rate,
                                            activation=activation)
        x = transformer_block(x)

    # Apply dense layer, it returns raw logit of next token prediction.
    outputs = layers.Dense(vocab_size)(x)

    # Apply softmax to turn raw logit to probability distribution.
    outputs = layers.Softmax()(outputs)

    # Build the model.
    model = keras.Model(inputs=inputs, outputs=outputs)

    # Set up optimizer based on input string.
    optimizer_instance = get_optimizer(optimizer, learning_rate)

    # Define the loss function and compile the model.
    loss_fn = CustomMaskPadLoss(pad_token_id=pad_token_id)
    model.compile(optimizer=optimizer_instance, loss=loss_fn)

    # Final output layer returns the probability of next token prediction.
    return model


def get_optimizer(optimizer_name: str,
                  learning_rate: float) -> keras.optimizers.Optimizer:
    """Helper function to get the appropriate optimizer instance.

    Args:
        optimizer_name: The optimizer type ('adam' or 'sgd').
        learning_rate: The learning rate for the optimizer.

    Returns:
        keras.optimizers.Optimizer: The corresponding optimizer instance.

    Raises:
        NotImplementedError: If an unsupported optimizer is specified.
    """
    if optimizer_name.lower() == 'sgd':
        return keras.optimizers.SGD(learning_rate=learning_rate)
    elif optimizer_name.lower() == 'adamw':
        return keras.optimizers.AdamW(learning_rate=learning_rate,
                                      weight_decay=0.005,
                                      gradient_accumulation_steps=None
                                      )
    else:
        raise NotImplementedError(f'Optimizer {optimizer_name}'
                                  ' is not implemented.')


# Decorator so that the custom class can be saved and loaded correctly.
@keras.saving.register_keras_serializable()
class CustomMaskPadLoss(keras.losses.Loss):
    """Custom loss function for masked padding in sequence-based tasks.

    This loss function computes the SparseCategoricalCrossentropy
    loss while ignoring the padding tokens (specified by `pad_token_id`).
    The padding tokens are not included in the loss calculation,
    allowing the model to focus on meaningful tokens during training.

    Attributes:
        name: The name of the loss function, used by Keras.
              Defaults to 'custom_mask_pad_loss'.
        pad_token_id: The ID of the padding token. If provided,
                      padding tokens will be ignored during loss calculation.
                      If None, no padding is masked.
        kwargs: Additional keyword arguments.
    """

    def __init__(self,
                 name: str = 'custom_mask_pad_loss',
                 pad_token_id: int | None = None,
                 **kwargs: dict):
        super().__init__(name=name, **kwargs)
        self.pad_token_id = pad_token_id

    def call(self,
             y_true: tf.Tensor,
             y_pred: tf.Tensor) -> tf.Tensor:
        """Computes the custom loss, optionally masking the padding
           tokens and normalizing the loss by the number of non-masked tokens.
           The loss is computed using the SparseCategoricalCrossentropy
           loss function.
        """
        loss_fn =  tf.keras.losses.SparseCategoricalCrossentropy(
                        # The model's output is a probability distribution. If
                        # it is raw logit, this should be True.
                        from_logits=False,

                        # Average the loss across the batch size.
                        reduction='sum_over_batch_size'
                    )

        if self.pad_token_id is not None:
            # Create a boolean mask: True for non-padding tokens.
            # Shape: (batch_size, sequence_length)
            mask = tf.not_equal(y_true, self.pad_token_id)

            # Use tf.boolean_mask to filter out padded tokens.
            # y_true_filtered will be a 1D tensor containing only
            # the valid token labels.
            y_true_filtered = tf.boolean_mask(y_true, mask)

            # y_pred_filtered will be a 2D tensor containing only
            # the predictions for valid tokens.
            y_pred_filtered = tf.boolean_mask(y_pred, mask)

            loss = loss_fn(y_true_filtered, y_pred_filtered)
        else:
            loss = loss_fn(y_true, y_pred)
        return loss


# Decorator so that the custom class can be saved and loaded correctly.
@keras.saving.register_keras_serializable()
class TokenAndPositionEmbedding(layers.Layer):
    """Combines token embeddings with positional embeddings.

    This layer creates combined token and positional embeddings
    for input sequences.
    The `mask_zero=True` setting in the token embeddings allows for
    automatic masking of padded tokens.

    Attributes:
        maxlen: The maximum expected sequence length. This determines the
                    range of positional embeddings.
        vocab_size: The size of the vocabulary. This determines the size
                        of the token embedding matrix.
        d_model: The dimensionality of the token and positional embeddings.
        positional_embedding_type: The type of positional embedding
                                                to use.  Can be 'simple',
                                                'sinusoidal'.
                                                Defaults to 'sinusoidal'.
        kwargs: Additional keyword arguments passed to the base
                `keras.layers.Layer` constructor.
    """

    def __init__(self, maxlen: int,
                vocab_size: int,
                d_model: int,
                positional_embedding_type: str = 'sinusoidal',
                **kwargs: dict):
        super().__init__(**kwargs)

        self.d_model = d_model
        self.maxlen = maxlen
        self.positional_embedding_type=positional_embedding_type

        # Set mask_zero=True so that Keras generates a mask for padded tokens.
        self.token_emb = layers.Embedding(input_dim=vocab_size,
                                          output_dim=d_model,
                                          mask_zero=True)

        if self.positional_embedding_type == 'simple':
          self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=d_model)
        elif self.positional_embedding_type == 'sinusoidal':
          self.pos_emb = self.positional_encoding(length=maxlen, depth=d_model)
        else:
            raise NotImplementedError('Positional embedding type'
                                      f' {self.positional_embedding_type}'
                                      f' not implemented.')

    def positional_encoding(self,length: int, depth: int) -> tf.Tensor:
        """Creates a positional encoding for a sequence of tokens.
            This approach uses sine and cosine functions at varying
            frequencies to create
            a unique positional representation for each token in the sequence.

        Args:
          length: The length of the sequence (number of tokens).
          depth: The dimensionality of the encoding (must be even).

        Returns:
          A TensorFlow tensor of shape (length, depth) representing
          the positional encoding.
        """
        depth = depth // 2  # Use integer division to ensure an integer depth.

        positions = np.arange(length)[:, np.newaxis]  # (seq, 1)
        depths = np.arange(depth)[np.newaxis, :] / depth  # (1, depth)

        angle_rates = 1 / (10000**depths)  # (1, depth)
        angle_rads = positions * angle_rates  # (pos, depth)

        pos_encoding = np.concatenate(
            [np.sin(angle_rads), np.cos(angle_rads)],
            axis=-1)

        return tf.cast(pos_encoding, dtype=tf.float32)

    def call(self, x: tf.Tensor) -> tf.Tensor:
        token_embeddings = self.token_emb(x)

        if self.positional_embedding_type == 'sinusoidal':
          # This factor sets the relative scale of the embedding
          # and positonal_encoding.
          token_embeddings *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
          position_embeddings= self.pos_emb[tf.newaxis, :, :]
        else:
          # Defaults to simple `positional_embedding_type`.
          positions = ops.arange(0, self.maxlen, 1)
          position_embeddings = self.pos_emb(positions)

        return token_embeddings + position_embeddings


# Decorator so that the custom class can be saved and loaded correctly.
@keras.saving.register_keras_serializable()
class TransformerBlock(layers.Layer):
  """A single Transformer block.

    The Transformer block is a fundamental component of the Transformer
    architecture, which is commonly used for sequence-based tasks. It consists
    of a MultiHeadAttention layer followed by a feed-forward network,
    with layer normalization and dropout applied at each step.

    Example:
        transformer_block = TransformerBlock(d_model=256, num_heads=8,
                                             ff_dim=1024)
        output = transformer_block(inputs)

    Attributes:
        d_model: The dimensionality of the input embedding (also the output
                 size of the attention layer).
        num_heads: The number of attention heads in the multi-head
                   attention mechanism.
        ff_dim: The number of units in the feed-forward network.
        dropout_rate: Dropout rate, between 0 and 1. Default is 0.0
        activation: The activation function to use in the feed-forward network.
                     Default is 'relu'.
        seed: Random seed for dropout and attention layers to ensure
              reproducibility. Default is 42.
        kwargs: Additional keyword arguments to pass to the parent `Layer`
                class.

    Returns:
        tf.Tensor: The output of the Transformer block after applying the
                   multi-head attention, feed-forward network,
                   layer normalization, and residual connections.

    """

  def __init__(self,
               d_model: int,
               num_heads: int,
               ff_dim: int,
               dropout_rate: float = 0.0,
               activation: str = 'relu',
               **kwargs: dict):
    super().__init__(**kwargs)

    self.self_attention = MultiHeadSelfAttention(d_model,
                                                 num_heads,
                                                 dropout_rate)
    self.feed_forward = FeedForwardNetwork(d_model,
                                           ff_dim,
                                           dropout_rate,
                                           activation)

  def call(self, inputs: tf.Tensor) -> tf.Tensor:
    """Applies a single transformer block to the input tensor.

    Notes:
        - The transformer block follows the architecture with residual
          connections and layer normalization.

    Args:
        inputs: The input tensor of shape (batch_size, seq_len, embed_dim).

    Returns:
        tf.Tensor: The output tensor of shape (batch_size, seq_len, embed_dim)
                    after applying the transformer block.
    """
    # First block: masked self-attention.
    attn_output = self.self_attention(inputs)

    # Second block: feedforward network applied on attention output.
    ffn_output = self.feed_forward(attn_output)

    return ffn_output


# Decorator so that the custom class can be saved and loaded correctly.
@keras.saving.register_keras_serializable()
class FeedForwardNetwork(tf.keras.layers.Layer):
    """Feed forward network layer.

    This layer implements a two-layer feedforward network with a residual
    connection and layer normalization. It's a common component in
    transformer architectures, used to introduce non-linearity and improve
    the model's ability to capture complex relationships.

    Args:
        d_model: The dimensionality of the embedding space.
        ff_dim: The dimensionality of the hidden layer in the feedforward
                network (often larger than d_model).
        dropout_rate: The dropout rate applied to the output of the feedforward
                      network. Defaults to 0.0.
        activation: The activation function used in the first dense layer.
                    Defaults to 'relu'.
        **kwargs: Additional keyword arguments passed to the base Layer.

    Call Arguments:
        x: Input tensor of shape (batch_size, sequence_length, d_model).

    Returns:
        tf.Tensor: Output tensor of shape (batch_size, sequence_length, d_model)
                  after applying the feedforward network and residual connection.
    """

    def __init__(self,
                d_model: int,
                ff_dim: int,
                dropout_rate: float = 0.0,
                activation: str = 'relu',
                **kwargs: dict):
        super(FeedForwardNetwork, self).__init__(**kwargs)
        # Define a two-layer feedforward network.
        self.ffn = tf.keras.Sequential([
            # Expand dimension.
            tf.keras.layers.Dense(ff_dim, activation=activation),
            # Project back to d_model.
            tf.keras.layers.Dense(d_model)
        ])
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.layernorm = tf.keras.layers.LayerNormalization()

    def call(self, x: tf.Tensor) -> tf.Tensor:
        """Applies the feedforward network to the input tensor.

        Args:
            x: Input tensor of shape (batch_size, sequence_length, d_model).

        Returns:
            tf.Tensor: Output tensor of shape (batch_size, sequence_length,
                                              d_model).
        """
        ffn_output = self.ffn(x)
        ffn_output = self.dropout(ffn_output)
        # Add residual connection followed by layer normalization.
        output = self.layernorm(x + ffn_output)
        return output


# Decorator so that the custom class can be saved and loaded correctly.
@keras.saving.register_keras_serializable()
class MultiHeadSelfAttention(tf.keras.layers.Layer):
    """Multi-head self-attention Layer.

    This layer implements multi-head self-attention, a key component in
    Transformer architectures.
    It computes attention weights for each head and applies them to the
    input to generate a contextually enriched representation.

    Args:
        d_model: The dimensionality of the embedding space.
        num_heads: The number of attention heads.
        dropout_rate: The dropout rate applied to the attention output.
                      Defaults to 0.0.
        **kwargs: Additional keyword arguments passed to the base Layer.

    Call Arguments:
        x: Input tensor of shape (batch_size, sequence_length, d_model).

    Returns:
        tf.Tensor: Output tensor of shape (batch_size, sequence_length, d_model)
                    with self-attention applied.
    """

    def __init__(self,
               d_model: int,
               num_heads: int,
               dropout_rate: float = 0.0,
               **kwargs: dict):
        super(MultiHeadSelfAttention, self).__init__(**kwargs)

        # Multi-head self-attention layer.
        self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads,
                                                      key_dim=d_model)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.layernorm = tf.keras.layers.LayerNormalization()

    def call(self, x: tf.Tensor) -> tf.Tensor:
      """Applies multi-head self-attention to the input tensor.

      Args:
          x: Input tensor of shape (batch_size, sequence_length, d_model).

      Returns:
          tf.Tensor: Output tensor of shape (batch_size, sequence_length,
                                            d_model).
      """

      # Apply self-attention. The mask is typically a look-ahead mask.
      attn_output = self.mha(query=x, value=x, key=x,  use_causal_mask=True)
      attn_output = self.dropout(attn_output)
      # Add residual connection followed by layer normalization.
      output = self.layernorm(x + attn_output)
      return output


class TextGenerator(keras.callbacks.Callback):
    """A callback to generate text from a trained model.

    1. Feed a starting prompt to the model.
    2. Predict probabilities for the next token.
    3. Sample the next token and add it to the input for the next prediction.

    Attributes:
        max_tokens: Number of tokens to be generated after the prompt.
        start_tokens: Token indices for the starting prompt.
        tokenizer: Tokenizer instance to convert token indices back to words.
        pad_token_id: Token ID for padding, default is 0.
        print_every: Print the generated text every this many epochs.
                     Default is 1.
        kwargs: Any additional keyword arguments.
    """

    def __init__(self, max_tokens: int,
                 start_tokens: list[int],
                 tokenizer: Any,
                 pad_token_id: int = 0,
                 print_every: int = 1,
                 **kwargs: dict):
        """Initializes the text generator callback.

        Args:
            max_tokens: Number of tokens to generate.
            start_tokens: Token indices for the initial prompt.
            tokenizer: The tokenizer used to decode generated token indices.
            pad_token_id: The padding token ID (default is 0).
            print_every: Print the generated text every `print_every` epochs.
                         Default is 1.
        """
        super().__init__(**kwargs)
        self.max_tokens = max_tokens
        self.start_tokens = start_tokens
        self.tokenizer = tokenizer
        self.print_every = print_every
        self.pad_token_id = pad_token_id  # ID for padding token.

    def greedy_decoding(self, probs: np.ndarray) -> int:
        """Select the token index with the highest probability.

        Args:
            probs: The probability distribution of next token prediction.

        Returns:
            int: The index of the predicted token with the highest probability.
        """
        predicted_index = np.argmax(probs)
        return predicted_index

    def sampling(self, probs: np.ndarray) -> int:
      """Sample a token index from the predicted next token probability.

      Args:
          probs: The probability distribution of predicted next token.

      Returns:
          int: The index of the sampled token.
      """
      return np.random.choice(np.arange(len(probs)), p=probs)


    def on_epoch_end(self, epoch: int, logs: dict | None = None) -> None:
        """Generate and print text after each epoch based on the starting
            tokens.

        Args:
            epoch: The current epoch number.
            logs: Logs from the training process.
        """
        maxlen = self.model.layers[0].output.shape[1]
        # Make a copy of the start tokens.
        start_tokens = list(self.start_tokens)
        if (epoch + 1) % self.print_every != 0:
            return

        num_tokens_generated = 0
        tokens_generated: list[int] = []

        while num_tokens_generated < self.max_tokens:
            pad_len = maxlen - len(start_tokens)
            sample_index = len(start_tokens) - 1

            # Handle padding to ensure the sequence is of the correct length.
            if pad_len < 0:
                x = start_tokens[:maxlen]
                sample_index = maxlen - 1
            elif pad_len > 0:
                x = start_tokens + [self.pad_token_id] * pad_len
            else:
                x = start_tokens

            x = np.array([x])
            y = self.model.predict(x, verbose=0)
            sample_token = self.sampling(y[0][sample_index])

            tokens_generated.append(sample_token)
            start_tokens.append(sample_token)
            num_tokens_generated = len(tokens_generated)

        # Combine the starting tokens with the generated tokens.
        output_tokens = self.start_tokens + tokens_generated
        output_tokens = list(map(int, output_tokens))

        # Decode and print the generated text.
        txt = self.tokenizer.decode(output_tokens)
        print('Generated text:\n', txt, '\n')


def sampling(probs: np.ndarray) -> int:
    """Sample a token index from the predicted next token probability.

    Args:
        probs: The probability distribution of predicted next token.

    Returns:
        int: The index of the sampled token.
    """
    return np.random.choice(np.arange(len(probs)), p=probs)


def greedy_decoding(probs: np.ndarray) -> int:
    """Select the token index from the predicted next token probability.

    Args:
        probs: The probability distribution of predicted next token.

    Returns:
        int: The index of the token with the highest probability.
    """
    predicted_index = np.argmax(probs)
    return predicted_index


def generate_text(start_prompt: str,
                  n_tokens: int,
                  model: keras.Model,
                  tokenizer: object,
                  pad_token_id: int = 0,
                  do_sample: bool = False) -> tuple[str, list[np.ndarray]]:
    """Generate text based on a starting prompt using a trained model.

    Args:
        start_prompt: The initial prompt to start the generation.
        n_tokens: The number of tokens to generate after the prompt.
        model: The trained model to use for text generation.
        tokenizer: The tokenizer to encode and decode text.
        pad_token_id: The token ID used for padding (default is 0).
        do_sample: Whether to sample from the distribution or use
                   greedy decoding (default is False).

    Returns:
        str: The generated text after the prompt.
    """
    maxlen = model.layers[0].output.shape[1]

    # Tokenize the starting prompt.
    start_tokens = tokenizer.encode(start_prompt)

    # Generate tokens.
    tokens_generated = start_tokens + []
    probs: list[np.ndarray] = []
    for _ in range(n_tokens):
        pad_len = maxlen - len(start_tokens)
        sample_index = len(start_tokens) - 1
        if pad_len < 0:
            # Truncate the input sequence to fit the max context length.
            x = start_tokens[:maxlen]
            sample_index = maxlen - 1
        elif pad_len > 0:
            x = start_tokens + [pad_token_id] * pad_len  # Pad the input sequence.
        else:
            x = start_tokens

        x = np.array([x])
        y = model.predict(x, verbose=0)  # Get predictions from the model.

        probs.append(y[0][sample_index])

        # Use greedy decoding or sampling based on the flag.
        if not do_sample:
            sample_token = greedy_decoding(y[0][sample_index])
        else:
            sample_token = sampling(y[0][sample_index])

        tokens_generated.append(sample_token)
        start_tokens.append(sample_token)

    # Convert tokens back to text.
    generated_text = tokenizer.decode(tokens_generated)
    generated_text = generated_text.replace(tokenizer.decode([pad_token_id]), '')
    return generated_text, probs


def plot_next_token(probs_or_logits: np.ndarray,
                    tokenizer: Any,
                    prompt: str,
                    keep_top: int = 30):
    """Plots the probability distribution of the next tokens.

    This function generates a bar plot showing the top `keep_top`
    tokens by probability.

    # Function from Gemma
    https://github.com/google-deepmind/gemma/blob/ee0d55674ecd0f921d39d22615e4e79bd49fce94/gemma/gm/text/_tokenizer.py#L249-L284

    Args:
        probs_or_logits: The raw logits output by the model or
                         the probability distribution for the next token
                         prediction.
        tokenizer: The tokenizer used to decode token IDs to human-readable
                   text.
        prompt: The input prompt used to generate the next token predictions.
        keep_top: The number of top tokens to display in the plot.
                  Default is 30.

    Returns:
        None: Displays a plot showing the probability distribution of the
              top tokens.
    """

    if np.isclose(probs_or_logits.sum(), 1):
        probs = probs_or_logits
    else:
        # Apply softmax to logits to get probabilities
        probs = jax.nn.softmax(probs_or_logits)

    # Select the top `keep_top` tokens by probability
    indices = jnp.argsort(probs)

    # Reverse to get highest probabilities first
    indices = indices[-keep_top:][::-1]

    # Get the probabilities and corresponding tokens
    probs = probs[indices].astype(np.float32)
    tokens = [repr(tokenizer.decode(i.item())) for i in indices]

    # Create the bar plot using Plotly.
    fig = px.bar(x=tokens, y=probs)

    # Customize the plot layout.
    fig.update_layout(
        title='Probability Distribution of Next '
              f'Tokens given the prompt="{prompt}"',
        xaxis_title='Tokens',
        yaxis_title='Probability',
    )

    # Display the plot.
    fig.show()

The `create_model` function used below constructs a transformer model, a potent neural network architecture widely employed in natural language processing. After creating the model, print out the summary of the model.



- **`create_model(maxlen=maxlen, vocab_size=vocab_size)`**:
    - This is a function that builds the transformer model.
    - `maxlen` refers to the maximum length of the sequences that the transformer model will process.
    - `vocab_size` refers to the size of the vocabulary (the total number of unique words or tokens the model can understand). This is used to determine the number of unique inputs the model should expect (for example, it could be the number of unique tokens in the text dataset).
- **`model.summary()`**:
    - This line prints out a summary of the model architecture.
    - The summary will show a breakdown of the different layers in the model, the number of parameters each layer has, and the output shape of each layer. This summary is useful for understanding the structure of the model, like how data flows through it and where the parameters are adjusted during training.

**What to look for in the summary:**

- Layer names: The different layers of the model (like Input, Dense, etc.) will be listed.

- Output shape: The shape of the data as it moves through each layer.

- Parameters: How many parameters are in each layer. These are the numbers that will be adjusted during training to decrease loss.

In [None]:
model = create_model(maxlen=maxlen, vocab_size=tokenizer.vocab_size)
print(model.summary())

For monitoring progress, define a callback function that is used to regularly print the generated words during training. This function allows you to track the learning progress of the language model.  You can specify the number of words to print and the initial prompt to guide the model's generation:

In [None]:
prompt = 'Abeni,'
start_words = tokenizer.encode(prompt)
text_gen_callback = TextGenerator(max_tokens=10, start_tokens=start_words, tokenizer=tokenizer)

Run the cell to train the model. The training process updates the model parameter after each step.

- A step is one model update after learning on a batch of examples.

- An epoch (`num_epoch`) is the number of times the model goes through the entire dataset.  `num_examples / batch_size` is the number of steps needed for one epoch.

When training the model below, you will specify the number of epochs.

Run the cell below to train the model. Rerun it until you achieve a training loss of around 0.1.

It is recommendend to train the model for at least `200` epochs. But if training is taking long, you can reduce the number of epochs:

In [None]:
num_epochs = 200  #@param {type: 'number'}
# verbose=2: Instructs the model.fit method to print one line per
# epoch so you see how the loss is decreasing and generated texts improving.
history = model.fit(x=dataset, verbose=2,
                    epochs=num_epochs,
                    callbacks=[text_gen_callback])

Now that you have a trained model, you can prompt it like you did in "Lab 3 - Experiment with a Transformer Model".

## Step 7: Prompting the trained model

Now that you are going to prompt the SLM, in this section, you will ask the four following key questions to evaluate its quality. The questions are:

*   A. How good is the SLM at predicting the next word of a given prompt (prior words) based on patterns identified in the training dataset?
*   B. Is the generated text coherent, and does it make sense given the context?
*   C. Is the likely next token what you expect to see when the context is changed slightly?
*   D. How does the model handle unseen tokens?





**A. How good is the SLM at predicting the next word of a given prompt (prior words) based on patterns identified in the training dataset?**

- Prompt the model using a word or sequence of words from the training dataset. For example, you can start with `'Abeni, a bright-eyed'`
- Visualize the probability distribution of the next token of given prompt
- Increase the `num_next_words` number to see more text.
- Set `do_sample=False` to "greedily" pick the next token given the context (prior tokens).
- Inspect the generated text. See how well the model has learned to generate text that reflects the patterns learned during training:

In [None]:
prompt = 'Abeni, a bright-eyed' #@param {type: 'string'}
num_next_words = 10 #@param {type: 'number'}
generated_text, probs = generate_text(prompt, num_next_words, model=model, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, do_sample=False)
plot_next_token(probs[0], tokenizer, prompt=prompt)
print('\n')
print('Generated Text:', generated_text)

**B. Is the generated text coherent, and does it make sense given the context?**
- Prompt the model with words or a phrase of your choosing.
- Increase the `num_next_words` number to see more texts.
- Visualize the probability distribution of the next token for the given prompt
- Set `do_sample=True` to sample words from the probability distribution of the next token given the context (prior tokens).
- Inspect the quality of generated texts.



In [None]:
prompt = 'Jide was hungry so she went looking for' #@param {type: 'string'}
num_next_words = 10 #@param {type: 'number'}
generated_text, probs = generate_text(prompt, num_next_words, model=model, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, do_sample=True)
plot_next_token(probs[0], tokenizer, prompt=prompt)
print('\n')
print('Generated Text:', generated_text)

**C. Is the likely next token what you expect to see when the context is changed slightly?**
- Change the context of the prompt slightly.
- Visualize the probability distribution of the next token for the given prompt.
- Increase the `num_next_words` number to see more texts.
- Set `do_sample=True` to sample words from probability distribution of the next token given the context (prior tokens).
- Inspect the quality of generated texts:

In [None]:
prompt = 'Jide was thirsty so she went looking for' #@param {type: 'string'}
num_next_words = 10 #@param {type: 'number'}
generated_text, probs = generate_text(prompt, num_next_words, model=model, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, do_sample=True)
plot_next_token(probs[0], tokenizer, prompt=prompt)
print('\n')
print('Generated Text:', generated_text)

**D. How does the model handle unseen tokens?**

- Prompt the model with words that are not present in the training dataset, like `'photosynthesis'`.
- Visualize the probability distribution of the next token for the given prompt.
- Increase the `num_next_words` number to see more texts.
- Inspect the generated texts:

In [None]:
prompt = 'Photosynthesis is the process ' #@param {type: 'string'}
num_next_words = 10 #@param {type: 'number'}
generated_text, probs = generate_text(prompt, num_next_words, model=model, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, do_sample=True)
plot_next_token(probs[0], tokenizer, prompt=prompt)
print('\n')
print('Generated Text:', generated_text)

Did you notice how the model replaced the word `photosynthesis` with the unknown `'<UNK>'` token? The generated sentence has nothing to do with photosynthesis, since the dataset does not contain any information about this. The dataset used to train the model dictates the knowledge it can exhibit based on the patterns it has learned during training.

## Experimentation and hyperparameter tuning
Above, you have used a few hyperparameters to train the SLM.

**What are hyperparameters?**

Hyperparameters are settings or values that you define before training a model. You can learn about them in detail [here](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)). They control the learning process. These are not learned by the model itself but are set by the user. In other words, hyperparameters help guide how the model learns from the data.

- `batch_size`: The batch size determines how many samples are included in each batch for model update.

- `num_epochs`: This is the number of times the model goes through the entire training dataset. An epoch consists of several iterations because the model goes through training data in batches. In each iteration, the model calculates loss and uses it to adjust its predictions. More epochs mean more times a model gets to adjust its understanding of the task and data. However, more epochs also means more time taken to train the model. This corresponds to `num_examples / batch_size` steps of training.


Another hyperparameter you've seen is `maxlen`, which  refers to the maximum number of tokens that each sequence will be padded or truncated to.

Now, adjust some of these hyperparameters and compare the results:

Your tasks:

1. Adjust the `batch_size` and `num_epochs` hyperparameters. Try out different configurations. For example, you can try different batch sizes (e.g, 8, 16, 32, 64...) and different numbers of epochs (e.g, 5, 10, 20...).

2. Run the model for each configuration.

3. Create a table to write down the loss at the end.

Below is an example of the table you should create.


|batch_size | num_epochs | train_loss |
|---|---|---|
|8 | 5 | 6.91|


Change the `batch_size` and `num_epochs`. Then, run the cell below to train the model:

In [None]:
# Prepare the dataset for training.
dataset = tf.data.Dataset.from_tensor_slices((input, target))

# Shuffle the examples in the dataset.
dataset = dataset.shuffle(buffer_size=len(input))

# Specify the batch size.
batch_size = 8  #@param {type: 'number'}

# Create batches of examples used to update the model parameters.
dataset = dataset.batch(batch_size)

model = create_model(maxlen=maxlen, vocab_size=tokenizer.vocab_size)

# Specify the number of epochs to train the model for.
num_epochs = 5  #@param {type: 'number'}

# Train the model.
history = model.fit(x=dataset, verbose=2,
                    epochs=num_epochs)

> Are you running into an "Out of Memory" error?

If you're getting an "Out of Memory" error, it means your system doesn't have enough memory to process the data. Here are some practical solutions:

Consider trying the following:

1. Reduce the `maxlen`.

    Lower the number of words (tokens) processed at once. Shorter sequences need less memory. Consider truncating long sequences to a smaller length.

2. Reduce the `batch_size`.

    A smaller batch means less data is processed at once, reducing memory requirements.

## Reflection

This is the end of Lab 5 - Training Your Own Small Language Model (SLM).

In this lab, you trained your first SLM and engaged in the following steps.

- **Tokenized the dataset:** You used the `SimpleWordTokenizer` from the previous lab to convert the text descriptions into numerical representations.

- **Padded the sequences:** You ensured all sequences have the same length by padding them with a special `'<PAD>'` token.  This is crucial for processing data in neural networks.

- **Prepared the input and target data:** You created input-target pairs, where the target is the input sequence shifted by one token.  This teaches the model to predict the next word based on the context (prior words).

- **Shuffled and batched the data:** You shuffled the dataset to randomize the training examples and grouped them into batches for efficient processing.

- **Trained the SLM:** You built and trained a small transformer model, observing how the training loss decreased over epochs.

- **Prompted the trained model:** You experimented with prompting the model, observing its ability to predict likely next word, generate coherent text, handle unseen words (represented by `'<UNK>'`), and adapt to changes in context.

- **Experimented with hyperparameters:** You learned about hyperparameters, such as `batch_size`, `num_epochs`, and explored their impact on training by trying different values.


The next section of the course delves deeper into model evaluation.