# Production QLoRA Training — Llama-2-13B (H100 Optimized)

Enterprise-grade fine-tuning pipeline optimized for NVIDIA H100 80GB.

**Key Optimizations for Stability**:
- **Gradient Checkpointing**: CRITICAL fix to reduce VRAM from >80GB to ~25GB.
- **Batch Size 2 / Accum 4**: Effective batch 8, but minimizing peak memory.
- **8-bit Optimizer**: Reduces fragmentation.
- **Memory Callback**: Aggressive cache clearing.

In [None]:
import os
os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [None]:
!pip install -q -U bitsandbytes accelerate peft transformers datasets tqdm faiss-cpu sentence-transformers flash-attn --no-build-isolation

try:
    import bitsandbytes
    print(f'[OK] bitsandbytes {bitsandbytes.__version__}')
except ImportError:
    print('[FATAL] bitsandbytes not found. Restarting runtime...')
    import os; os.kill(os.getpid(), 9)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import time

# === Hardware Verification ===
assert torch.cuda.is_available(), 'FATAL: No CUDA device found'

gpu_name = torch.cuda.get_device_name(0)
vram_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3

print('=' * 60)
print('HARDWARE DIAGNOSTICS')
print('=' * 60)
print(f'  GPU:           {gpu_name}')
print(f'  VRAM:          {vram_gb:.1f} GB')
print(f'  CUDA Version:  {torch.version.cuda}')
print(f'  PyTorch:       {torch.__version__}')
print(f'  BF16 Support:  {torch.cuda.is_bf16_supported()}')

# Enable TF32 for H100/A100/RTX6000 tensor core acceleration (PyTorch 2.9+ style)
torch.backends.cuda.matmul.allow_tf32 = True # Fallback
try:
    torch.backends.cuda.matmul.fp32_precision = 'tf32'
    torch.backends.cuda.conv.fp32_precision = 'tf32'
    print(f'  TF32:          Enabled (Strict)')
except AttributeError:
    print(f'  TF32:          Enabled (Legacy)')

print('=' * 60)

## 1. Load Model — 4-bit QLoRA + Flash Attention 2

In [None]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from tqdm import tqdm

MODEL_NAME = 'NousResearch/Llama-2-13b-hf'
MAX_LENGTH = 1024 # Reduced from 2048 to save memory

print(f'Loading {MODEL_NAME} in 4-bit NF4...')

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.model_max_length = MAX_LENGTH
# Suppress warning about padding, since we handle it dynamically
tokenizer.deprecation_warnings = {'Asking-to-pad-to-max_length': True}

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map='auto',
    attn_implementation='flash_attention_2'
)

# CRITICAL: Enable Gradient Checkpointing to save ~40GB VRAM
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True)
model.gradient_checkpointing_enable()

# Freeze all base model parameters
for param in model.parameters():
    param.requires_grad = False

print('Base model loaded. Applying LoRA...')

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    bias='none',
    target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj']
)

model = get_peft_model(model, peft_config)

# Verify only LoRA params are trainable
trainable, total = 0, 0
for p in model.parameters():
    total += p.numel()
    if p.requires_grad:
        trainable += p.numel()

print('=' * 60)
print('LORA CONFIGURATION')
print('=' * 60)
print(f'  Total params:     {total:,}')
print(f'  Trainable params: {trainable:,}')
print(f'  Trainable %:      {100 * trainable / total:.2f}%')
print(f'  LoRA rank:        32')
print(f'  LoRA alpha:       64')
print(f'  Targets:          q_proj, k_proj, v_proj, o_proj')
print(f'  Sequence length:  {MAX_LENGTH}')
print('=' * 60)

## 2. Prepare Streaming Data (Robust Cleaning)

**Strategy:**
- **Streaming**: Never load full dataset into RAM.
- **Dynamic Padding**: Tokenize with `padding=False`.
- **Clean Columns**: Auto-detects and removes non-tensor columns.

In [None]:
from datasets import load_dataset

print('Initializing dataset stream...')

# Streaming mode: Load on demand
raw_dataset = load_dataset(
    'HuggingFaceFW/fineweb-edu',
    split='train',
    streaming=True
)

# Shuffle buffer for randomness
shuffled_dataset = raw_dataset.shuffle(seed=42, buffer_size=10000)

def tokenize_function(examples):
    # Dynamic padding: NO padding here, just truncation
    return tokenizer(
        examples['text'],
        truncation=True,
        max_length=MAX_LENGTH,
        padding=False # CRITICAL: Dynamic padding handled by collator
    )

print('Configuring tokenization stream...')

# Inspect first sample to find all columns that need removal
try:
    sample = next(iter(shuffled_dataset))
    remove_cols = list(sample.keys())
    print(f'Detected columns to remove: {remove_cols}')
except Exception as e:
    print(f'Warning: Could not auto-detect columns ({e}). Using default list.')
    remove_cols = ['text', 'id', 'url', 'file_name', 'timestamp', 'dump', 'segment', 'token_count']

tokenized_dataset = shuffled_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=remove_cols
)

print('Dataset stream ready.')

## 3. Training Configuration (Memory Optimized)

In [None]:
output_dir = '/content/drive/MyDrive/fineweb_edu_llama2_13b/checkpoints'
os.makedirs(output_dir, exist_ok=True)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=2,   # Reduced from 8 to 2
    gradient_accumulation_steps=4,   # Increased from 1 to 4 (Net 8)
    learning_rate=1e-4,
    lr_scheduler_type='cosine',
    warmup_steps=150,                # Fixed steps instead of ratio
    max_steps=5000,
    bf16=True,
    fp16=False,
    optim='adamw_bnb_8bit',          # 8-bit optimizer to save memory
    gradient_checkpointing=True,     # CRITICAL for memory savings
    logging_steps=20,
    save_steps=500,
    save_total_limit=3,
    report_to='none',
    remove_unused_columns=True,
    dataloader_num_workers=4,        # Reduced CPU memory pressure
    dataloader_pin_memory=True,
    dataloader_drop_last=True,
    dataloader_persistent_workers=True
)

print('=' * 60)
print('TRAINING ARGUMENTS')
print('=' * 60)
print(f'  Batch Size:          2 (per device)')
print(f'  Grad Accumulation:   4')
print(f'  Effective Batch:     8')
print(f'  Max Steps:           5,000')
print(f'  Precision:           BF16')
print(f'  Grad Checkpointing:  ENABLED (Critical)')
print(f'  Optimizer:           adamw_bnb_8bit')
print(f'  LR:                  1e-4 (cosine, 150 warmup steps)')
print(f'  Dataloader Workers:  4')
print(f'  Padding:             Dynamic')
print('=' * 60)

In [None]:
import logging
from transformers import TrainerCallback
import gc

class MemoryCallback(TrainerCallback):
    """Aggressively cleans memory to prevent fragmentation."""
    def on_step_end(self, args, state, control, **kwargs):
        torch.cuda.empty_cache()
        gc.collect()

class ThroughputCallback(TrainerCallback):
    """Logs throughput and warns if below target."""
    def __init__(self, batch_size, grad_accum):
        self.batch_size = batch_size * grad_accum
        self.start_time = None
        self.start_step = 0

    def on_step_begin(self, args, state, control, **kwargs):
        if self.start_time is None:
            self.start_time = time.time()
            self.start_step = state.global_step

    def on_log(self, args, state, control, logs=None, **kwargs):
        if self.start_time is None or state.global_step <= self.start_step:
            return
        elapsed = time.time() - self.start_time
        steps_done = state.global_step - self.start_step
        if steps_done == 0:
            return

        steps_per_sec = steps_done / elapsed
        remaining_steps = args.max_steps - state.global_step
        eta_min = remaining_steps / steps_per_sec / 60 if steps_per_sec > 0 else float('inf')

        print(f'  [PERF] Step {state.global_step}/{args.max_steps} | '
              f'{steps_per_sec:.2f} it/s | '
              f'ETA: {eta_min:.0f} min')

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
    callbacks=[
        ThroughputCallback(batch_size=2, grad_accum=4),
        MemoryCallback()
    ]
)

print('Trainer ready.')

## 4. Train

In [None]:
from transformers.trainer_utils import get_last_checkpoint

last_checkpoint = get_last_checkpoint(output_dir)

train_start = time.time()

try:
    if last_checkpoint is not None:
        print(f'Resuming from checkpoint: {last_checkpoint}')
        trainer.train(resume_from_checkpoint=last_checkpoint)
    else:
        print('Starting fresh Llama-2-13B QLoRA training...')
        trainer.train()

    train_elapsed = time.time() - train_start
    print('=' * 60)
    print(f'TRAINING COMPLETE — {train_elapsed / 60:.1f} minutes')
    print('=' * 60)

except torch.cuda.OutOfMemoryError:
    print('\n' + '=' * 60)
    print('FATAL: CUDA Out of Memory')
    print('=' * 60)
    print(f'  Allocated: {torch.cuda.memory_allocated() / 1024**3:.1f} GB')
    print(f'  Reserved:  {torch.cuda.memory_reserved() / 1024**3:.1f} GB')
    print(f'  Total:     {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB')
    print('=' * 60)
    torch.cuda.empty_cache()

In [None]:
final_model_dir = '/content/drive/MyDrive/fineweb_edu_llama2_13b/final_model'
print(f'Saving LoRA adapters to: {final_model_dir}')
trainer.save_model(final_model_dir)
tokenizer.save_pretrained(final_model_dir)
print('Model saved successfully.')

## 5. Build RAG Index

In [None]:
import faiss
from sentence_transformers import SentenceTransformer
import numpy as np

RAG_SAMPLES = 100_000
RAG_DIR = '/content/drive/MyDrive/fineweb_edu_llama2_13b/rag_index'
os.makedirs(RAG_DIR, exist_ok=True)

# Reload streaming to get raw text for RAG
rag_stream = raw_dataset.take(RAG_SAMPLES)

passages = []
print('Extracting passages...')
for row in tqdm(rag_stream, total=RAG_SAMPLES):
    text = row['text'].strip()
    for i in range(0, len(text), 500):
        chunk = text[i:i + 500].strip()
        if len(chunk) > 50:
            passages.append(chunk)

print(f'Encoding {len(passages):,} passages...')
embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
embeddings = embedder.encode(passages, show_progress_bar=True, batch_size=256, convert_to_numpy=True)

index = faiss.IndexFlatIP(embeddings.shape[1])
faiss.normalize_L2(embeddings)
index.add(embeddings)

faiss.write_index(index, os.path.join(RAG_DIR, 'faiss_index.bin'))
np.save(os.path.join(RAG_DIR, 'passages.npy'), np.array(passages, dtype=object))
print(f'RAG index saved: {len(passages):,} passages, dim={embeddings.shape[1]}')