# Quantization (Sampling)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/quantization_sampling.ipynb)

Example on using quantization with Gemma (for inference). For an example of quantization aware training (QAT), see [QAT finetuning](https://github.com/google-deepmind/gemma/blob/main/docs/quantization_aware_training.md) example.

In [None]:
!pip install -q gemma

In [None]:
# Common imports
import os
import jax
import jax.numpy as jnp
import treescope

# Gemma imports
from gemma import gm
from gemma import peft  # Parameter fine-tuning module

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):

In [None]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

## Initializing the model

To use Gemma with quantization, simply wrap any Gemma model in `gm.nn.IntWrapper` (in the example for int8 weight inference):

In [None]:
model = gm.nn.IntWrapper(model=gm.nn.Gemma3_4B(text_only=True), dtype=jnp.int8)

Initialize the weights:

In [None]:
token_ids = jnp.zeros((1, 256,), dtype=jnp.int32)  # Create the (batch_size, seq_length)

params = model.init(jax.random.key(0), token_ids)

params = params['params']



```
# This is formatted as code
```

Restore the pre-trained params. We use `peft.quantize` to quantize the checkpoint.


In [None]:
del params
# Load the params from the checkpoint
original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)
# edit params`
params = peft.quantize(original, method='INT8', checkpoint_kernel_key='w')

## Fine-tuning

See our [finetuning guide](https://github.com/google-deepmind/gemma/blob/main/docs/lora_finetuning.md) for more info.

## Inference

Here's an example of running a single model call:

In [None]:
tokenizer = gm.text.Gemma3Tokenizer()

prompt = tokenizer.encode('The capital of France is')
prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)


# Run the model
out = model.apply(
    {'params': params},
    tokens=prompt,
    return_last_only=True,  # Only predict the last token
)


# Show the token distribution
tokenizer.plot_logits(out.logits)

To sample an entire sentence:

In [None]:
sampler = gm.text.Sampler(
    model=model,
    params=params,
    tokenizer=tokenizer,
)

sampler.sample('The capital of France is', max_new_tokens=30)

' Paris.\n\nParis is a global center for art, fashion, gastronomy, and culture. It is known for its iconic landmarks such as the Eiffel Tower'