# QLoRA (Finetuning)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/qlora_finetuning.ipynb)

This is an example on fine-tuning Gemma with QLoRA (Quantized Low-Rank Adaptation). It builds on the [LoRA finetuning](https://gemma-llm.readthedocs.io/en/latest/lora_finetuning.html) tutorial, so it's recommended to read that first.

QLoRA combines the parameter-efficient fine-tuning of LoRA with model weight quantization, reducing memory requirements significantly while maintaining performance. This allows for fine-tuning larger models on consumer hardware.

In [None]:
!pip install -q gemma

In [None]:
# Common imports
import os
import optax
import treescope

# Gemma imports
from kauldron import kd
from gemma import gm

By default, Jax does not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):

In [None]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

## Config updates

Like regular LoRA, QLoRA requires 3 main changes to the trainer configuration. The key difference is using `QLoRA` instead of `LoRA` and specifying a quantization method.

For an end-to-end example, see [qlora.py](https://github.com/google-deepmind/gemma/tree/main/examples/qlora.py) config.

### 1. Model

Wrap the model in the `gm.nn.QLoRA`. This will apply model surgery to replace all the linear and compatible layers with quantized versions that have LoRA adapters.

In [None]:
# Model
model = gm.nn.QLoRA(
    rank=8,  # QLoRA typically uses higher rank than standard LoRA
    quant_method=gm.peft.QuantizationMethod.INT4,  # 4-bit quantization
    model=gm.nn.Gemma3_4B(
        tokens="batch.input",  # This is critical - matches how tokens are named in the batch
        text_only=True,
    ),
)

Internally, this uses the [`gemma.peft`](https://github.com/google-deepmind/gemma/blob/main/gemma/peft) mini-library to perform model surgery with quantization.

### 2. Checkpoint

Just like with LoRA, wrap the init transform in a `gm.ckpts.SkipLoRA`. The wrapper is required because the param structure with QLoRA is different from the original model.

In [None]:
# Checkpoint
init_transform = gm.ckpts.SkipLoRA(
    wrapped=gm.ckpts.LoadCheckpoint(
        path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
    ),
)

Note: If you're loading the weights directly with `gm.ckpts.load_params`, you can use the `peft.split_params` and `peft.merge_params` instead, similar to the LoRA approach.

### 3. Optimizer

Similar to LoRA, add a mask to the optimizer so only the LoRA weights are trained. With QLoRA, it's common to use a lower learning rate.

In [None]:
# Optimizer
optimizer = kd.optim.partial_updates(
    optax.adafactor(learning_rate=1e-4),  # Lower learning rate for QLoRA
    # We only optimize the LoRA weights. The rest of the model is frozen.
    mask=kd.optim.select("lora"),
)

## Training

### Data pipeline

The data pipeline setup is identical to the regular LoRA approach:

In [None]:
# Create tokenizer
tokenizer = gm.text.Gemma3Tokenizer()

tokenizer.encode('This is an example sentence', add_bos=True)

In [None]:
# Create dataset
ds = kd.data.py.Tfds(
    name='mtnt/en-fr',
    split='train',
    shuffle=True,
    batch_size=8,
    transforms=[
        # Create the model inputs/targets/loss_mask.
        gm.data.Seq2SeqTask(
            # Select which field from the dataset to use.
            # https://www.tensorflow.org/datasets/catalog/mtnt
            in_prompt='src',
            in_response='dst',
            # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
            out_input='input',
            out_target='target',
            out_target_mask='loss_mask',
            tokenizer=tokenizer,
            # Padding parameters
            max_length=200,
            truncate=True,
        ),
    ],
)

ex = ds[0]

treescope.show(ex)

We can decode an example from the batch to inspect the model input and check it is properly formatted:

In [None]:
text = tokenizer.decode(ex['input'][0])

print(text)

### Trainer

Create the trainer, reusing the `model`, `init_transform` and `optimizer` defined above:

In [None]:
# Create trainer
trainer = kd.train.Trainer(
    seed=42,
    workdir='/tmp/ckpts',
    # Dataset
    train_ds=ds,
    # Model
    model=model,
    init_transform=init_transform,
    # Training parameters
    num_train_steps=500,
    train_losses={
        "loss": kd.losses.SoftmaxCrossEntropyWithIntLabels(
            logits="preds.logits",
            labels="batch.target",
            mask="batch.loss_mask",
        ),
    },
    optimizer=optimizer,
)

Training can be launched with the `.train()` method. Note that the trainer and model are immutable, so they do not store state or parameters. Instead, the state containing the trained parameters is returned.

In [None]:
# Start training
state, aux = trainer.train()

## Evaluation

After training, we can evaluate the model by sampling a prompt to translate from English to French:

In [None]:
# Create a sampler using the trained parameters
sampler = gm.text.ChatSampler(
    model=model,
    params=state.params,
    tokenizer=tokenizer,
)

In [None]:
# Test the sampler with a new prompt
sampler.chat("I'm feeling happy today!")

## How QLoRA Works

QLoRA combines two key techniques:

1. **Quantization**: Reduces the precision of the model's weights (to 4-bit in our case)
2. **Low-Rank Adaptation**: Adds small trainable matrices to adapt the quantized model

### Under the Hood

The implementation handles two distinct phases differently:

- **Training**: During initialization, quantized weights and LoRA adapters are created
- **Evaluation**: Special handling bypasses adapter creation when not in initialization mode

This approach avoids RNG key issues during evaluation while maintaining proper parameter structure during training.

## QLoRA vs LoRA: Memory Comparison

QLoRA offers significant memory savings compared to regular LoRA, especially for larger models. A rough comparison:

| Model Size | Full Fine-tuning | LoRA | QLoRA (INT4) |
|------------|-----------------|------|-------------|
| 4B         | ~8 GB           | ~5 GB | ~3 GB        |
| 12B        | ~24 GB          | ~10 GB| ~5 GB        |

These are approximate values and actual memory usage depends on sequence length, batch size, and specific hardware/framework implementation.

## Saving and Loading QLoRA Weights

QLoRA weights can be saved and loaded similar to LoRA weights. Here's an example of how to save just the adapter weights:

In [None]:
# Example of how to save only the LoRA weights
import os

# Split model parameters to isolate LoRA weights
_, lora_params = gm.peft.split_params(state.params)

# Save only the LoRA parameters (much smaller size)
save_path = "/tmp/my_qlora_weights"
os.makedirs(save_path, exist_ok=True)
gm.ckpts.save_params(
    params=lora_params,
    path=save_path,
)

And to load the saved weights:

In [None]:
# Example of loading the saved QLoRA weights for inference
import jax

# Load base model parameters
base_params = gm.ckpts.load_params(
    path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
)

# Load LoRA parameters
lora_params = gm.ckpts.load_params(
    path="/tmp/my_qlora_weights",
)

# Merge parameters for inference
merged_params = gm.peft.merge_params(base_params, lora_params)

# Create a new sampler with the merged parameters
inference_sampler = gm.text.ChatSampler(
    model=model,
    params=merged_params,
    tokenizer=tokenizer,
)

## Conclusion

QLoRA provides an excellent balance between memory efficiency and fine-tuning performance. By quantizing the frozen base model weights, we can dramatically reduce memory usage while still maintaining the benefits of parameter-efficient fine-tuning.

This approach is particularly valuable when working with larger models or when computation resources are limited.