# QLoRA Demo Notebook

In this tutorial, we fine-tune the Qwen3 2B and 14B model using Low Rank Adaptation(LoRA), a parameter-efficient way of finetuning LLMs.

LoRA works by freezing the original weights of the pre-trained model and
injecting trainable low-rank matrices into each layer of the Transformer
architecture. During fine-tuning, only these newly introduced low-rank matrices
are updated, greatly decreasing the computational and memory resources required
compared to traditional full fine-tuning. This approach is based on the
observation that the changes in model weights needed for adaptation often have a
low rank. The benefits of using LoRA include reduced GPU memory usage, faster
training times, and the advantage that, after training, the LoRA adapters can be
merged with the original model weights, resulting in no additional inference
latency.

## Install necessary libraries

In [None]:
!pip install -q kagglehub

!pip install -q tensorflow
!pip install -q tensorboardX
!pip install -q grain
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q datasets

In [None]:
# If you want to upload your metrics to Weights & Biases, please install the package and login. Make sure to install `wandb` before importing `tunix`.
!pip install wandb

import wandb

wandb.login()

## Hyperparameters

In [None]:
# Data
BATCH_SIZE = 16

# Model
MESH = [(1, 1), ("fsdp", "tp")]
# LoRA
RANK = 16
ALPHA = 2.0

# Train
MAX_STEPS = 100
EVAL_EVERY_N_STEPS = 20
NUM_EPOCHS = 3


# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
PROFILING_DIR = "/tmp/content/profiling/"

In [None]:
import os
import logging
import sys
def create_dir(path):
  try:
    os.makedirs(path, exist_ok=True)
    logging.info(f"Created dir: {path}")
  except OSError as e:
    logging.error(f"Error creating directory '{path}': {e}")


create_dir(INTERMEDIATE_CKPT_DIR)
create_dir(CKPT_DIR)
create_dir(PROFILING_DIR)

# Download the weights from Kaggle

In [None]:
import os
import kagglehub

# Log in
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()

# alternatively place kaggle.json under ~/.kaggle/

In [None]:
import jax
mesh = jax.make_mesh(*MESH)

In [None]:
from flax import nnx
import kagglehub
from tunix.models.qwen3 import model
from tunix.models.qwen3 import params

MODEL_CP_PATH = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")

config = (
    model.ModelConfig.qwen3_0_6b()
)  # pick correponding config based on model version
qwen3 = params.create_model_from_safe_tensors(MODEL_CP_PATH, config, mesh)
nnx.display(qwen3)

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_CP_PATH)

In [None]:
def templatize(prompts):
  out = []
  for p in prompts:
    out.append(
        tokenizer.apply_chat_template(
            [
                {"role": "user", "content": p},
            ],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
        )
    )
  return out

In [None]:
from tunix.generate import sampler

inputs = templatize([
    "which is larger 9.9 or 9.11?",
    "如何制作月饼?",
    "tell me your name, respond in Chinese",
])

sampler = sampler.Sampler(
    qwen3,
    tokenizer,
    sampler.CacheConfig(
        cache_size=256, num_layers=28, num_kv_heads=8, head_dim=128
    ),
)
out = sampler(inputs, max_generation_steps=128, echo=True)

for t in out.text:
  print(t)
  print("*" * 30)

## Apply LoRA/QLoRA to the model

In [None]:
import qwix
def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
      rank=RANK,
      alpha=ALPHA,
      # comment the two args below for LoRA (w/o quantisation).
      weight_qtype="nf4",
      tile_size=256,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

In [None]:
# LoRA model
lora_qwen3 = get_lora_model(qwen3, mesh=mesh)
nnx.display(lora_qwen3)

## Load Datasets for SFT Training

In [None]:
# Loads the training and validation datasets

from tunix.examples.data import translation_dataset as data_lib
from tunix.rl import common
from tunix.sft import peft_trainer

train_ds, validation_ds = data_lib.create_datasets(
    dataset_name='mtnt/en-fr',
    # Uncomment the line below to use a Hugging Face dataset.
    # Note that this requires upgrading the 'datasets' package and restarting
    # the Colab runtime.
    # dataset_name='Helsinki-NLP/opus-100',
    global_batch_size=BATCH_SIZE,
    max_target_length=256,
    num_train_epochs=NUM_EPOCHS,
    tokenizer=tokenizer,
)


def gen_model_input_fn(x: peft_trainer.TrainingInput):
  pad_mask = x.input_tokens != tokenizer.pad_id()
  positions = common.build_positions_from_mask(pad_mask)
  attention_mask = common.make_causal_attn_mask(pad_mask)
  return {
      'input_tokens': x.input_tokens,
      'input_mask': x.input_mask,
      'positions': positions,
      'attention_mask': attention_mask,
  }

## SFT Training

In [None]:
from tunix.sft import metrics_logger

import optax

logging_option = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/tensorboard/full", flush_every_n_steps=20
)
training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    metrics_logging_options=logging_option,
)
trainer = peft_trainer.PeftTrainer(qwen3, optax.adamw(1e-5), training_config)
trainer = trainer.with_gen_model_input_fn(gen_model_input_fn)

with jax.profiler.trace(os.path.join(PROFILING_DIR, "full_training")):
  with mesh:
    trainer.train(train_ds, validation_ds)

### Training with LoRA/QLoRA

In [None]:
# Since LoRA model is sharing backbone with base model,
# restart Colab runtime so base model is loaded as pre-trained.

training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    checkpoint_root_directory=CKPT_DIR,
)
lora_trainer = peft_trainer.PeftTrainer(
    lora_gemma, optax.adamw(1e-3), training_config
).with_gen_model_input_fn(gen_model_input_fn)

with jax.profiler.trace(os.path.join(PROFILING_DIR, "peft")):
  with mesh:
    lora_trainer.train(train_ds, validation_ds)

## Generate with the LoRA/QLoRA model

In [None]:
from tunix.generate import sampler

inputs = templatize([
    "which is larger 9.9 or 9.11?",
    "如何制作月饼?",
    "tell me your name, respond in Chinese",
])

sampler = sampler.Sampler(
    lora_qwen3,
    tokenizer,
    sampler.CacheConfig(
        cache_size=256, num_layers=28, num_kv_heads=8, head_dim=128
    ),
)
out = sampler(inputs, max_generation_steps=128, echo=True)

for t in out.text:
  print(t)
  print("*" * 30)