# Quantization Aware Training (QAT)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/quantization_aware_training.ipynb)

This is an example on how to obtain and run quantized versions of Gemma models. It's best to first read the [finetuning](https://github.com/google-deepmind/gemma/blob/main/docs/finetuning.md) colab to understand this one.



In [1]:
!pip install -q gemma

In [2]:
# Common imports
import os
import treescope
import optax

# Gemma imports
from kauldron import kd
from gemma import gm
from gemma import peft

By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):

In [3]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

## Config updates

If you're familiar with the [finetuning](https://github.com/google-deepmind/gemma/blob/main/docs/finetuning.md) tutorial, switching to QAT only require 1 change to the trainer.

This is slightly different to LoRA (we discuss the difference below)

### 1. Model

Wrap the model in the `gm.nn.QuantizationAwareTrainingWrapper`. This will apply model surgery to replace all the linear and compatible layers with Simulation for Quantized layers. You can choose among several options for quantization:

* SFP8: switched floating point in 8 bits (very efficient with gemma.cpp)
* Q4_0: per-block integer quantization (equivalent to 4.5 bits per weights), very popular on llama.cpp
* INT4: per-channel weight quantization (almost exactly 4 bits per weights)

In [4]:
model = gm.nn.QuantizationAwareWrapper(
    method = peft.QuantizationMethod.INT8,
    model=gm.nn.Gemma3_4B(tokens="batch.input", text_only=True),
)

Internally, this uses the [`gemma.peft`](https://github.com/google-deepmind/gemma/blob/main/gemma/peft) mini-library to perform model surgery.

In [5]:
init_transform = wrapped=gm.ckpts.LoadCheckpoint(
    path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
)

## Training

### Data pipeline

Like for the [finetuning](https://github.com/google-deepmind/gemma/blob/main/docs/finetuning.md) example, we recreate the tokenizer:

In [6]:
tokenizer = gm.text.Gemma3Tokenizer()

tokenizer.encode('This is an example sentence', add_bos=True)

[<_Gemma3SpecialTokens.BOS: 2>, 2094, 563, 614, 2591, 13315]

And the data pipeline:

In [7]:
ds = kd.data.py.Tfds(
    name='mtnt/en-fr',
    split='train',
    shuffle=True,
    batch_size=8,
    transforms=[
        # Create the model inputs/targets/loss_mask.
        gm.data.Seq2SeqTask(
            # Select which field from the dataset to use.
            # https://www.tensorflow.org/datasets/catalog/mtnt
            in_prompt='src',
            in_response='dst',
            # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
            out_input='input',
            out_target='target',
            out_target_mask='loss_mask',
            tokenizer=tokenizer,
            # Padding parameters
            max_length=200,
            truncate=True,
        ),
    ],
)

ex = ds[0]

treescope.show(ex)

We can decode an example from the batch to inspect the model input and check it is properly formatted:

In [8]:
text = tokenizer.decode(ex['input'][0])

print(text)

<start_of_turn>user
Is this a good place to ask about the ethnicity and intelligence debate?<end_of_turn>
<start_of_turn>model
Est-ce un bon endroit pour poser des questions sur le débat à propos de l'ethnicité et le renseignement ?<end_of_turn>


### Trainer

We then create the trainer, reusing the `model`, `init_transform` and `optimizer` created above:

In [9]:
trainer = kd.train.Trainer(
    seed=42,  # The seed of enlightenment
    workdir='/tmp/ckpts',  # TODO(epot): Make the workdir optional by default
    # Dataset
    train_ds=ds,
    # Model
    model=model,
    init_transform=init_transform,
    # Training parameters
    num_train_steps=500,
    train_losses={
        "loss": kd.losses.SoftmaxCrossEntropyWithIntLabels(
            logits="preds.logits",
            labels="batch.target",
            mask="batch.loss_mask",
        ),
    },
    optimizer=optax.adafactor(learning_rate=0.005),
)

Trainning can be launched with the `.train()` method.

Note that the trainer like the model are immutables, so it does not store the state nor params. Instead the state containing the trained parameters is returned.

In [10]:
state, aux = trainer.train()

Configuring ...
Initializing ...
Disabling pygrain multi-processing (unsupported in colab).


Starting training loop at step 0


train:   0%|          | 0/501 [00:00<?, ?it/s]

## Inference

In order to infer the model, you have two options:

1. simply evaluate the `QATWrapper`: that does not provide any memory footprint reduction
2. use the `IntWrapper` as follows (only available for INT8 quantization)

In [11]:
quantized_model = gm.nn.IntWrapper(model=gm.nn.Gemma3_4B(tokens="batch.input", text_only=True))
quantized_params = peft.quantize(state.params, method=peft.QuantizationMethod.INT8)

then evaluate

In [12]:
sampler = gm.text.Sampler(
    model=quantized_model,
    params=quantized_params,
    tokenizer=tokenizer,
)

prompt = """\
<start_of_turn>user
I'm feeling happy!<end_of_turn>
<start_of_turn>model
"""

sampler.sample(prompt, max_new_tokens=30)

'Je me sens bien !<end_of_turn>'