# Sampling example

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/sampling.ipynb)

Example on how to load a Gemma model and run inference on it.

In [None]:
!pip install -q gemma

In [1]:
# Common imports
import os
import jax
import jax.numpy as jnp

# Gemma imports
from gemma import gm

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):

In [None]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Load the model, the params and the tokenizer.

In [2]:
tokenizer = gm.text.Gemma2Tokenizer()

model = gm.nn.Gemma2_2B()

params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA2_2B_IT)

INFO:2025-01-21 14:37:27,995:jax._src.xla_bridge:945: Unable to initialize backend 'rocm': Your process properly initialized the GPU backend, but //learning/brain/research/jax:gpu_support is not linked in. You most likely should add that build dependency to your program.
INFO:2025-01-21 14:37:29,437:jax._src.xla_bridge:945: Unable to initialize backend 'pathways': Could not initialize backend 'pathways'
INFO:2025-01-21 14:37:29,437:jax._src.xla_bridge:945: Unable to initialize backend 'mock_tpu': Must pass --mock_tpu_platform flag to initialize the mock_tpu backend


## Single token

Here's an example of predicting a single token, directly calling the model.

In [5]:
# Encode the prompt
prompt = tokenizer.encode('My name is', add_bos=True)  # /!\ Don't forget to add the BOS token
prompt = jnp.asarray(prompt)


# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    return_last_only=True,  # Only predict the last token
)


# Sample a token from the predicted logits
next_token = jax.random.categorical(
    jax.random.key(1),
    out.logits
)
tokenizer.decode(next_token)

' Mary'

You can also display the next token probability.

In [6]:
tokenizer.plot_logits(out.logits)

## Multiple tokens

In practice, Gemma provide a `gm.text.Sampler` to perform efficient sampling (with kv-caching, early stopping,...).

In [7]:
sampler = gm.text.Sampler(
    model=model,
    params=params,
    tokenizer=tokenizer,
)

sampler.sample('My name is', max_new_tokens=30)

" Sarah and I'm a freelance writer. I'm passionate about helping people tell their stories and share their knowledge with the world. \n\nI"