> <p><small><small>This Notebook is made available subject to the licence and terms set out in the <a href = "http://www.github.com/google-deepmind/ai-foundations">AI Research Foundations Github README file</a>.

![](https://storage.googleapis.com/dm-educational/assets/ai_foundations/GDM-Labs-banner-image-C4-white-bg.png)

# Lab: Trainable Parameters in the Transformer Model

<a href='https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_4/gdm_lab_4_5_reflection_on_trainable_parameters.ipynb' target='_parent'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a>

Explore how individual neural network components can be combined to assemble a full transformer model.

25 minutes

## Overview

In this lab, you will walk through a **full transformer implementation** in Keras and JAX, one component at a time. As you go through the implementation, you will define functions that compute the number of trainable parameters for each component to finally compute the number of trainable parameters of the entire transformer model.

This lab provides you with the opportunity to explore all of the details of a full transformer architecture. It further illustrates an important property of implementations of complex neural network models, namely **modularity**. By breaking up a complex architecture into smaller building blocks and then combining them, they become a lot more manageable and it is easy to adjust individual components.

### What you will learn
By the end of this lab, you will::

- Understand which neural network components are required to assemble a transformer model.
- Understand how many trainable parameters each of these components has.
- Understand how hyperparameters such as the number of transformer blocks or the vocabulary size affect the overall number of model parameters.

### Tasks

In this lab, you will:

* Examine the following neural network components for assembling a transformer model:
    1. Layer normalization
    2. Embedding layer
    3. Multi-head attention
    4. Multi-layer perceptron
    5. Transformer block
    6. Output layer
    7. Full transformer model
* Implement a function that computes the number of trainable parameters for each of the components to gain a deeper understanding of each component.
* Modify the model's hyperparameters to see how changes in the size of different components impact the total parameter count.
* *(Optional)*: Train the model that is implemented in this lab on the **Africa Galore** dataset to verify that the implementation is working as expected.


## How to use Google Colaboratory (Colab)


Google Colaboratory (also known as Google Colab) is a platform that allows you to run Python code in your browser. The code is written in **cells** that are executed on a remote server.

To run a cell, hover over the cell and click on the `run` button to its left. The run button is the circle with the triangle (▶). Alternatively, you can also click on a cell and use the keyboard combination Ctrl+Return (or ⌘+Return if you are using a Mac).

To try this out, run the following cell. This should print today's day of the week below it.

In [None]:
from datetime import datetime

print(f"Today is {datetime.today():%A}.")

Note that the *order in which you run the cells matters*. When you are working through a lab, make sure to always run *all* cells in order, otherwise the code might not work. If you take a break while working on a lab, Colab may disconnect you and in that case, you have to execute all cells again before  continuing your work. To make this easier, you can select the cell you are currently working on and then choose __Runtime → Run before__  from the menu above (or use the keyboard combination Ctrl/⌘ + F8). This will re-execute all cells before the current one.

## Imports

In this lab, you will mainly implement functions that perform simple computations and do not require any additional packages. The transformer model that is already implemented in this lab uses the Keras and JAX packages, and you will also use methods from the custom `ai_foundation` package to verify your implementation.

Run the following cell to import the required packages.

In [None]:
%%capture

import os # For setting Keras parameters.
os.environ["KERAS_BACKEND"] = "jax"

# Install the custom package for this course.
!pip install "git+https://github.com/google-deepmind/ai-foundations.git@main"

import jax # For working with vectors and matrices.
import jax.numpy as jnp # For working with vectors and matrices.
import pandas as pd # For loading the dataset.
import keras # For defining the transformer model.
import tqdm # For displaying progress bars.
from keras import layers # For defining the transformer model.
from ai_foundations import training # For training your transformer model.
from ai_foundations import generation # For prompting your transformer model.
from ai_foundations import tokenization # For loading the BPE tokenizer.
# For providing feedback.
from ai_foundations.feedback.course_4 import counting_parameters as feedback

## Components of the transformer

The following cells implement all the components of the transformer model in Keras. Walk through each component and at the end of each component implement the function that computes how many parameters are being trained as part of this component.

### Layer normalization

The cell below implements the layer normalization component using Keras and JAX. Recall that the formula for layer normalization is:

$$ \mbox{LayerNorm}(\mathbf{x}) = \gamma \frac{\mathbf{x} - \mu(\mathbf{x})}{\sqrt{\mbox{Var}(\mathbf{x}) + \epsilon} } + \beta,$$

where $\mu(\mathbf{x})$ is the mean of elements in $\mathbf{x}$ and $\mbox{Var}(\mathbf{x})$ is the variance of the elements in $\mathbf{x}$. $\gamma$ (gamma) and $\beta$ (beta) are learnable parameters.

<br>

------
> **ℹ️ Info: Keras layers**
>
> To understand how individual neural network components can be implemented in Keras, note that all Keras layers consist of an `__init__` and a `call` method.
>
> **Initialization method (`__init__`)**:
>
>The initialization function is called with the options for the respective component. For example, in the case of layer normalization, the two options are the embedding size `embedding_dim` and the `epsilon` argument, which defines the constant that is added to the variance to avoid division by 0 errors.
>
>Using the `embedding_dim` argument, the method initializes the $\gamma$ and $\beta$ vectors, which are the main parameters of this model.
>
>**`call` method**:
>
>The `call` method (also known as the forward function) is called whenever the model makes a prediction. Either during training to compute the loss function or during validation and testing to make predictions on data points that are not part of the training data. It defines how the input is transformed to the output of the component by using the model's parameters.
>
>Consider now the `call` function below. It is called with the input to the component `x`. It then outputs the layer normalized inputs.
>
>This function performs five steps:
>1. It computes the mean $\mu(\mathbf{x})$ across all features using `jnp.mean`. Note that the `axis` argument is set to `-1`. This tells the `jnp.mean` function that the mean should be computed across the last dimension of `x` which is again, by convention, all features of a single data point.
>2. It computes the variance across all features using the formula $\mbox{Var}(\mathbf{x}) = \frac{1}{d_x}\left(\mathbf{x}-\mu(\mathbf{x})\right)^2$, where $d_x$ is the dimension of `x`.
>3. It computes the normalized values of `x` using the mean and variance.
>4. It returns the normalized values combined with the scaling factor $\gamma$ and the shifting term $\beta$.
>
------

<br />

Run the following cell to define the `LayerNorm` component. Then walk through the code line by line.

In [None]:
class LayerNorm(layers.Layer):
    """A Keras implementation of Layer Normalization.

    This layer normalizes the activations of the previous layer for
    each given example in a batch independently, across the features dimension.

    Args:
      embedding_dim: The dimension of the output of the attention mechanism.
      epsilon: A small float added to variance to avoid dividing by zero.
    """

    def __init__(self, embedding_dim: int, epsilon: float = 1e-6, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon

        # The shape of the parameters (gamma and beta) is the size of the
        # input embeddings and the size of the output of the attention mechanism
        # `embedding_dim`.
        shape = (embedding_dim,)

        # Initialize gamma (scale) as a vector of ones.
        self.gamma = self.add_weight(
            shape=shape,
            initializer='ones',
            name='gamma'
        )
        # Initialize beta (shift) as a vector of zeros.
        self.beta = self.add_weight(
            shape=shape,
            initializer='zeros',
            name='beta'
        )

    def call(self, x: jax.Array):
        """
        Applies the layer normalization logic.

        Args:
          x: The input tensor.
            Shape: (batch_size, sequence_length, embedding_dim).

        Returns:
          The normalized and transformed tensor.
            Shape: (batch_size, sequence_length, embedding_dim).
        """
        # Calculate mean and variance over the feature axis (-1).
        mean = jnp.mean(x, axis=-1, keepdims=True)
        variance = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)

        # Normalize the input.
        normalized_x = (x - mean) / jnp.sqrt(variance + self.epsilon)

        # Apply the learned scale (gamma) and shift (beta).
        return self.gamma * normalized_x + self.beta

How many trainable parameters does this component have? Here, and throughout this lab, assume that the model has been initialized with the following hyperparameters:

```python
{
    # The maximum number of input tokens in a sequence.
    "max_length": 128,
    # The dimension of the input embeddings, of the outputs of the
    # attention mechanism, and of the output of every transformer block.
    "embedding_dim": 256,
    # The dimension of the hidden MLP layer in each transformer block.
    "mlp_dim": 384,
    # The number of attention heads.
    "num_heads": 4,
    # The number of transformer blocks.
    "num_blocks": 2,
    # The number of unique tokens in the vocabulary.
    "vocabulary_size": 262144
}
```

The following cell defines these parameters and a function that computes the number of trainable parameters based on the model hyperparameters. To get you started, this cell is already complete. For the remaining components in this lab, you will have to complete the function for computing the trainable parameters.

In [None]:
# Define the model hyperparameters.
MODEL_HYPERPARAMETERS = {
    "max_length": 128,
    "embedding_dim": 256,
    "mlp_dim": 384,
    "num_heads": 4,
    "num_blocks": 2,
    "vocabulary_size": 262144
}


def parameter_count_layer_norm(hyperparams: dict[str, int]) -> int:
    embedding_dim =  hyperparams["embedding_dim"]
    parameter_count = embedding_dim + embedding_dim
    return parameter_count


param_count = parameter_count_layer_norm(MODEL_HYPERPARAMETERS)
if param_count != ...:
    print(f"Number of parameters for layer normalization: {param_count:,}")

In [None]:
# @title Run this cell to check your implementation of `parameter_count_layer_norm`
feedback.test_parameter_count_layer_norm(parameter_count_layer_norm)

### Coding Activity 1: Token and position embeddings

The following cell implements both the sinusoidal positional embeddings and
a component to embed the input tokens. The sinusoidal positional embeddings do not contain any trainable parameters, so you may ignore the `positional_encoding` function.

<br />

------
> **💻 Your task:**
>
> Walk through the `TokenAndPositionEmbedding` implementation and then complete the `parameter_count_embedding` function to compute the number of trainable parameters in this component.
>
------

In [None]:
class TokenAndPositionEmbedding(layers.Layer):
    """
    A Keras layer that combines token and sinusoidal positional embeddings.

    Args:
      vocabulary_size: The size of the vocabulary.
      embedding_dim: The dimensionality of the embeddings. Must be even.
      max_length: The maximum length of the input sequences in tokens.
    """

    def __init__(
        self,
        vocabulary_size: int,
        embedding_dim: int,
        max_length: int,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.vocabulary_size = vocabulary_size
        self.embedding_dim = embedding_dim
        self.max_length = max_length
        self.pos_encoding = self._positional_encoding()

        self.token_embeddings = layers.Embedding(
            input_dim=vocabulary_size, output_dim=embedding_dim, mask_zero=True
        )

    def _positional_encoding(self) -> jax.Array:
        """
        Creates a fixed sinusoidal positional encoding matrix.

        This function generates a unique positional representation for each
        token in a sequence using sine and cosine functions of different
        frequencies.

        Returns:
          A JAX array of shape (1, max_length, embedding_dim) for the positional
            encoding.
        """
        depth = self.embedding_dim // 2
        # Shape: (max_length, 1).
        positions = jnp.arange(self.max_length)[:, jnp.newaxis]
        depths = jnp.arange(depth)[jnp.newaxis, :] / depth
        # Shape: (1, depth).
        angle_rates = 1 / (10000**depths)
        # Shape: (1, depth).
        angle_rads = positions * angle_rates
        # Shape: (max_length, depth).
        pos_encoding = jnp.concatenate(
            [jnp.sin(angle_rads), jnp.cos(angle_rads)], axis=-1
        )
        # Add a batch dimension for broadcasting.
        return pos_encoding[jnp.newaxis, :, :]

    def call(self, x: jax.Array) -> jax.Array:
        """
        Applies the embedding layer.

        Args:
          x: Input tensor of token IDs. Shape: (batch_size, sequence_length).

        Returns:
          A JAX array of shape (batch_size, sequence_length, embedding_dim)
            representing the combined token and positional embeddings.
        """
        # Get token embeddings from the lookup table.
        # The input tensor `x` contains integer token IDs.
        token_embeddings = self.token_embeddings(x)

        # Scale token embeddings, as described in Vaswani et al., 2017.
        token_embeddings *= jnp.sqrt(self.embedding_dim)

        # Add the fixed positional embeddings.
        # The positional encoding has shape (1, max_length, embedding_dim) and
        # will be broadcasted across the batch dimension of token_embeddings.
        return token_embeddings + self.pos_encoding


In [None]:
# Implement the computation of the number of parameters in the embedding layer:
def parameter_count_embedding(hyperparams: dict[str, int]) -> int:

    parameter_count = ...

    return parameter_count


param_count = parameter_count_embedding(MODEL_HYPERPARAMETERS)
if param_count != ...:
    print(f"Number of parameters for embedding component: {param_count:,}")

In [None]:
# @title Run this cell to check your implementation of `parameter_count_embedding`
feedback.test_parameter_count_embedding(parameter_count_embedding)

### Coding Activity 2: Multi-head attention

The following cell implements the complete multi-head attention mechanism.

<br>

------
> **💻 Your task:**
>
>Walk through this implementation and then complete the `parameter_count_attention` function.
>
>The `call` method here is complex since it supports highly parallelized computations where all computations for all attention heads and all examples in one batch are performed in parallel using 4-dimensional tensors (matrices with four dimensions). In order to estimate the number of trainable parameters, you only have to understand the first part of the `MultiHeadSelfAttention` and the `__init__` method. So it is okay if you are unsure what exactly is happening in the `call` method.
>
>
>**Hints:**
>
>
>- This component includes a layer normalization component. Use the `parameter_count_layer_norm` function from above to obtain the number of parameters for that component.
>
>- The queries, keys, and values are projected to a dimension of `embedding_dim / num_heads`. Since there are `num_heads` this means that the overall dimension of the projections across all heads is `embedding_dim / num_heads * num_heads = embedding_dim`. So for the purpose of computing the number of parameters, you may assume that there is only a single attention head that projects everything to a dimension `embedding_dim`.
>
> -  The implementation of `Dense` in Keras automatically adds bias terms. Therefore, if the input dimension that is passed to a Dense layer is $q$ and the output dimension is $r$ it will initialize a parameter matrix of dimension $q\times r$ and a bias vector of dimension $r$. This results in total in $(q+1)\times r$ parameters.
>
-----

In [None]:
K_MASK = -2.3819763e38


class MultiHeadSelfAttention(layers.Layer):
    """A Keras layer for multi-head self-attention with causal masking.

    Args:
      embedding_dim: The dimensionality of the embeddings. Must be divisible by
        num_heads.
      num_heads: The number of attention heads.
      dropout_rate: The dropout rate to apply to the attention output.
    """

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        dropout_rate: float = 0.0,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate

        # embedding_dim must be divisible by num_heads.
        assert (
            embedding_dim % num_heads == 0
        ), "embedding_dim must be divisible by num_heads"
        self.head_dim = embedding_dim // num_heads

        # Define projection layers for query, key, value, and the final output.
        self.q_dense = keras.layers.Dense(embedding_dim, name="q_projection")
        self.k_dense = keras.layers.Dense(embedding_dim, name="k_projection")
        self.v_dense = keras.layers.Dense(embedding_dim, name="v_projection")
        self.output_dense = keras.layers.Dense(embedding_dim,
                                               name="output_projection")

        # Dropout layer.
        self.dropout = layers.Dropout(rate=dropout_rate)

    def _split_heads(self, x: jax.Array) -> jax.Array:
        """Splits the last dimension into (num_heads, head_dim) and transposes
           the array such that the final shape is
           (batch_size, num_heads, sequence_length, head_dim).

        Args:
          x: Input tensor of shape (batch_size, sequence_length, embedding_dim).

        Returns:
          jax.Array of shape (batch_size, num_heads, sequence_length, head_dim).
        """
        batch_size, sequence_length, _ = x.shape
        x = x.reshape(
            batch_size, sequence_length, self.num_heads, self.head_dim
        )
        return x.transpose((0, 2, 1, 3))

    def call(self, x: jax.Array, training: bool = False) -> jax.Array:
        """Forward pass for the MultiHeadSelfAttention layer.

        Args:
          x: Input tensor of shape (batch_size, sequence_length, embedding_dim).
          training: Python boolean indicating whether the layer should behave in
            training mode (apply dropout) or in inference mode.

        Returns:
          The output tensor of shape
            (batch_size, sequence_length, embedding_dim).
        """
        batch_size, sequence_length, _ = x.shape

        # Project inputs to Q, K, V.
        query = self.q_dense(x)  # (batch_size, sequence_length, embedding_dim).
        key = self.k_dense(x)  # (batch_size, sequence_length, embedding_dim).
        value = self.v_dense(x)  # (batch_size, sequence_length, embedding_dim).

        # Reshape for multi-head attention.

        # (batch, num_heads, sequence_length, head_dim).
        query = self._split_heads(query)

        # (batch, num_heads, sequence_length, head_dim).
        key = self._split_heads(key)

        # (batch, num_heads, sequence_length, head_dim).
        value = self._split_heads(value)

        # Scaled dot-product attention.
        # Matrix multiplication scores:
        # (..., sequence_length, d_k) x (..., d_k, sequence_length)
        #    -> (..., sequence_length, sequence_length).
        logits_raw = jnp.matmul(query, key.transpose((0, 1, 3, 2)))
        logits_raw /= jnp.sqrt(self.head_dim)

        # Apply causal (look-ahead) mask.
        # The mask ensures that attention is only applied to previous positions.
        causal_mask = jnp.tril(jnp.ones((sequence_length, sequence_length),
                                        dtype=bool))
        logits_masked = jnp.where(causal_mask, logits_raw, K_MASK)

        # Apply softmax to get attention weights.
        # (batch, num_heads, sequence_length, sequence_length).
        attention_weights = jax.nn.softmax(logits_masked, axis=-1)

        # Apply attention weights to values.
        # (batch, num_heads, sequence_length, head_dim).
        attention_output = jnp.matmul(attention_weights, value)

        # Concatenate heads and apply final projection.
        # Transpose back to (batch_size, sequence_length, num_heads, head_dim).
        attention_output = attention_output.transpose((0, 2, 1, 3))

        # Reshape to (batch_size, sequence_length, embedding_dim).
        attention_output = attention_output.reshape(
            batch_size, sequence_length, self.embedding_dim
        )
        attention_output = self.output_dense(attention_output)

        # Apply dropout.
        attention_output = self.dropout(attention_output, training=training)

        return attention_output

In [None]:
# Implement the computation of the number of parameters in the multi-head
# attention component here:
def parameter_count_attention(hyperparams: dict[str, int]) -> int:

    parameter_count = ...

    return parameter_count


param_count = parameter_count_attention(MODEL_HYPERPARAMETERS)
if param_count != ...:
    print("Number of parameters for multi-head attention component:"
          f" {param_count:,}")

In [None]:
# @title Run this cell to check your implementation of `parameter_count_attention`
feedback.test_parameter_count_attention(parameter_count_attention)

### Coding Activity 3: MLP component

The following cell implements the MLP component including layer normalization as part of the transformer block.

<br>

------
> **💻 Your task:**
>
>Walk through this implementation and then complete the `parameter_count_mlp` function.
>
------

In [None]:
class MultiLayerPerceptron(layers.Layer):
    """Multi-layer perceptron component.

    This component implements a two-layer multi-layer perceptron. It introduces
    a non-linearity and improves the model's ability to learn complex patterns.

    Args:
      embedding_dim: The dimensionality of the embedding space.
      mlp_dim: The dimensionality of the hidden layer in the MLP component
        (often larger than embedding_dim).
      dropout_rate: The dropout rate applied to the output of the MLP
        component.
      activation: The activation function used in the first dense layer.

    Returns:
      Output tensor of shape (batch_size, sequence_length, embedding_dim)
        after applying the MLP.
    """

    def __init__(self,
                 embedding_dim: int,
                 mlp_dim: int,
                 dropout_rate: float = 0.0,
                 activation: str = "relu",
                 **kwargs: dict):
        super().__init__(**kwargs)
        # Define a two-layer MLP.
        self.mlp = keras.Sequential([
            # Hidden layer.
            layers.Dense(mlp_dim, activation=activation),
            # Output layer.
            layers.Dense(embedding_dim)
        ])
        self.dropout = layers.Dropout(dropout_rate)

    def call(self, x: jax.Array) -> jax.Array:
        """Applies the MLP to the input tensor.

        Args:
          x: Input tensor of shape (batch_size, sequence_length, embedding_dim).

        Returns:
          Output tensor of shape (batch_size, sequence_length, embedding_dim).
        """
        # Shape: (batch_size, sequence_length, embedding_dim).
        mlp_output = self.mlp(x)

        # Apply dropout.
        # Shape: (batch_size, sequence_length, embedding_dim).
        mlp_output = self.dropout(mlp_output)

        return mlp_output

In [None]:
# Implement the computation of the number of parameters in the MLP component
# here:
def parameter_count_mlp(hyperparams):

    parameter_count = ...

    return parameter_count


param_count = parameter_count_mlp(MODEL_HYPERPARAMETERS)
if param_count != ...:
    print(f"Number of parameters for MLP component: {param_count:,}")

In [None]:
# @title Run this cell to check your implementation of `parameter_count_mlp`
feedback.test_parameter_count_mlp(parameter_count_mlp)

### Coding Activity 4: Transformer block

The following cell implements the transformer block that first passes the input through the multi-head attention mechanism and then passes it through the MLP component.

<br>

------
> **💻 Your task:**
>
>Walk through this implementation and then complete the `parameter_count_transformer_block` function.
>
>
>**Hints:**
>- This block includes both a multi-head attention component and an MLP component. Use your `parameter_count_attention` and `parameter_count_mlp` functions from above to compute the number of parameters for these components.
>- This block includes layer normalization after both the multi-head attention component and the MLP component. Use your `parameter_count_layer_norm` function from above to compute the number of parameters for these components.
>
------



In [None]:
class TransformerBlock(layers.Layer):
  """A single transformer block.

    The transformer block is a fundamental component of the transformer
    architecture, which is commonly used for sequence-based tasks. It consists
    of a MultiHeadAttention layer followed by an MLP,
    with layer normalization and dropout applied at each step.

    Example:
      transformer_block = TransformerBlock(embedding_dim=256, num_heads=8,
                                           mlp_dim=1024)
      output = transformer_block(inputs)

    Args:
      embedding_dim: The dimensionality of the input embedding (also the output
        size of the attention layer).
      num_heads: The number of attention heads in the multi-head attention
        mechanism.
      mlp_dim: The number of units in the MLP.
      dropout_rate: Dropout rate, between 0 and 1.
      activation: The activation function to use in the MLP.
    """

  def __init__(self,
               embedding_dim: int,
               num_heads: int,
               mlp_dim: int,
               dropout_rate: float = 0.0,
               activation: str = "relu",
               **kwargs: dict):
    super().__init__(**kwargs)

    self.self_attention = MultiHeadSelfAttention(embedding_dim,
                                                 num_heads,
                                                 dropout_rate)

    self.layer_norm_attention = LayerNorm(embedding_dim)
    self.feed_forward = MultiLayerPerceptron(embedding_dim,
                                             mlp_dim,
                                             dropout_rate,
                                             activation)
    self.layer_norm_mlp = LayerNorm(embedding_dim)

  def call(self, x: jax.Array) -> jax.Array:
    """Applies a single transformer block to the input tensor.

    Args:
      x: The input tensor of shape (batch_size, sequence_length, embedding_dim).

    Returns:
      The output tensor of shape (batch_size, sequence_length, embedding_dim)
        after applying the transformer block.
    """
    # Apply masked self-attention.
    # Shape: (batch_size, sequence_length, embedding_dim).
    attention_output = self.self_attention(x)

    # Add residual connection.
    attention_output = attention_output + x

    # Apply layer normalization.
    attention_output = self.layer_norm_attention(attention_output)

    # Multi-layer perceptron applied to attention output.
    # Shape: (batch_size, sequence_length, embedding_dim).
    mlp_output = self.feed_forward(attention_output)

    # Add residual connection.
    # Shape: (batch_size, sequence_length, embedding_dim).
    mlp_output = mlp_output + attention_output

    # Apply layer normalization.
    # Shape: (batch_size, sequence_length, embedding_dim).
    mlp_output = self.layer_norm_mlp(mlp_output)

    return mlp_output

In [None]:
# Implement the computation of the number of parameters in
# one transformer block here:
def parameter_count_transformer_block(hyperparams):

    parameter_count = ...

    return parameter_count


param_count = parameter_count_transformer_block(MODEL_HYPERPARAMETERS)
if param_count != ...:
    print(f"Number of parameters for one transformer block: {param_count:,}")

In [None]:
# @title Run this cell to check your implementation of `parameter_count_transformer_block`
feedback.test_parameter_count_transformer_block(parameter_count_transformer_block)

### Coding Activity 5: The output layer

The following cell implements the last component that is needed, namely the output layer with a SoftMax activation function. This layer outputs the probability distribution over the next token.

<br>

------
> **💻 Your task:**
>
>Walk through this implementation and then complete the `parameter_count_output_layer` function.
>
------

In [None]:
class OutputLayer(keras.layers.Layer):
    """
    A layer to compute the log probability distribution over the vocabulary.

    This layer projects the input tensor to the vocabulary size and applies a
    log-softmax activation.

    Args:
        vocabulary_size: The size of the vocabulary.
    """
    def __init__(self, vocabulary_size: int, **kwargs):
        super().__init__(**kwargs)
        self.vocabulary_size = vocabulary_size

        # The dense layer projects the input from embedding_dim to
        # vocabulary_size.
        self.output_layer = keras.layers.Dense(vocabulary_size,
                                        name="output_projection")

    def call(self, x: jax.Array) -> jax.Array:
        """
        Forward pass for the OutputLayer.

        Args:
          x: Input tensor of shape (batch_size, sequence_length, embedding_dim).

        Returns:
          The logits for each token in the vocabulary, with shape
            (batch_size, sequence_length, vocabulary_size).
        """
        # Project the embedding_dim dimension to the vocabulary size.
        logits = self.output_layer(x)

        return logits

In [None]:
# Implement the computation of the number of parameters in
# the output layer here:
def parameter_count_output_layer(hyperparams):

    parameter_count = ...

    return parameter_count


param_count = parameter_count_output_layer(MODEL_HYPERPARAMETERS)
if param_count != ...:
    print(f"Number of parameters for output layer: {param_count:,}")

In [None]:
# @title Run this cell to check your implementation of `parameter_count_output_layer`
feedback.test_parameter_count_output_layer(parameter_count_output_layer)

### Coding Activity 6: Putting it all together

With all of the components in place, the following cell implements the final transformer model.

It consists of three components:

1. The input embedding layer
2. A stack of `n_blocks` transformer blocks
3. The output layer

<br>

------
> **💻 Your task:**
>
>Walk through this implementation and then complete the `parameter_count_transformer` function.
>
>As above, use the existing functions to compute the number of parameters for the individual components.
>
------


In [None]:
class TransformerModel(layers.Layer):
    """Implements the full transformer model in Keras.

    Args:
      vocabulary_size: The size of the vocabulary, i.e., the number of unique
        tokens.
      max_length: The maximum length of the input sequences.
      embedding_dim: The dimensionality of the embedding space.
      mlp_dim: The number of units in the MLP of each transformer block.
      num_heads:The number of attention heads in the multi-head attention
        mechanism.
      num_blocks: The number of transformer blocks to stack in the model.
      dropout_rate: The dropout rate to prevent overfitting.
      activation: The activation function to use in the MLP of each transformer
        block.
    """

    def __init__(self,
                 vocabulary_size: int,
                 max_length: int,
                 embedding_dim: int = 256,
                 mlp_dim: int = 256,
                 num_heads: int = 2,
                 num_blocks: int = 1,
                 dropout_rate: float = 0.0,
                 activation: str = "relu",
                 **kwargs):
        super().__init__(**kwargs)

        # Create an embedding layer that combines token and positional
        # embeddings.
        self.embedding_layer = TokenAndPositionEmbedding(vocabulary_size,
                                                         embedding_dim,
                                                         max_length)

        # Create a stack of transformer blocks.
        self.transformer_blocks = keras.Sequential()
        for _ in range(num_blocks):
            self.transformer_blocks.add(TransformerBlock(
                embedding_dim,
                num_heads,
                mlp_dim,
                dropout_rate=dropout_rate,
                activation=activation))

        # Create output layer.
        self.output_layer = OutputLayer(vocabulary_size)


    def call(self, x: jax.Array) -> jax.Array:

        # Embed input tokens.
        # Shape: (batch_size, sequence_length, embedding_dim).
        x = self.embedding_layer(x)

         # Shape: (batch_size, sequence_length, embedding_dim).
        x = self.transformer_blocks(x)

        # Compute output (log-probabilties for each token).
         # Shape: (batch_size, sequence_length, vocabulary_size).
        output = self.output_layer(x)

        return output

In [None]:
# Implement the computation of the total number of parameters in the
# transformer here:
def parameter_count_transformer(hyperparams):

    parameter_count = ...

    return parameter_count


param_count = parameter_count_transformer(MODEL_HYPERPARAMETERS)
if param_count != ...:
    print(f"Number of parameters for full transformer model: {param_count:,}")

In [None]:
# @title Run this cell to check your implementation of `parameter_count_transformer`
feedback.test_parameter_count_transformer(parameter_count_transformer)

You have now walked through the full implementation of the transformer model by exploring each of its components and how they are combined.

As you observed, this model has 135 million trainable parameters, which is quite a considerable number. For all of these parameters, the optimizer has to compute gradients for each example and then update the weights on each training step. This is why training usually takes a lot of time. Also note, in comparison to a model such as Gemma-1B, this model is still much smaller.

Now go back through the number of parameters of each component. Which components introduce a lot of parameters? Which ones are less so?

Then investigate which hyperparameters have a big effect on the number of parameters of the model. Edit the parameters below and observe how the parameters of each component change.


In [None]:
# @title Compute trainable parameters

max_length = 128 # @param {"type": "number"}
embedding_dim = 256 # @param {"type": "number"}
mlp_dim = 384 # @param {"type": "number"}
num_heads = 4 # @param {"type": "number"}
num_blocks = 2 # @param {"type": "number"}
vocabulary_size = 262144 # @param {"type": "number"}

your_hyperparameters = {
    "max_length": max_length,
    "embedding_dim": embedding_dim,
    "mlp_dim": mlp_dim,
    "num_heads": num_heads,
    "num_blocks": num_blocks,
    "vocabulary_size": vocabulary_size
}

# Compute parameter counts.
parameter_counts = {
    "embedding": parameter_count_embedding(your_hyperparameters)
}

for i in range(your_hyperparameters["num_blocks"]):
    parameter_counts[f"transformer_block_{i}"] = parameter_count_transformer_block(your_hyperparameters)
    parameter_counts[f"  attention_{i}"] = parameter_count_attention(your_hyperparameters)
    parameter_counts[f"  layer_norm_attention_{i}"] = parameter_count_layer_norm(your_hyperparameters)
    parameter_counts[f"  mlp_{i}"] = parameter_count_mlp(your_hyperparameters)
    parameter_counts[f"  layer_norm_mlp_{i}"] = parameter_count_layer_norm(your_hyperparameters)

parameter_counts["output_layer"] = parameter_count_output_layer(your_hyperparameters)

# Print the parameter counts.
total = parameter_count_transformer(your_hyperparameters)
max_key_width = len(max(parameter_counts.keys(), key=len))
max_value_width = 80 - max_key_width

separator_length = max_key_width + 2 + max_value_width
print("Parameters of each component:")
separator_length = max_key_width + 2 + max_value_width
print("-" * separator_length)
print(f"{'Component':<{max_key_width}}  {'Parameters':>{max_value_width}}")
for key, value in parameter_counts.items():
    if not key.startswith(" "):
        print("-" * separator_length)
        print(f"{key:<{max_key_width}}  {value:>{max_value_width},}")
        if key.startswith("transformer_block"):
            print("~" * separator_length)
    else:
        print(f"{key:<{max_key_width}}  {value:>{max_value_width},}")

print("-" * separator_length)
print(f"{'Total':<{max_key_width}}  {total:>{max_value_width},}")

## Optional: Training the model

As a last optional exercise, if you would like to see this model in action, you can run the following hidden cell to load the Africa Galore dataset, tokenize and pad the data, and train the model. This will take about one minute to run on a Colab instance with a GPU or 10 minutes on a Colab instance with a CPU.

You can then sample continuations to a prompt from the model in the cell after the training loop.



### Load and tokenize the dataset

Run the following cell to load the Africa Galore dataset and tokenize it with the Byte Pair Encoding tokenizer from the previous courses.

In [None]:
# Load the dataset and the tokenizer.

africa_galore = pd.read_json(
    "https://storage.googleapis.com/dm-educational/assets/ai_foundations/africa_galore.json"
)
dataset = africa_galore["description"].values
print("Loaded dataset with", dataset.shape[0], "paragraphs.")
BPEWordTokenizer = tokenization.BPEWordTokenizer
tokenizer = BPEWordTokenizer.from_url("https://storage.googleapis.com/dm-educational/assets/ai_foundations/bpe_tokenizer_3000_v2.pkl")

encoded_tokens = []
for paragraph in tqdm.tqdm(dataset, unit="paragraphs"):
    encoded_tokens.append(tokenizer.encode(paragraph))

### Prepare the dataset for training

Run the following cell to pad and truncate the paragraphs and prepare the input and target sequences for training your model.

In [None]:
max_length = 300
padded_sequences = keras.preprocessing.sequence.pad_sequences(
        encoded_tokens,
        maxlen=max_length,
        padding="post",
        truncating="post",
        value=tokenizer.pad_token_id,
    )
# Prepare input and target for the transformer model.
# For each example, extract all tokens except the last one.
input_sequences = padded_sequences[:, :-1]
# For each example, extract all tokens except the first one.
target_sequences = padded_sequences[:, 1:]

max_length = input_sequences.shape[1]

### Define and train the model

The following cell defines the transformer model using the `TransformerModel` class that you defined above. It also initializes the optimizer (Adam) and the loss function (a multi-class cross-entropy loss), and attaches both of these training components to the model. Finally, it initializes a function that prints a generation after every tenth epoch which allows you to monitor the training progress.

In [None]:
# Set a seed for reproducability.
keras.utils.set_random_seed(3112)

# Initialize the transformer model that you defined above.
transformer = TransformerModel(
    tokenizer.vocabulary_size,
    max_length=max_length,
    dropout_rate=0.1,
    num_blocks=2,
    embedding_dim=64,
    mlp_dim=128,
)

# Build the Keras model.
input_layer = keras.Input((max_length,))
output_layer = transformer(input_layer)
model = keras.Model(input_layer, output_layer)

# Initialize the optimizer.
optimizer = keras.optimizers.Adam(learning_rate=2.5e-3)
# Initialize the loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(
    # The output layer outputs raw logits rather than probabilities computed
    # through the softmax to improve efficiency and avoid very small numbers.
    from_logits=True,
    ignore_class=tokenizer.pad_token_id,
    # Average the loss across the batch size.
    reduction="sum_over_batch_size",
)

# Attach the optimizer and loss function to the model.
model.compile(optimizer=optimizer, loss=loss_fn)

# Initialize a callback function that prints a generation after every 10 epochs.
prompt = "Jide"
prompt_ids = tokenizer.encode(prompt)
text_gen_callback = training.TextGenerator(
    max_tokens=11, start_tokens=prompt_ids, tokenizer=tokenizer, print_every=10
)

Run the following cell to use the `fit` method to train your model for 100 epochs.

In [None]:
model.fit(
    x=input_sequences,
    y=target_sequences,
    batch_size=32,
    epochs=100,
    callbacks=[text_gen_callback],
)

### Prompt your model

Finally, you can use the following cell to prompt your transformer model.

In [None]:
# Use this cell to generate new texts.
# You can edit the prompt varible to modify the input prompt.
prompt = "Jide was thirsty so she went looking for a"

# Greedy sampling.

generated_text, _ = generation.generate_text(
    prompt,
    n_tokens = 1,
    model = model,
    tokenizer = tokenizer,
    sampling_mode="greedy"
)
print(f"Generated text: {generated_text}")

## Summary

In this activity you explored all the **individual components of the transformer model**. You have seen how they are implemented, how many trainable parameters they have, and how different model hyperparameters affect the number of trainable parameters.

This lab has also illustrated how splitting up a model into its individual components makes it more manageable to implement and maintain the code for complex models. Instead of implementing every part of the transformer in one gigantic Keras model, you have seen that it is possible to combine smaller building blocks, each of which are less complex. This type of **modularity** is an important property of frameworks like Keras.



## Solutions

The following cells provide reference solutions to the coding activities in this notebook. If you really get stuck after trying to solve the activities yourself, you may want to consult these solutions.

It is recommended that you *only* look at the solutions after you have tried to solve the activities *multiple times*. The best way to learn challenging concepts in computer science and artificial intelligence is to debug your code piece-by-piece until it works, rather than copying existing solutions.

If you feel stuck, you may want to first try to debug your code. For example, by adding additional print statements to see what your code is doing at every step. This will provide you with a much deeper understanding of the code and the materials. It will also provide you with practice on how to solve challenging coding problems beyond this course.

To view the solutions for an activity, click on the arrow to the left of the activity name. If you consult the solutions, do not copy and paste them into the cells above. Instead, look at them, and type them manually into the cell. This will help you understand where you went wrong.


### Coding Activity 1

In [None]:
def parameter_count_embedding(hyperparams: dict[str, int]) -> int:
    """Computes parameters for the token embedding matrix.

    The embedding matrix has shape `vocabulary_size x embedding_dim`.

    Args:
      hyperparams: Model hyperparameters. Expects `"vocabulary_size"` and
          `"embedding_dim"`.

    Returns:
        int: Total number of trainable parameters for the embedding layer.
    """

    vocabulary_size = hyperparams["vocabulary_size"]
    embedding_dim = hyperparams["embedding_dim"]
    # The embedding matrix is of size `vocabulary_size` x `embedding_dim`.
    parameter_count = vocabulary_size * embedding_dim
    return parameter_count

### Coding Activity 2

In [None]:
def parameter_count_attention(hyperparams: dict[str, int]) -> int:
    """Computes parameters for a multi-head attention sublayer with LayerNorm.

    Counts parameters for the query, key, value, and output linear projections,
    each modeled as a dense layer with bias of shape
    `embedding_dim x embedding_dim` plus a bias vector of size `embedding_dim`.

    Args:
      hyperparams: Model hyperparameters. Expects `"embedding_dim"`.

    Returns:
      Total number of trainable parameters for the attention sublayer.
    """

    embedding_dim = hyperparams["embedding_dim"]
    # Parameters for query projection.
    # Note that for the key, query, and value projections, the first dimension
    # is d_head * num_heads which happens to be embedding_dim.
    q_parameter_count = (embedding_dim + 1) * embedding_dim
    # Parameters for key projection.
    k_parameter_count = (embedding_dim + 1) * embedding_dim

    # Parameters for value projection.
    v_parameter_count = (embedding_dim + 1) * embedding_dim

    # Parameters for output projection.
    o_parameter_count = (embedding_dim + 1) * embedding_dim

    parameter_count = (
        q_parameter_count
        + k_parameter_count
        + v_parameter_count
        + o_parameter_count
    )

    return parameter_count

### Coding Activity 3

In [None]:
def parameter_count_mlp(hyperparams: dict[str, int]) -> int:
    """Computes parameters for the MLP component.

    The MLP is modeled as two dense layers with biases:
    - First projection: `embedding_dim -> mlp_dim`
    - Second projection: `mlp_dim -> embedding_dim`

    Args:
      hyperparams: Model hyperparameters. Expects `"embedding_dim"` and
          `"mlp_dim"`.

    Returns:
      Total number of trainable parameters for the MLP sublayer.
    """

    embedding_dim = hyperparams["embedding_dim"]
    mlp_dim = hyperparams["mlp_dim"]

    # Parameters for first projection component.
    ffn_parameter_count = (embedding_dim + 1) * mlp_dim
    # Parameters for second projection component.
    output_parameter_count = (mlp_dim + 1) * embedding_dim

    parameter_count = (
        ffn_parameter_count + output_parameter_count
    )
    return parameter_count

### Coding Activity 4

In [None]:
def parameter_count_transformer_block(hyperparams: dict[str, int]) -> int:
    """Computes parameters for a transformer block (attention + MLP).

    Sums the parameters from the multi-head attention component (plus its
    LayerNorm) and the MLP component (plus its LayerNorm).

    Args:
      hyperparams : Model hyperparameters needed by the attention and MLP
        parameter count calculators. Typically includes `"embedding_dim"`,
        `"mlp_dim"`, and potentially others.

    Returns:
      Total number of trainable parameters for one transformer block.
    """

    embedding_dim = hyperparams["embedding_dim"]

    # Parameters for multi-head attention mechanism.
    mha_parameter_count = parameter_count_attention(hyperparams)

    # Parameters for MLP component.
    mlp_parameter_count = parameter_count_mlp(hyperparams)

    # Parameters for two layer norm components.
    layer_norm_parameter_count = 2 * embedding_dim

    parameter_count = (
        mha_parameter_count + mlp_parameter_count + layer_norm_parameter_count
    )
    return parameter_count

### Coding Activity 5

In [None]:
def parameter_count_output_layer(hyperparams: dict[str, int]) -> int:
    """Computes parameters for the output projection layer.

    The output projection maps from `embedding_dim` to `vocabulary_size` and
    includes a bias term for each vocabulary entry.

    Args:
      hyperparams: Model hyperparameters. Expects `"vocabulary_size"` and
        `"embedding_dim"`.

    Returns:
      Total number of trainable parameters for the output layer.
    """

    embedding_dim = hyperparams["embedding_dim"]
    vocabulary_size = hyperparams["vocabulary_size"]

    # Parameters for output projection.
    output_parameter_count = (embedding_dim + 1) * vocabulary_size

    # Only the projection component has parameters,
    # the activation function does not.
    parameter_count = output_parameter_count

    return parameter_count

### Coding Activity 6

In [None]:
def parameter_count_transformer(hyperparams: dict[str, int]) -> int:
    """Computes parameters for an entire transformer model consisting of
    `num_blocks` blocks.

    Args:
      hyperparams: Model hyperparameters needed by the transformer block
        parameter count calculator, including `"embedding_dim"`,
        `"vocabulary_size", `"mlp_dim"`, `"num_blocks"`, and potentially others.

    Returns:
      Total number of trainable parameters for the transformer model.
    """

    num_blocks = hyperparams["num_blocks"]

    # Parameter count of embedding layer.
    embedding_parameter_count = parameter_count_embedding(hyperparams)

    # Parameter count of `num_blocks` transformer blocks.
    transformer_blocks_parameter_count = (
        num_blocks * parameter_count_transformer_block(hyperparams)
    )

    # Parameter count of output_layer.
    output_parameter_count = parameter_count_output_layer(hyperparams)

    parameter_count = (
        embedding_parameter_count
        + transformer_blocks_parameter_count
        + output_parameter_count
    )

    return parameter_count