# QLoRA Demo Notebook

In this tutorial, we fine-tune the Qwen3 2B and 14B model using Low Rank Adaptation(LoRA), a parameter-efficient way of finetuning LLMs.

LoRA works by freezing the original weights of the pre-trained model and
injecting trainable low-rank matrices into each layer of the Transformer
architecture. During fine-tuning, only these newly introduced low-rank matrices
are updated, greatly decreasing the computational and memory resources required
compared to traditional full fine-tuning. This approach is based on the
observation that the changes in model weights needed for adaptation often have a
low rank. The benefits of using LoRA include reduced GPU memory usage, faster
training times, and the advantage that, after training, the LoRA adapters can be
merged with the original model weights, resulting in no additional inference
latency.

## Install necessary libraries

In [None]:
!pip install -q kagglehub

!pip install -q tensorflow
!pip install -q tensorflow_datasets
!pip install -q tensorboardX
!pip install -q grain
!pip install -q git+https://github.com/google/tunix
!pip install -q git+https://github.com/google/qwix

!pip uninstall -q -y flax
!pip install -q git+https://github.com/google/flax.git

!pip install -q datasets

In [None]:
!pip install "jax[tpu]==0.7.1" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/libtpu_releases.html


NameError: name 'os' is not defined

In [1]:
import os

os.environ['TPU_LIBRARY_PATH'] = '/home/linchai_google_com/miniconda3/envs/qwen/lib/python3.12/site-packages/libtpu/libtpu.so'

In [4]:
# If you want to upload your metrics to Weights & Biases, please install the package and login. Make sure to install `wandb` before importing `tunix`.
!pip install wandb

import wandb

wandb.login()



[34m[1mwandb[0m: Currently logged in as: [33mlinchai[0m ([33mlinchai-google[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [5]:
!pip install safetensors
!pip install transformers



## Hyperparameters

In [2]:
# Data
BATCH_SIZE = 16

# Model
MESH = [(1, 1), ("fsdp", "tp")]
# LoRA
RANK = 16
ALPHA = 2.0

# Train
MAX_STEPS = 100
EVAL_EVERY_N_STEPS = 20
NUM_EPOCHS = 3


# Checkpoint saving
INTERMEDIATE_CKPT_DIR = "/tmp/content/intermediate_ckpt/"
CKPT_DIR = "/tmp/content/ckpts/"
PROFILING_DIR = "/tmp/content/profiling/"

In [3]:
import os
import logging
import sys
def create_dir(path):
  try:
    os.makedirs(path, exist_ok=True)
    logging.info(f"Created dir: {path}")
  except OSError as e:
    logging.error(f"Error creating directory '{path}': {e}")


create_dir(INTERMEDIATE_CKPT_DIR)
create_dir(CKPT_DIR)
create_dir(PROFILING_DIR)

# Download the weights from Kaggle

In [4]:
import os
import kagglehub

# Log in
if "KAGGLE_USERNAME" not in os.environ or "KAGGLE_KEY" not in os.environ:
  kagglehub.login()

# alternatively place kaggle.json under ~/.kaggle/

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

In [5]:
import jax
mesh = jax.make_mesh(*MESH)
mesh


Mesh(axis_sizes=(1, 1), axis_names=('fsdp', 'tp'), axis_types=(Auto, Auto))

In [6]:

from flax import nnx
import kagglehub
from tunix.models.qwen3 import model
from tunix.models.qwen3 import params

MODEL_CP_PATH = kagglehub.model_download("qwen-lm/qwen-3/transformers/0.6b")

config = (
    model.ModelConfig.qwen3_0_6b()
)  # pick correponding config based on model version
qwen3 = params.create_model_from_safe_tensors(MODEL_CP_PATH, config, mesh)
nnx.display(qwen3)

Traceback (most recent call last):
  File "/home/linchai_google_com/miniconda3/envs/qwen/lib/python3.12/site-packages/treescope/renderers.py", line 290, in _render_subtree
    maybe_result = handler(node=node, path=path, subtree_renderer=rec)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linchai_google_com/miniconda3/envs/qwen/lib/python3.12/site-packages/treescope/_internal/handlers/custom_type_handlers.py", line 65, in handle_via_treescope_repr_method
    return treescope_repr_method(path, subtree_renderer)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/linchai_google_com/miniconda3/envs/qwen/lib/python3.12/site-packages/flax/nnx/pytreelib.py", line 702, in __treescope_repr__
    if name.startswith('_'):
       ^^^^^^^^^^^^^^^
AttributeError: 'int' object has no attribute 'startswith'



In [7]:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(MODEL_CP_PATH)

In [8]:
def templatize(prompts):
  out = []
  for p in prompts:
    out.append(
        tokenizer.apply_chat_template(
            [
                {"role": "user", "content": p},
            ],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
        )
    )
  return out

In [9]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

In [None]:
from tunix.generate import sampler

inputs = templatize([
    "which is larger 9.9 or 9.11?",
    "如何制作月饼?",
    "tell me your name, respond in Chinese",
])

sampler = sampler.Sampler(
    qwen3,
    tokenizer,
    sampler.CacheConfig(
        cache_size=256, num_layers=28, num_kv_heads=8, head_dim=128
    ),
)
out = sampler(inputs, max_generation_steps=128, echo=True)

for t in out.text:
  print(t)
  print("*" * 30)

<|im_start|>user
which is larger 9.9 or 9.11?<|im_end|>
<|im_start|>assistant
<think>
Okay, so I need to figure out which number is larger between 9.9 and 9.11. Let me think. Both numbers are in decimal form, right? 9.9 and 9.11. Hmm, decimal numbers can be tricky sometimes, but I remember that when comparing decimals, you can look at the digits from left to right, starting with the first non-zero digit. 

First, let me write them down to visualize better: 9.9 and 9.11. Both start with a 9. So, the first digit after the decimal is the tenths
******************************
<|im_start|>user
如何制作月饼?<|im_end|>
<|im_start|>assistant
<think>
好的，用户问如何制作月饼。首先，我需要确定用户的需求是什么。可能他们想了解基本的步骤，或者有特定的口味偏好，比如是否需要使用模具，或者是否需要特别的装饰。用户可能对月饼的制作流程不太熟悉，或者有时间限制，需要快速指导。

接下来，我应该考虑月饼的种类。用户可能想知道不同种类的制作方法，比如传统的中式月饼和一些创新的口味。但问题中没有提到，所以可能需要先给出通用的步骤，然后询问是否需要更详细的信息。

然后，我需要确保回答的结构清晰，分
******************************
<|im_start|>user
tell me your name, respond in Chinese<|im_end|>
<|im_start|>assistant
<think>
好的，用户让我告诉他

In [None]:
import functools
import humanize
def show_hbm_usage():
  """Displays memory usage per device."""
  fmt_size = functools.partial(humanize.naturalsize, binary=True)

  for d in jax.local_devices():
    stats = d.memory_stats()
    used = stats["bytes_in_use"]
    limit = stats["bytes_limit"]
    print(f"Using {fmt_size(used)} / {fmt_size(limit)} ({used/limit:%}) on {d}")

In [16]:
show_hbm_usage()

Using 1.5 GiB / 7.5 GiB (19.534145%) on TPU_0(process=0,(0,0,0,0))
Using 16.0 KiB / 7.5 GiB (0.000204%) on TPU_1(process=0,(0,0,0,1))
Using 16.0 KiB / 7.5 GiB (0.000204%) on TPU_2(process=0,(1,0,0,0))
Using 16.0 KiB / 7.5 GiB (0.000204%) on TPU_3(process=0,(1,0,0,1))
Using 16.0 KiB / 7.5 GiB (0.000204%) on TPU_4(process=0,(0,1,0,0))
Using 16.0 KiB / 7.5 GiB (0.000204%) on TPU_5(process=0,(0,1,0,1))
Using 16.0 KiB / 7.5 GiB (0.000204%) on TPU_6(process=0,(1,1,0,0))
Using 16.0 KiB / 7.5 GiB (0.000204%) on TPU_7(process=0,(1,1,0,1))


## Apply LoRA/QLoRA to the model

In [14]:
import qwix
def get_lora_model(base_model, mesh):
  lora_provider = qwix.LoraProvider(
      module_path=".*q_einsum|.*kv_einsum|.*gate_proj|.*down_proj|.*up_proj",
      rank=RANK,
      alpha=ALPHA,
      # comment the two args below for LoRA (w/o quantisation).
      weight_qtype="nf4",
      tile_size=256,
  )

  model_input = base_model.get_model_input()
  lora_model = qwix.apply_lora_to_model(
      base_model, lora_provider, **model_input
  )

  with mesh:
    state = nnx.state(lora_model)
    pspecs = nnx.get_partition_spec(state)
    sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
    nnx.update(lora_model, sharded_state)

  return lora_model

In [15]:
# LoRA model
lora_qwen3 = get_lora_model(qwen3, mesh=mesh)
nnx.display(lora_qwen3)

## Load Datasets for SFT Training

In [16]:
# Loads the training and validation datasets

from tunix.examples.data import translation_dataset as data_lib
from tunix.rl import common
from tunix.sft import peft_trainer

train_ds, validation_ds = data_lib.create_datasets(
    dataset_name='mtnt/en-fr',
    # Uncomment the line below to use a Hugging Face dataset.
    # Note that this requires upgrading the 'datasets' package and restarting
    # the Colab runtime.
    # dataset_name='Helsinki-NLP/opus-100',
    global_batch_size=BATCH_SIZE,
    max_target_length=256,
    num_train_epochs=NUM_EPOCHS,
    tokenizer=tokenizer,
)


def gen_model_input_fn(x: peft_trainer.TrainingInput):
  pad_mask = x.input_tokens != tokenizer.pad_id()
  positions = common.build_positions_from_mask(pad_mask)
  attention_mask = common.make_causal_attn_mask(pad_mask)
  return {
      'input_tokens': x.input_tokens,
      'input_mask': x.input_mask,
      'positions': positions,
      'attention_mask': attention_mask,
  }

ModuleNotFoundError: No module named 'tensorflow_datasets'

## SFT Training

In [None]:
from tunix.sft import metrics_logger

import optax

logging_option = metrics_logger.MetricsLoggerOptions(
    log_dir="/tmp/tensorboard/full", flush_every_n_steps=20
)
training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    metrics_logging_options=logging_option,
)
trainer = peft_trainer.PeftTrainer(qwen3, optax.adamw(1e-5), training_config)
trainer = trainer.with_gen_model_input_fn(gen_model_input_fn)

with jax.profiler.trace(os.path.join(PROFILING_DIR, "full_training")):
  with mesh:
    trainer.train(train_ds, validation_ds)

### Training with LoRA/QLoRA

In [None]:
# Since LoRA model is sharing backbone with base model,
# restart Colab runtime so base model is loaded as pre-trained.

training_config = peft_trainer.TrainingConfig(
    eval_every_n_steps=EVAL_EVERY_N_STEPS,
    max_steps=MAX_STEPS,
    checkpoint_root_directory=CKPT_DIR,
)
lora_trainer = peft_trainer.PeftTrainer(
    lora_gemma, optax.adamw(1e-3), training_config
).with_gen_model_input_fn(gen_model_input_fn)

with jax.profiler.trace(os.path.join(PROFILING_DIR, "peft")):
  with mesh:
    lora_trainer.train(train_ds, validation_ds)

## Generate with the LoRA/QLoRA model

In [None]:
from tunix.generate import sampler

inputs = templatize([
    "which is larger 9.9 or 9.11?",
    "如何制作月饼?",
    "tell me your name, respond in Chinese",
])

sampler = sampler.Sampler(
    lora_qwen3,
    tokenizer,
    sampler.CacheConfig(
        cache_size=256, num_layers=28, num_kv_heads=8, head_dim=128
    ),
)
out = sampler(inputs, max_generation_steps=128, echo=True)

for t in out.text:
  print(t)
  print("*" * 30)