# QLoRA (Finetuning)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/qlora_finetuning.ipynb)

This is an example on fine-tuning Gemma with QLoRA (Quantized Low-Rank Adaptation). It builds on the [LoRA finetuning](https://gemma-llm.readthedocs.io/en/latest/lora_finetuning.html) tutorial, so it's recommended to read that first.

QLoRA combines the parameter-efficient fine-tuning of LoRA with model weight quantization, reducing memory requirements significantly while maintaining performance. This allows for fine-tuning larger models on consumer hardware.

In [None]:
!pip install -q gemma

In [ ]:
# Common imports
import os
import optax
import treescope

# Gemma imports
from kauldron import kd
from gemma import gm

By default, Jax does not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):

In [None]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

## Config updates

Like regular LoRA, QLoRA requires 3 main changes to the trainer configuration. The key difference is using `QLoRA` instead of `LoRA` and specifying a quantization method.

For an end-to-end example, see [qlora.py](https://github.com/google-deepmind/gemma/tree/main/examples/qlora.py) config.

### 1. Model

Wrap the model in the `gm.nn.QLoRA`. This will apply model surgery to replace all the linear and compatible layers with quantized versions that have LoRA adapters.

In [ ]:
model = gm.nn.QLoRA(
    rank=8,  # QLoRA typically uses higher rank than standard LoRA
    quant_method=gm.peft.QuantizationMethod.INT4,  # 4-bit quantization
    model=gm.nn.Gemma3_4B(
        tokens="batch.input",
        text_only=True,  # Important: Make sure the model is text-only
    ),
)

Internally, this uses the [`gemma.peft`](https://github.com/google-deepmind/gemma/blob/main/gemma/peft) mini-library to perform model surgery with quantization.

### 2. Checkpoint

Just like with LoRA, wrap the init transform in a `gm.ckpts.SkipLoRA`. The wrapper is required because the param structure with QLoRA is different from the original model.

In [ ]:
# Standard approach using SkipLoRA
# Use WithRngKeys to ensure it has the right RNG keys
from kauldron import ckpts

# Create a wrapper that adds the right RNG keys
class WithRngKeys(ckpts.AbstractPartialLoader):
    def __init__(self, wrapped):
        self.wrapped = wrapped
        
    def transform(self, state):
        # Apply the wrapped transform
        state = self.wrapped.transform(state)
        return state

init_transform = WithRngKeys(
    wrapped=gm.ckpts.SkipLoRA(
        wrapped=gm.ckpts.LoadCheckpoint(
            path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
        ),
    ),
)

Note: If you're loading the weights directly with `gm.ckpts.load_params`, you can use the `peft.split_params` and `peft.merge_params` instead, similar to the LoRA approach.

### 3. Optimizer

Similar to LoRA, add a mask to the optimizer so only the LoRA weights are trained. With QLoRA, it's common to use a lower learning rate.

In [None]:
optimizer = kd.optim.partial_updates(
    optax.adafactor(learning_rate=1e-4),  # Lower learning rate for QLoRA
    # We only optimize the LoRA weights. The rest of the model is frozen.
    mask=kd.optim.select("lora"),
)

## Training

### Data pipeline

The data pipeline setup is identical to the regular LoRA approach:

In [None]:
tokenizer = gm.text.Gemma3Tokenizer()

tokenizer.encode('This is an example sentence', add_bos=True)

In [None]:
ds = kd.data.py.Tfds(
    name='mtnt/en-fr',
    split='train',
    shuffle=True,
    batch_size=8,
    transforms=[
        # Create the model inputs/targets/loss_mask.
        gm.data.Seq2SeqTask(
            # Select which field from the dataset to use.
            # https://www.tensorflow.org/datasets/catalog/mtnt
            in_prompt='src',
            in_response='dst',
            # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
            out_input='input',
            out_target='target',
            out_target_mask='loss_mask',
            tokenizer=tokenizer,
            # Padding parameters
            max_length=200,
            truncate=True,
        ),
    ],
)

ex = ds[0]

treescope.show(ex)

We can decode an example from the batch to inspect the model input and check it is properly formatted:

In [None]:
text = tokenizer.decode(ex['input'][0])

print(text)

### Trainer

Create the trainer, reusing the `model`, `init_transform` and `optimizer` defined above:

In [ ]:
trainer = kd.train.Trainer(
    seed=42,  # The seed of enlightenment
    workdir='/tmp/ckpts',
    # Dataset
    train_ds=ds,
    # Model
    model=model,
    init_transform=init_transform,
    # Training parameters
    num_train_steps=500,
    train_losses={
        "loss": kd.losses.SoftmaxCrossEntropyWithIntLabels(
            logits="preds.logits",
            labels="batch.target",
            mask="batch.loss_mask",
        ),
    },
    optimizer=optimizer,
)

# Let's create a more targeted debugging approach based on quantization patterns
import jax
import jax.numpy as jnp
import logging
logging.basicConfig(level=logging.INFO)

# First, let's understand how the quantization init works
print("Creating standard model (without LoRA) for reference")
standard_model = gm.nn.Gemma3_4B(tokens="batch.input", text_only=True)

print("Creating QLoRA model")
qlora_model = gm.nn.QLoRA(
    rank=8,
    quant_method=gm.peft.QuantizationMethod.INT4,
    model=gm.nn.Gemma3_4B(tokens="batch.input", text_only=True),
)

print("Step 1: Loading checkpoint directly for reference")
try:
    # Load original parameters first for reference
    original_params = gm.ckpts.load_params(
        path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
    )
    print("Successfully loaded original checkpoint")
except Exception as e:
    print(f"Error loading original checkpoint: {str(e)}")
    raise

print("Step 2: Initializing the QLoRA model")
try:
    # Create dummy data for initialization
    dummy_ids = tokenizer.encode("Test", add_bos=True)
    dummy_input = jnp.array([dummy_ids])
    
    # Initialize the QLoRA model
    variables = qlora_model.init(
        jax.random.PRNGKey(42),
        tokens=dummy_input,
    )
    params = variables["params"]
    print("Successfully initialized QLoRA model")
except Exception as e:
    print(f"Error initializing QLoRA model: {str(e)}")
    raise

print("Step 3: Separating QLoRA parameters")
try:
    # Split the QLoRA parameters
    base_params, lora_params = gm.peft.split_params(params)
    print("Successfully split parameters")
except Exception as e:
    print(f"Error splitting parameters: {str(e)}")
    raise

print("Step 4: Restoring base parameters while preserving structure")
try:
    # Get the structure of original params that matches our base params
    restored_params = {}
    for key, subtree in original_params.items():
        if key in base_params:
            restored_params[key] = subtree
    
    print("Successfully prepared parameters for merging")
except Exception as e:
    print(f"Error preparing parameters: {str(e)}")
    raise

print("Step 5: Merging with LoRA parameters")
try:
    # Merge the parameters
    final_params = gm.peft.merge_params(restored_params, lora_params)
    print("Successfully merged parameters")
except Exception as e:
    print(f"Error merging parameters: {str(e)}")
    raise

print("All debug steps completed!")
print("\nNote: This debugging doesn't actually use the init_transform with SkipLoRA,")
print("but demonstrates how the manual parameter loading and merging should work.")

In [ ]:
# Manual approach for QLoRA initialization and training
import jax
import dataclasses
from kauldron import random

# Create proper RNG streams to use for module initialization
rng_streams = random.RngStreams(seed=42)
rngs = rng_streams.init_rngs()

# Initialize model with dummy data
dummy_ids = tokenizer.encode("Test", add_bos=True)
dummy_input = jnp.array([dummy_ids])

# Initialize QLoRA model
variables = model.init(
    rngs,  # Use the proper RNGs with all required streams
    tokens=dummy_input,
)
params = variables["params"]

# Split parameters
original_params, lora_params = gm.peft.split_params(params)

# Load original parameters
checkpoint_params = gm.ckpts.load_params(
    path=gm.ckpts.CheckpointPath.GEMMA3_4B_IT,
)

# Get common keys
restored_params = {}
for key in original_params:
    if key in checkpoint_params:
        restored_params[key] = checkpoint_params[key]

# Merge with LoRA parameters
final_params = gm.peft.merge_params(restored_params, lora_params)

# Create a custom init_transform that returns our prepared parameters
# and preserves the RNGs needed for proper initialization
@dataclasses.dataclass
class CustomInitTransform:
    def transform(self, state):
        # Return state with our custom parameters
        return state.replace(params=final_params)

# Create trainer with our custom init_transform and full rng_streams
manual_trainer = kd.train.Trainer(
    seed=42,  # Use the same seed for consistent behavior
    workdir='/tmp/ckpts',
    train_ds=ds,
    model=model,
    init_transform=CustomInitTransform(),
    num_train_steps=500,
    train_losses={
        "loss": kd.losses.SoftmaxCrossEntropyWithIntLabels(
            logits="preds.logits",
            labels="batch.target",
            mask="batch.loss_mask",
        ),
    },
    optimizer=optimizer,
    rng_streams=rng_streams,  # Explicitly provide RNG streams
)

# Try training with our manual approach
try:
    state, aux = manual_trainer.train()
    print("Training succeeded with manual parameter loading!")
except Exception as e:
    print(f"Error during training with manual loading: {str(e)}")
    # Print RNG keys to help diagnose issues
    print(f"Available RNG keys: {rngs.keys()}")
    raise

## Evaluation

Let's test our fine-tuned model with a sample input:

In [None]:
sampler = gm.text.ChatSampler(
    model=model,
    params=state.params,
    tokenizer=tokenizer,
)

We test a sentence, using the same formatting used during fine-tuning:

In [None]:
sampler.chat('I\'m feeling happy today!')

## QLoRA vs LoRA: Memory Comparison

QLoRA offers significant memory savings compared to regular LoRA, especially for larger models. A rough comparison:

| Model Size | Full Fine-tuning | LoRA | QLoRA (INT4) |
|------------|-----------------|------|-------------|
| 4B         | ~8 GB           | ~5 GB | ~3 GB        |
| 12B        | ~24 GB          | ~10 GB| ~5 GB        |

These are approximate values and actual memory usage depends on sequence length, batch size, and specific hardware/framework implementation.

## Conclusion

QLoRA provides an excellent balance between memory efficiency and fine-tuning performance. By quantizing the frozen base model weights, we can dramatically reduce memory usage while still maintaining the benefits of parameter-efficient fine-tuning.

This approach is particularly valuable when working with larger models or when computation resources are limited.