# Parameter-Efficient Fine-Tuning of Llama 3.1-8B with LoRA/QLoRA on NVIDIA GPUs using JAX and Tunix

This tutorial walks you through parameter-efficient fine-tuning (PEFT) of Llama 3.1-8B using LoRA and QLoRA on NVIDIA GPUs with JAX, Tunix, and Qwix. Unlike full-parameter SFT, PEFT freezes the base model weights and trains only small adapter matrices, dramatically reducing memory requirements and training time while maintaining model quality.

**What you'll do:**
1. Set up the environment and authenticate with Hugging Face
2. Load the Llama 3.1-8B base model and apply LoRA/QLoRA adapters
3. Prepare the UltraChat 200k dataset for instruction fine-tuning
4. Configure and run parameter-efficient fine-tuning
5. Visualize training metrics with TensorBoard
6. Run a quick inference sanity check

## Preliminaries

### Make sure you have supported hardware

**Hardware requirements.** QLoRA with 4-bit quantization can fine-tune Llama 3.1-8B on a single GPU with **16 GB+ of VRAM**. For LoRA without quantization, 24 GB+ is recommended. Multiple GPUs enable larger batch sizes and faster training through data parallelism; on multi-GPU systems, the model is automatically sharded across devices using FSDP and tensor parallelism.

In [None]:
!nvidia-smi

### Set your Hugging Face token

Create a [Hugging Face](https://huggingface.co/) access token in your Hugging Face account [settings](https://huggingface.co/settings/tokens), copy it, and paste it into the field below. This token is required to authenticate with the Hugging Face Hub and download the Llama 3.1 model and related assets; once saved, it will be reused by this environment for the rest of the tutorial.

In [None]:
import os

from ipywidgets import Password, Button, HBox, Output
from IPython.display import display

try:
    from huggingface_hub import whoami
except Exception:
    from huggingface_hub import HfApi

def _verify_token(token: str) -> str:
    try:
        return whoami(token=token).get("name", "unknown")
    except TypeError:
        return HfApi(token=token).whoami().get("name", "unknown")

token_box = Password(description="HF Token:", placeholder="paste your token here", layout={"width": "400px"})
save_btn = Button(description="Save", button_style="success")
out = Output()

def save_token(_):
    out.clear_output()
    with out:
        existing = os.environ.get("HF_TOKEN")
        entered = token_box.value.strip()
        if existing and not entered:
            user = _verify_token(existing)
            print(f"Using existing HF_TOKEN. Logged in as: {user}")
            return
        if not entered:
            print("No HF token entered.")
            return
        os.environ["HF_TOKEN"] = entered
        user = _verify_token(entered)
        print(f"Token saved. Logged in as: {user}")

save_btn.on_click(save_token)
display(HBox([token_box, save_btn]), out)

### Authenticate with Hugging Face

Verify that your Hugging Face token is set and valid. If the token is missing, an error is raised immediately rather than failing silently during model download.

In [None]:
# Prefer environment variable if already set

from huggingface_hub.v1.hf_api import whoami
HF_TOKEN = os.environ.get("HF_TOKEN")

if HF_TOKEN:
    try:
        user = whoami()["name"]
        print(f"Authenticated with Hugging Face as: {user} (via HF_TOKEN env)")
    except Exception as e:
        print("HF_TOKEN is set but authentication failed:", e)
else:
    raise RuntimeError(
        "HF_TOKEN is not set. Please create a Hugging Face access token "
        "and export it as an environment variable."
    )

### Acquire permission to use the gated model

Llama 3.1-8B is a gated model, so you must explicitly request access before it can be downloaded. Visit the [model page](https://huggingface.co/meta-llama/Llama-3.1-8B) on Hugging Face, log in with the same account linked to your access token, and click **Request access**. You'll need to agree to Meta's license terms; approval is usually granted quickly but is not automatic. Once approved, your Hugging Face token will authorize downloads transparently. If you skip this step, model downloads will fail even with a valid token.

### Set up the environment

### Import dependencies

Import the core libraries needed for training:
- **JAX/Flax**: High-performance ML framework with automatic differentiation and XLA compilation
- **Optax**: Gradient processing and optimization library for JAX
- **Transformers**: Hugging Face library for tokenizers and model configurations
- **Qwix**: Quantization and LoRA utilities for JAX models
- **Tunix**: Training utilities including `PeftTrainer` and `AutoModel` for streamlined fine-tuning

The easiest way to get a working environment is the [NVIDIA NGC JAX container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax), which ships with all dependencies preinstalled. To install the dependencies manually:

```bash
pip install 'jax[cuda13]' flax optax transformers datasets qwix
```

On top of the installation (either container or manual), you will need Tunix:

```bash
pip install tunix
```

In [None]:
# Imports
import time
import shutil
import numpy as np
import jax
import jax.numpy as jnp
import optax
from flax import nnx
import transformers
from datasets import load_dataset

import qwix
from tunix.models.automodel import AutoModel
from tunix.sft import peft_trainer, metrics_logger

print(f"JAX {jax.__version__} | Devices: {jax.devices()}")

### Create the device mesh

JAX uses a device mesh to define how computation and data are distributed across GPUs. The mesh assigns logical axis names to physical device dimensions, enabling FSDP (Fully Sharded Data Parallel) and TP (Tensor Parallel) strategies. The configuration adapts automatically based on available GPUs:

| GPUs | Mesh Shape | Strategy |
|------|------------|----------|
| 8+ | `(1, 4, 2)` | data + FSDP + TP |
| 2â€“7 | `(N, 1)` | FSDP only |
| 1 | `(1, 1)` | No sharding |

The `fsdp` axis shards model parameters across devices to reduce per-device memory, while `tp` enables tensor-parallel splitting of large weight matrices.

In [None]:
# Create mesh for sharding
NUM_DEVICES = jax.local_device_count()

if NUM_DEVICES >= 8:
    mesh = jax.make_mesh((1, 4, 2), ("data", "fsdp", "tp"),
        axis_types=(jax.sharding.AxisType.Auto,) * 3)
elif NUM_DEVICES >= 2:
    # Shard model across GPUs using FSDP
    mesh = jax.make_mesh((NUM_DEVICES, 1), ("fsdp", "tp"),
        axis_types=(jax.sharding.AxisType.Auto,) * 2)
else:
    # Single GPU - no sharding, but keep axis names for API consistency
    mesh = jax.make_mesh((1, 1), ("fsdp", "tp"),
        axis_types=(jax.sharding.AxisType.Auto,) * 2)

print(f"Devices: {NUM_DEVICES} | Mesh: {mesh.shape}")

## Define model and training parameters

All training hyperparameters are defined in one place for easy experimentation. The key parameters control model selection, LoRA configuration, quantization, batch size, sequence length, and training duration. Set `CLEAN_START = True` to remove existing checkpoints before training, or `False` to resume from a previous run.

In [None]:
# Configuration
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # Suppress CUDA/TF warnings

MODEL_ID = "meta-llama/Llama-3.1-8B"
TOKENIZER_ID = "meta-llama/Llama-3.1-8B-Instruct"

LORA_RANK = 16
LORA_ALPHA = 32.0
USE_QUANTIZATION = True  # set True for QLoRA (4-bit), set False for regular LoRA

BATCH_SIZE = 2
MAX_SEQ_LENGTH = 512
LEARNING_RATE = 1e-4
MAX_STEPS = 100

OUTPUT_DIR = "/workspace/llama3_lora_output"
CLEAN_START = True  # Set to False to resume from checkpoint

if CLEAN_START and os.path.exists(f"{OUTPUT_DIR}/checkpoints"):
    shutil.rmtree(f"{OUTPUT_DIR}/checkpoints")
    print("Removed old checkpoints (CLEAN_START=True)")

os.makedirs(OUTPUT_DIR, exist_ok=True)

## Load the model

### Load the tokenizer

Load the tokenizer from the Instruct model variant, which includes the chat template for formatting conversations. The pad token is set to the EOS token if not already defined, which is standard for decoder-only models like Llama.

In [None]:
# Load tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_ID, token=HF_TOKEN)
tokenizer.pad_token = tokenizer.pad_token or tokenizer.eos_token
print(f"Tokenizer loaded: {TOKENIZER_ID}")

### Load the base model

`AutoModel.from_pretrained()` handles the complete model loading pipeline: downloading weights from the Hugging Face Hub (cached in `model_download_path`), converting them to JAX-compatible format, and initializing the model architecture with proper sharding across the mesh.

The model is loaded within the mesh context to ensure parameters are distributed correctly across devices from the start.

In [None]:
# Load model using AutoModel
print(f"Loading {MODEL_ID}...")
load_start = time.time()

with mesh:
    base_model, model_path = AutoModel.from_pretrained(
        MODEL_ID,
        mesh,
        model_download_path="/hf_cache",
    )

print(f"Model loaded in {time.time() - load_start:.1f}s")
print(f"Model path: {model_path}")

## Apply LoRA / QLoRA

Low-Rank Adaptation (LoRA) freezes the base model weights and injects small trainable matrices into attention and MLP layers. This dramatically reduces the number of trainable parameters while preserving model quality.

**QLoRA** adds 4-bit NF4 quantization on top of LoRA to further reduce memory:
- Base weights are quantized to 4-bit NormalFloat format
- Only the small LoRA adapter weights remain in full precision
- `tile_size=32` controls the quantization block size (must divide the smallest weight dimension)

**Target modules** specify which layers receive LoRA adapters using regex patterns matching attention projections (`q_proj`, `k_proj`, `v_proj`, `o_proj`) and MLP layers (`gate_proj`, `up_proj`, `down_proj`).

In [None]:
# Apply QLoRA / LoRA
target_modules = ".*q_proj|.*k_proj|.*v_proj|.*o_proj|.*gate_proj|.*up_proj|.*down_proj"

lora_provider = qwix.LoraProvider(
    module_path=target_modules,
    rank=LORA_RANK,
    alpha=LORA_ALPHA,
    weight_qtype="nf4" if USE_QUANTIZATION else None,
    tile_size=32 if USE_QUANTIZATION else None,
)

dummy_input = {
    'input_tokens': jnp.ones((1, 128), dtype=jnp.int32),
    'positions': jnp.arange(128)[None, :],
    'cache': None,
    'attention_mask': jnp.ones((1, 128, 128), dtype=jnp.bool_),
}

print(f"Applying {'QLoRA' if USE_QUANTIZATION else 'LoRA'} (rank={LORA_RANK})...")
lora_model = qwix.apply_lora_to_model(
    base_model, lora_provider,
    rngs=nnx.Rngs(params=0),  # For reproducible LoRA weight initialization
    **dummy_input
)

with mesh:
    state = nnx.state(lora_model)
    sharded = jax.lax.with_sharding_constraint(state, nnx.get_partition_spec(state))
    nnx.update(lora_model, sharded)

print(f"{'QLoRA' if USE_QUANTIZATION else 'LoRA'} applied!")

## Prepare the training data

Load the [UltraChat 200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k) dataset, a large collection of multi-turn conversations commonly used for instruction fine-tuning. For this tutorial, a subset of 2,000 training and 200 evaluation examples is used.

The data processing pipeline applies the chat template to format conversations with special tokens, tokenizes with padding to `MAX_SEQ_LENGTH`, and creates attention masks to ignore padding tokens. Training uses an infinite generator that cycles through the data, while evaluation uses a finite iterator that yields exactly one pass through the eval set. Batch size is scaled by `NUM_DEVICES` for data parallelism.

In [None]:
# Prepare dataset
dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft").select(range(2000))
eval_dataset = load_dataset("HuggingFaceH4/ultrachat_200k", split="test_sft").select(range(200))

def tokenize(ex):
    text = tokenizer.apply_chat_template(ex["messages"], tokenize=False)
    tok = tokenizer(text, max_length=MAX_SEQ_LENGTH, padding="max_length", truncation=True)
    return {"input_tokens": np.array(tok["input_ids"]), "input_mask": np.array(tok["attention_mask"], dtype=bool)}

train_data = [tokenize(ex) for ex in dataset]
eval_data = [tokenize(ex) for ex in eval_dataset]

# Infinite generator for training (cycles through data)
def train_batches(data, bs):
    i = 0
    while True:
        batch = data[i:i+bs] if i+bs <= len(data) else data[:bs]
        yield {k: np.stack([x[k] for x in batch]) for k in batch[0]}
        i = (i + bs) % len(data)

# Reusable eval dataset - returns fresh finite iterator each time
class EvalDataset:
    def __init__(self, data, bs):
        self.data = data
        self.bs = bs
    def __iter__(self):
        for i in range(0, len(self.data), self.bs):
            batch = self.data[i:i+self.bs]
            if len(batch) == self.bs:
                yield {k: np.stack([x[k] for x in batch]) for k in batch[0]}

train_ds = train_batches(train_data, BATCH_SIZE * NUM_DEVICES)
eval_ds = EvalDataset(eval_data, BATCH_SIZE * NUM_DEVICES)

print(f"Train: {len(train_data)} examples | Eval: {len(eval_data)} examples | Batch size: {BATCH_SIZE * NUM_DEVICES}")

## Provide the training configuration

The model expects specific input formats: position indices for rotary embeddings and a 3D causal attention mask `[batch, seq, seq]` combining causal attention with padding. The `gen_model_input` function constructs these from the tokenized batch.

`PeftTrainer` orchestrates the training loop with an AdamW optimizer, periodic checkpointing, TensorBoard-compatible metrics logging, and evaluation every `eval_every_n_steps` steps.

In [None]:
# Input processing helpers
def build_positions(mask):
    return jnp.clip(jnp.cumsum(mask, axis=-1) - 1, 0).astype(jnp.int32)

def build_causal_mask(mask):
    n = mask.shape[-1]
    return jnp.tril(jnp.ones((n, n), dtype=jnp.bool_))[None] & mask[:, None, :]

def gen_model_input(x):
    mask = x["input_tokens"] != tokenizer.pad_token_id
    return {
        "input_tokens": x["input_tokens"],
        "positions": build_positions(mask),
        "attention_mask": build_causal_mask(mask),
        "input_mask": x["input_mask"],
    }

# Create trainer
trainer = peft_trainer.PeftTrainer(
    lora_model,
    optax.adamw(LEARNING_RATE),
    peft_trainer.TrainingConfig(
        max_steps=MAX_STEPS,
        eval_every_n_steps=25,  # Evaluate every 25 steps
        checkpoint_root_directory=f"{OUTPUT_DIR}/checkpoints",
        metrics_logging_options=metrics_logger.MetricsLoggerOptions(log_dir=f"{OUTPUT_DIR}/logs"),
    ),
).with_gen_model_input_fn(gen_model_input)

print("Trainer ready!")

## Run the training

This block launches the PEFT training loop. It runs a baseline evaluation first to measure initial loss, then trains for `MAX_STEPS` steps with periodic evaluation. The first step is slower due to XLA JIT compilation, which is cached for subsequent steps.

In [None]:
# Training with progress
NUM_EVAL_BATCHES = len(eval_data) // (BATCH_SIZE * NUM_DEVICES)

class Progress:
    def __init__(self, n): 
        self.n = n
        self.t0 = None
        self.eval_count = 0
        self.eval_started = False
    def on_train_start(self, _): 
        self.t0 = time.time()
        print("Training (first step includes JIT)...")
    def on_train_end(self, _): 
        print(f"\nDone in {time.time()-self.t0:.0f}s")
    def on_train_step_start(self, _): 
        self.eval_started = False
    def on_train_step_end(self, _, step, loss, dt):
        if step <= 2 or step % 10 == 0:
            print(f"Step {step}/{self.n} | Loss: {float(loss):.4f} | {dt:.1f}s/step")
    def on_eval_step_start(self, _):
        if not self.eval_started:
            self.eval_count += 1
            label = "Baseline eval" if self.eval_count == 1 else f"Eval #{self.eval_count}"
            print(f"{label}...", end=" ", flush=True)
            self.eval_started = True
    def on_eval_step_end(self, _, eval_loss):
        avg_loss = float(eval_loss) / NUM_EVAL_BATCHES
        print(f"loss: {avg_loss:.4f} (avg over {NUM_EVAL_BATCHES} batches)")

trainer.training_hooks = Progress(MAX_STEPS)

print("Starting (baseline eval + JIT compilation first)...")
with mesh:
    trainer.train(train_ds, eval_ds)

print(f"Checkpoints: {OUTPUT_DIR}/checkpoints")

### Visualize training with TensorBoard

To monitor training loss and other metrics, launch TensorBoard in a separate terminal:

```bash
tensorboard --logdir=/workspace/llama3_lora_output/logs --host 0.0.0.0 --port 6006 --load_fast=false
```

Then open [http://127.0.0.1:6006/](http://127.0.0.1:6006/) in your browser.

## Test inference

A quick sanity check to verify the fine-tuned model produces coherent output. The code below tokenizes a prompt using the Llama 3.1 chat template, then runs greedy autoregressive generation for up to 10 tokens, stopping early if the model produces an EOS token. This confirms the adapters are applied correctly and the model produces reasonable predictions.

Note: this is naive autoregressive generation without KV-caching, so each step recomputes attention over the full sequence. For production use, consider a dedicated serving framework with KV-cache support.

In [None]:
# Quick inference test with the fine-tuned LoRA model
prompt = "What is the capital of France?"
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Tokenize
tokens = jnp.array(tokenizer(text)["input_ids"])[None, :]

# Greedy autoregressive generation
max_new_tokens = 10
generated_ids = []
eos_token_id = tokenizer.eos_token_id

for _ in range(max_new_tokens):
    seq_len = tokens.shape[1]
    positions = jnp.arange(seq_len)[None, :]
    attention_mask = jnp.tril(jnp.ones((seq_len, seq_len), dtype=jnp.bool_))[None, :]

    with mesh:
        output = lora_model(tokens, positions, None, attention_mask)
        logits = output[0] if isinstance(output, tuple) else output

    next_token_id = int(jnp.argmax(logits[0, -1]))
    generated_ids.append(next_token_id)

    if next_token_id == eos_token_id:
        break

    tokens = jnp.concatenate([tokens, jnp.array([[next_token_id]])], axis=1)

# Decode all generated tokens
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

print(f"Prompt: {prompt}")
print(f"Generated ({len(generated_ids)} tokens): '{generated_text}'")