> <p><small><small>This Notebook is made available subject to the licence and terms set out in the <a href = "http://www.github.com/google-deepmind/ai-foundations">AI Research Foundations Github README file</a>.

# **Build Your Own Small Language Model, Lab 3: Experiment with a Transformer Model**

<a href='https://colab.research.google.com/github/google-deepmind/ai-foundations/blob/master/course_1/introduction_to_language_modeling_lab_3.ipynb' target='_parent'><img src='https://colab.research.google.com/assets/colab-badge.svg' alt='Open In Colab'/></a>

In [None]:
# Packages used.
from IPython.display import clear_output

Run the (hidden) code cell below to load a pre-trained Gemma language model.

In [None]:
# @title Hidden code used for loading a pre-trained Gemma language model.
# The gemma package provides tools for working with Gemma language models,
# including loading and prompting.
# Install it using `!pip install gemma==3.0.0`.
!pip install gemma==3.0.0
from IPython.display import clear_output
clear_output()  # Clears the output
import os
import jax
import jax.numpy as jnp
import numpy as np
import plotly.express as px
from gemma import gm
from typing import Any


def prompt_transformer_model(input_text: str,
                             max_new_tokens: int = 10,
                             model_name: int = 'Gemma-1B',
                             do_sample: bool = True) -> tuple[str, np.ndarray, Any]:
    """Generate text from a transformer model (Gemma) based on the input text.

    Args:
        input_text: The input prompt for the model.
        max_new_tokens: The maximum number of new tokens to generate.
        model_name: The name of the model to load. Supported options are
                    'Gemma-1B' and 'Gemma-4B'. Defaults to 'Gemma-1B'.
        do_sample: Whether to use sampling for text generation (True for random
                  sampling, False for greedy).

    Returns:
        output_text: The generated text, including the input text and the
                     model's output.
        next_token_logits: Logits for the next token (probability distribution).
        tokenizer: The tokenizer used for encoding/decoding the text.

    Raises:
        ValueError: If the model_name is not recognized or supported.
    """

    assert isinstance(do_sample, bool), 'do_sample must be a boolean value.'

    # Process for Gemma-based models.
    if model_name not in ['Gemma-1B', 'Gemma-4B']:
        raise ValueError(f'model_name=`{model_name}` is not supported. '
                        'Supported options are \'Gemma-1B\' and \'Gemma-4B\'')

    tokenizer, model, params = load_gemma(model_name)
    sampler = gm.text.Sampler(
        model=model,
        params=params,
        tokenizer=tokenizer,
    )

    if not do_sample:
        sampler_output_text = sampler.sample(input_text,
                                            max_new_tokens=max_new_tokens,
                                            sampling=gm.text.Greedy())
    else:
        sampler_output_text = sampler.sample(input_text,
                                            max_new_tokens=max_new_tokens,
                                            sampling=gm.text.RandomSampling())

    # Convert the input text to tokens and apply the model to generate predictions
    prompt = tokenizer.encode(input_text, add_bos=True)
    prompt = jnp.asarray(prompt)
    out = model.apply(
        {'params': params},
        tokens=prompt,
        return_last_only=True  # Only predict the last token.
    )
    next_token_logits = out.logits
    output_text = input_text + sampler_output_text

    return output_text, next_token_logits, tokenizer


def load_gemma(model_name: str = 'Gemma-1B') -> tuple:
    """Loads a Gemma model and its associated tokenizer and parameters.

    Args:
        model_name: The name of the Gemma model to load. Options are: 'Gemma-1B'
                    and 'Gemma-4B'.

    Returns:
        tokenizer: Tokenizer for the specified Gemma model.
        model: The Gemma model.
        params: The parameters for the specified Gemma model.

    Raises:
        ValueError: If an unsupported model name is provided.
    """
    # Set the full GPU memory usage for JAX
    os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.00'

    # Initialize variables
    tokenizer = None
    model = None
    params = None

    # Model loading based on model_name
    if model_name == 'Gemma-1B':
        tokenizer = gm.text.Gemma3Tokenizer()
        model = gm.nn.Gemma3_1B()
        params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_1B_PT)
    elif model_name == 'Gemma-4B':
        tokenizer = gm.text.Gemma3Tokenizer()
        model = gm.nn.Gemma3_4B()
        params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_PT)
    else:
        raise ValueError(f'Unsupported model name: {model_name}. '
                        'Please use \'Gemma-1B\' or \'Gemma-4B\'.')

    return tokenizer, model, params


def plot_next_token(probs_or_logits: np.ndarray,
                    tokenizer: Any,
                    prompt: str,
                    keep_top: int = 30):
    """Plots the probability distribution of the next tokens.

    This function generates a bar plot showing the top `keep_top`
    tokens by probability.

    # Function from gemma
    https://github.com/google-deepmind/gemma/blob/ee0d55674ecd0f921d39d22615e4e79bd49fce94/gemma/gm/text/_tokenizer.py#L249-L284

    Args:
        probs_or_logits: The raw logits output by the model or
                          the probability distribution for the next token
                          prediction.
        tokenizer: The tokenizer used to decode token IDs to human-readable text.
        prompt: The input prompt used to generate the next token predictions.
        keep_top: The number of top tokens to display in the plot. Default is 30.

    Returns:
        None: Displays a plot showing the probability distribution of the top
              tokens.
    """

    if np.isclose(probs_or_logits.sum(), 1):
        probs = probs_or_logits
    else:
      # Apply softmax to logits to get probabilities
        probs = jax.nn.softmax(probs_or_logits)

    # Select the top `keep_top` tokens by probability
    indices = jnp.argsort(probs)

    # Reverse to get highest probabilities first
    indices = indices[-keep_top:][::-1]

    # Get the probabilities and corresponding tokens
    probs = probs[indices].astype(np.float32)
    tokens = [repr(tokenizer.decode(i.item())) for i in indices]

    # Create the bar plot using Plotly
    fig = px.bar(x=tokens, y=probs)

    # Customize the plot layout
    fig.update_layout(
        title=f'Probability Distribution of Next Tokens given the prompt="{prompt}"',
        xaxis_title='Tokens',
        yaxis_title='Probability',
    )

    # Display the plot
    fig.show()


## Predicting a next token using a pre-trained Gemma model

**Your task**
1. Select the transformer model from the `model_name` dropdown menu.
2. Enter a prompt of your choice using the `prompt` text field.
3. Run the cell.
4. Inspect the model's prediction for the next token.

For example, if you start with the prompt: `'Jide was hungry so she went looking for'` the transformer model will predict the next token. A token can be a single character (like `'T'`), a full word (like `'The'`), or a subword (such as `'Th'`).

Try running the cell several times to observe how the model responds to different prompts:

> Is the cell below running slowly? Change the `model_name` to try out a different model size. Remember, a model with fewer parameters is faster.

In [None]:
model_name = 'Gemma-1B' #@param ['Gemma-1B', 'Gemma-4B']

prompt = 'Jide was hungry so she went looking for' #@param {type: 'string'}

output_text, next_token_logits, tokenizer = prompt_transformer_model(prompt, max_new_tokens=1, model_name=model_name)
clear_output() # Clears the output.

print(output_text)

### Visualize the probability distribution of the predicted next token

Now that you've seen the model's prediction, it's important to examine the probability distribution behind the next token. The transformer model calculates the likelihood of each possible next token. This is based on the context (prior words) of the prompt you provided and samples from the probability distribution.

The plot below visualizes the probability distribution of the next token predicted by the language model and the prompt.  Each bar represents a different token, and its height corresponds to the probability assigned to that token by the model.  

Visualizing the probability distribution allows you to analyze the model's preferences for different token choices after being given a prompt.  A highly peaked distribution suggests high confidence in a single prediction, while a flatter distribution indicates greater uncertainty and a broader range of plausible next tokens.

Run the cell below:

In [None]:
plot_next_token(next_token_logits, tokenizer, prompt=prompt)

When you run the cell above, the model generates a probability distribution for the next token. Some tokens will have higher probabilities than others, meaning they are more likely to be chosen as the next word.

Here are a few likely observations using the `Gemma-1B` model:

1. The most probable token will usually be a common word that fits the context of the sentence (e.g., `'food'` after the prompt `'Jide was hungry so she went looking for'`).
2. The model might suggest words that seem plausible but aren't always the most expected, like `'a'` or `'something'`.
3. You might notice some tokens have low probabilities, meaning the model considers them less likely to fit but doesn't completely rule them out, like `'work'` or `'help'`.
4. Changing the transformer model may result in slight variations in the predicted next token. This is because the prediction is influenced by the model's parameters, which are determined by the dataset used for training.

Try out different prompts and observe the probability distribution of the next token prediction.

### Changing the context slightly

What happens to the probability distribution if the context is changed? Try `'Jide was thirsty so she went looking for'`:

In [None]:
model_name = 'Gemma-1B' #@param ['Gemma-1B', 'Gemma-4B']

prompt = 'Jide was thirsty so she went looking for' #@param {type: 'string'}

output_text, next_token_logits, tokenizer = prompt_transformer_model(prompt,
                                                                     max_new_tokens=1,
                                                                     model_name=model_name)
clear_output() # Clears the output.

plot_next_token(next_token_logits, tokenizer, prompt=prompt)

#### What did you observe?

When running the transformer model with prompts like `'Jide was thirsty so she went looking for'`, you might notice certain patterns in the predicted next tokens. For instance, you may see drink-related words like "water" suggested more often. This is because the transformer model is **context-aware** and understands that terms related to hunger and thirst tend to align with certain words—like `'food'` or `'water'`—based on the context provided by the prompt.

#### Comparison between transformer models

Different transformer models can sometimes generate different next tokens, even for the same prompt. You might see variations in the suggestions depending on the size and training of the model you're using. Larger models, with more data and parameters, tend to generate more accurate and contextually appropriate predictions. Smaller models might be more limited in their understanding, occasionally offering less relevant or more generic predictions.

#### Transformer models versus n-gram models

When comparing  transformer models to traditional n-gram models, you likely noticed some key differences.

N-gram models predict the next token based on a fixed window of the preceding tokens (e.g., the last two or three words). These models often struggle with longer-range dependencies or more complex sentence structures, as they only consider a limited context.

In contrast, transformer models have a very long context window usually of thousands of tokens. The context window (or length) represents the size of prior texts that the transformer model can consider and focus on at a given time. As a result, the transformer model is better able to learn the relationship between several tokens when compared to N-gram models, which can realistically only go up to single-digit context lengths. This makes the transformer models more flexible and accurate, especially in situations where the context stretches beyond just a few words.

For example, when comparing outputs for the same prompt, you may see that n-gram models often fail to predict more specific words (like `'water'` or `'food'` after `'hungry'`), because they don't understand the broader context as effectively. Transformer models, on the other hand, would likely generate more contextually appropriate words, like `'food'` when the prompt mentions hunger or `'water'` when thirst is implied.

### Generating more samples

Now, try increasing the `num_next_tokens` to generate more texts and observe how the model responds:

In [None]:
model_name = 'Gemma-1B' #@param [ 'Gemma-1B', 'Gemma-4B']

prompt = 'Jide was thirsty so she went looking for' #@param {type: 'string'}

num_next_tokens = 100 #@param {type: 'number'}

output_text, next_token_logits, tokenizer = prompt_transformer_model(prompt,
                                                                     max_new_tokens=num_next_tokens,
                                                                     model_name=model_name)
clear_output() # Clears the output.

print(output_text)

#### Language models tend to follow their training data distributions

When you ran the cell above multiple times, what did you notice?

- Did you observe stereotypical outputs?
- Perhaps the output seems non-contextually relevant?

Language models are adept at predicting the next token, but they closely follow the distribution of their training data. If the model is trained on biased data, it will produce biased outputs.

Similarly, if a language model is trained on data gathered from across the internet, it will reflect the dominant texts and perspectives found there. For instance, if more text data is written in English than another language, then naturally the output token distributions following a prompt like `'Jide'` would be in English. Adapting or fine-tuning pre-trained language models for specific tasks (like auto-completing `'Jide...'` in a specific writing style) is an exciting research area.

> **NOTE:** The `Gemma` models that are used here are pre-trained checkpoints and not "instruction-tuned" models. If you are curious, Gemma's <a href='https://github.com/google-deepmind/gemma'>code</a> and
<a href='https://gemma-llm.readthedocs.io/en/latest/index.html'>documentation</a> are publicly available.


**Does the output above change every time you run the cell?**

You likely noticed that the output of the transformer model changes each time you run the cell above, even with the same prompt. This is because the model uses a probability distribution to pick the next token, which introduces a level of stochasticity (randomness) into the prediction. This is the same as what you saw in the n-gram models, where the next word isn't always the same due to the model sampling from a set of possibilities.

This variability helps the model generate more diverse and creative outputs like you've seen previously when sampling from the n-gram model.



### Controlling the model's output

The `do_sample` variable is used to instruct the model to sample from the probability distribution. It is set to `True` by default. Therefore, you may observe diverse outputs when you experiment with the models above.

If you want the model to return a determinstic output, i.e., to always pick the token with the highest probability of occuring next, set the variable `do_sample=False` as done in the cell below.

<!--The following is an example of the model with the variable set to False.

```python
prompt_transformer_model(prompt, max_new_tokens=num_next_tokens, model_name=model_name, do_sample=False)
```

With this setting, the output will be consistent across multiple runs for the same prompt as it always selects the most probable token.

**Sampling Mode (Default: `do_sample=True`)**
-->

<!--By default, `do_sample` is set to `True` making the model samples from the probability distribution which inturn introduces randomness and results in more varied and creative outputs. This is helpful when you want the model to explore a range of possible continuations for a prompt, rather than sticking strictly to the most likely outcome.-->

Running the cell below multiple times with the same prompt will return the same output.

Run the cell below multiple times and observe the result:

In [None]:
model_name = 'Gemma-1B' #@param ['Gemma-1B', 'Gemma-4B']

prompt = 'Jide was thirsty so she went looking for' #@param {type: 'string'}

num_next_tokens = 100 #@param {type: 'number'}

output_text, next_token_logits, tokenizer = prompt_transformer_model(prompt,
                                                                     max_new_tokens=num_next_tokens,
                                                                     model_name=model_name,
                                                                     do_sample=False)
clear_output() # Clears the output.

print(output_text)

**Balancing creativity and consistency**

Sampling from a probability distribution allows the transformer model to explore a range of possible next tokens, fostering creativity and generating varied outputs. This approach contrasts with always picking the token with the highest probability, which focuses on the most likely next token, as you have seen above.

Different applications require different settings for this balance. For creative tasks such as generating stories, sampling from the probability distribution is ideal. This is because it allows the model to explore various possibilities and produce more imaginative results.

If accuracy, consistency and reliability are important for your use case, it's better to choose the token with the highest probability.

## Reflection

This is the end of **Lab 3: Experiment with a Transformer Model.**

In this lab, you:

- Experimented with a transformer model, exploring its ability to predict the next token in a sequence. By trying different prompts and model sizes, you likely observed how the model's predictions and their probabilities shifted based on the context.

- Visualized the probability distribution to have a deeper understanding of the model's confidence in different potential next tokens based on given context (prior words).

- Explored the impact of generating longer sequences of text, discovering the role of sampling in creating varied and sometimes unexpected outputs. This variability showcases the creative potential of transformer models while also highlighting their susceptibility to biases present in their training data. By toggling the `do_sample` parameter, you experienced the trade-off between generating consistent, predictable outputs and embracing the diversity of sampled predictions.

This lab offered a hands-on introduction to the power and nuances of transformer models. In the next part of the course, you'll compare n-grams and transformers, evaluating their performance based on fluency, coherence, and relevance. This will further solidify your understanding of these language models and their application in natural language processing.