# Sampling

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/sampling.ipynb)

Example on how to load a Gemma model and run inference on it.

The Gemma library has 3 ways to prompt a model:

* `gm.text.ChatSampler`: Easiest to use, simply talk to the model and get answer. Support multi-turns conversations out-of-the-box.
* `gm.text.Sampler`: Lower level, but give more control. The chat state has to be manually handeled for multi-turn.
* `model.apply`: Directly call the model, only predict a single token.

In [1]:
# !pip install -q gemma

In [1]:
# Common imports
import os
import jax
import jax.numpy as jnp

# Gemma imports
from gemma import gm

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):

In [2]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

import jax
print("JAX devices:", jax.devices())

os.system("nvidia-smi")

JAX devices: [CudaDevice(id=0)]
Tue Dec  2 03:04:54 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.105.08             Driver Version: 580.105.08     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4080        On  |   00000000:01:00.0 Off |                  N/A |
|  0%   37C    P8             12W /  320W |     597MiB /  16376MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+--------------

W1202 03:04:54.173913   74395 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1202 03:04:54.176926   74242 cuda_executor.cc:1802] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.


0

Load the model and the params. Here we load the instruction-tuned version of the model.

In [3]:
model = gm.nn.Gemma3_4B()

params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)

2025-12-02 03:05:03.668900: W external/local_xla/xla/tsl/platform/cloud/google_auth_provider.cc:185] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Could not resolve hostname', error details: Could not resolve host: metadata.google.internal".
E1202 03:05:05.026103   74691 google_auth_provider.cc:188] Could not find the credentials file in the standard gcloud location [/home/gb21553/.config/gcloud/application_default_credentials.json]. You may specify a credentials file using $GOOGLE_APPLICATION_CREDENTIALS, or to use Google application default credentials, run: gcloud auth application-default login


## Multi-turns conversations

The easiest way to chat with Gemma is to use the `gm.text.ChatSampler`. It hides the boilerplate of the conversation cache, as well as the `<start_of_turn>` / `<end_of_turn>` tokens used to format the conversation.

Here, we set `multi_turn=True` when creating `gm.text.ChatSampler` (by default, the `ChatSampler` start a new conversation every time).

In multi-turn mode, you can erase the previous conversation state, by passing `chatbot.chat(..., multi_turn=False)`.

In [4]:
sampler = gm.text.ChatSampler(
    model=model,
    params=params,
    multi_turn=True,
    print_stream=True,  # Print output as it is generated.
)

turn0 = sampler.chat('Share one methapore linking "shadow" and "laughter".')

Okay, here's a metaphor linking "shadow" and "laughter," aiming for a slightly evocative and layered feel:

“Laughter, like a shadow, is born of a shared light. It stretches and dances, mimicking the joy it reflects, but it’s always a little darker, a little cooler – a reminder that even the brightest moments hold a trace of something unseen, something that lingers just beyond the reach of the sun.”

---

**Why this works:**

*   **Shadow as a Reflection:** Shadows are inherently linked to light. Laughter is often a response to joy, so the shadow represents the underlying complexities or potential sadness that can exist even within happiness.
*   **Mimicry & Difference:** Shadows mimic the shape of the light source, but they aren't the light itself. Laughter mimics the joy it’s responding to, but it’s a separate, distinct experience.
*   **Underlying Darkness:** The “darker, cooler” aspect acknowledges that happiness isn’t always pure and unadulterated.  It hints at vulnerability or a 

In [5]:
turn1 = sampler.chat('Expand it in a haiku.')

Okay, here’s a haiku expanding on that metaphor:

Dark shadow dances,
Laughter’s echo, cool and brief,
Joy’s unseen heart beats.

---

**Explanation of choices:**

*   **First line (“Dark shadow dances”):** Immediately establishes the visual link and the movement associated with both.
*   **Second line (“Laughter’s echo, cool and brief”):** Captures the fleeting nature of laughter and the slightly detached feeling of a shadow. “Cool” adds a touch of melancholy.
*   **Third line (“Joy’s unseen heart beats”):**  Connects the outward expression (laughter) to a deeper, perhaps more vulnerable, core feeling.

Do you want me to try a different haiku variation, or perhaps explore a specific emotion related to the shadow/laughter connection?

Note: By default (`multi_turn=False`), the conversation state is reset everytime, but you can still continue the previous conversation by passing `sampler.chat(..., multi_turn=True)`

By default, greedy decoding is used. You can pass a custom `sampling=` method as kwargs:

* `gm.text.Greedy()`: (default) Greedy decoding
* `gm.text.RandomSampling()`: Simple random sampling with temperature, for more variety

## Sample a prompt

For more control, we also provide a `gm.text.Sampler` which still perform efficient sampling (with kv-caching, early stopping,...).

Prompting the sampler require to correctly add format the prompt with the `<start_of_turn>` / `<end_of_turn>` tokens (see the custom token section doc on [tokenizer](https://gemma-llm.readthedocs.io/en/latest/tokenizer.html)).

In [6]:
sampler = gm.text.Sampler(
    model=model,
    params=params,
)

prompt = """<start_of_turn>user
Give me a list of inspirational quotes.<end_of_turn>
<start_of_turn>model
"""

out = sampler.sample(prompt, max_new_tokens=1000)
print(out)

Okay, here's a list of inspirational quotes, categorized a bit for different vibes, with a little context where helpful:

**1. On Perseverance & Resilience:**

*   “The only way to do great work is to love what you do.” – Steve Jobs (Focuses on passion and dedication)
*   “Fall seven times, stand up eight.” – Japanese Proverb (Highlights the importance of getting back up after failure)
*   “Success is not final, failure is not fatal: It is the courage to continue that counts.” – Winston Churchill (Emphasizes resilience and not giving up)
*   “The difference between ordinary and extraordinary is that little extra.” – Jimmy Johnson (Suggests pushing beyond the expected)
*   “Don't watch the clock; do what it does. Keep going.” – Sam Levenson (Focuses on consistent action)


**2. On Self-Love & Confidence:**

*   “You are enough.” – Brené Brown (A simple, powerful reminder of self-worth)
*   “Believe you can and you’re halfway there.” – Theodore Roosevelt (The power of positive self-belie

## Use the model directly

Here's an example of predicting a single token, directly calling the model.

The model input expectes encoded tokens. For this, we first need to encode the prompt with our tokenizer. See our [tokenizer](https://gemma-llm.readthedocs.io/en/latest/tokenizer.html) documentation for more information on using the tokenizer.

In [7]:
tokenizer = gm.text.Gemma3Tokenizer()

Note: When encoding the prompt, don't forget to add the beginning-of-string token with `add_bos=True`. All prompts feed to the model should start by this token.

In [8]:
prompt = tokenizer.encode('One word to describe Paris: \n\n', add_bos=True)
prompt = jnp.asarray(prompt)

We then can call the model, and get the predicted logits.

In [9]:
# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    return_last_only=True,  # Only predict the last token
)


# Sample a token from the predicted logits
next_token = jax.random.categorical(
    jax.random.key(1),
    out.logits
)
tokenizer.decode(next_token)

'Romantic'

You can also display the next token probability.

In [11]:
tokenizer.plot_logits(out.logits)

## Next steps

* See our [multimodal](https://gemma-llm.readthedocs.io/en/latest/multimodal.html) example to query the model with images.
* See our [finetuning](https://gemma-llm.readthedocs.io/en/latest/finetuning.html) example to train Gemma on your custom task.
* See our [tool use](https://gemma-llm.readthedocs.io/en/latest/tool_use.html) tutorial to extend Gemma with external tools.
