# Lightweight Fine-Tuning — LoRA + QLoRA (IMDb Sentiment)

This notebook is a job-ready refresh of your original PEFT demo. It includes:
- **Baseline** evaluation of a frozen model
- **LoRA** fine-tuning path (rank/configurable)
- **Optional QLoRA** path (4-bit base model with `bitsandbytes`)
- **Evaluation** with Accuracy and Macro-F1 (+ confusion matrix)
- **Save/Reload** adapters and **Merge** for export
- **Deployment-ready** inference helpers

You can toggle LoRA vs QLoRA via a config flag.


In [1]:
# %% [setup]
import os, sys, math, time
from dataclasses import dataclass
from typing import Optional, Dict

import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForSequenceClassification,
    DataCollatorWithPadding, TrainingArguments, Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Device:', DEVICE)

@dataclass
class Config:
    base_model: str = 'bert-base-uncased'  # keep BERT to align with original notebook
    task_name: str = 'imdb'
    max_length: int = 256
    sample_train: Optional[int] = 5000  # None for full
    sample_test: Optional[int] = 2500
    batch_size: int = 16
    num_epochs: int = 3
    lr: float = 5e-5
    lora_r: int = 8
    lora_alpha: int = 32
    lora_dropout: float = 0.1
    target_modules = ['query','key','value']  # include key as upgrade
    use_qlora: bool = False  # toggle here
    output_dir: str = 'outputs_peft'

cfg = Config()
os.makedirs(cfg.output_dir, exist_ok=True)
print(cfg)


Device: cuda
Config(base_model='bert-base-uncased', task_name='imdb', max_length=256, sample_train=5000, sample_test=2500, batch_size=16, num_epochs=3, lr=5e-05, lora_r=8, lora_alpha=32, lora_dropout=0.1, use_qlora=False, output_dir='outputs_peft')


## 1) Data — IMDb
We keep the dataset and pre-processing nearly identical to your original notebook for continuity.


In [2]:
# %% [data]
raw = load_dataset('imdb')
if cfg.sample_train:
    raw['train'] = raw['train'].select(range(cfg.sample_train))
if cfg.sample_test:
    raw['test'] = raw['test'].select(range(cfg.sample_test))

tok = AutoTokenizer.from_pretrained(cfg.base_model, use_fast=True)
def tokenize_fn(ex):
    return tok(ex['text'], truncation=True, max_length=cfg.max_length)

tok_ds = raw.map(tokenize_fn, batched=True)
tok_ds = tok_ds.remove_columns([c for c in tok_ds['train'].column_names if c not in ['input_ids','attention_mask','label']])
tok_ds.set_format(type='torch', columns=['input_ids','attention_mask','label'])

collator = DataCollatorWithPadding(tok)
num_labels = 2
print(tok_ds)


Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['label', 'input_ids', 'attention_mask'],
        num_rows: 5000
    })
    test: Dataset({
        features: ['label', 'input_ids', 'attention_mask'],
        num_rows: 2500
    })
    unsupervised: Dataset({
        features: ['label', 'input_ids', 'attention_mask'],
        num_rows: 50000
    })
})


## 2) Baseline — frozen model evaluation
We evaluate the base model without fine-tuning to establish a reference.


In [3]:
# %% [baseline]
baseline_model = AutoModelForSequenceClassification.from_pretrained(
    cfg.base_model, num_labels=num_labels
).to(DEVICE)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(axis=-1)
    acc = accuracy_score(labels, preds)
    f1 = f1_score(labels, preds, average='macro')
    return {'accuracy': acc, 'f1_macro': f1}

baseline_args = TrainingArguments(
    output_dir=os.path.join(cfg.output_dir, 'baseline'),
    per_device_eval_batch_size=cfg.batch_size,
    do_train=False,
    do_eval=True,
    logging_steps=50,
)

baseline_trainer = Trainer(
    model=baseline_model,
    args=baseline_args,
    eval_dataset=tok_ds['test'],
    tokenizer=tok,
    data_collator=collator,
    compute_metrics=compute_metrics,
)
baseline_metrics = baseline_trainer.evaluate()
baseline_metrics


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  baseline_trainer = Trainer(


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mgrausoft-net[0m ([33mgrausoft-net-it-freelancer-oliver-grau[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


{'eval_loss': 0.7940624952316284,
 'eval_model_preparation_time': 0.0023,
 'eval_accuracy': 0.0948,
 'eval_f1_macro': 0.0865911582024114,
 'eval_runtime': 16.0773,
 'eval_samples_per_second': 155.499,
 'eval_steps_per_second': 9.765}

## 3) LoRA / QLoRA setup
Toggle `cfg.use_qlora` to switch between classic LoRA and QLoRA. QLoRA loads the base model in 4-bit and prepares it for k-bit training.


In [4]:
# %% [peft-setup]
if cfg.use_qlora:
    peft_base = AutoModelForSequenceClassification.from_pretrained(
        cfg.base_model,
        num_labels=num_labels,
        load_in_4bit=True,
        device_map='auto'
    )
    peft_base = prepare_model_for_kbit_training(peft_base)
else:
    peft_base = AutoModelForSequenceClassification.from_pretrained(
        cfg.base_model, num_labels=num_labels
    )
peft_base.to(DEVICE)

lora_cfg = LoraConfig(
    r=cfg.lora_r,
    lora_alpha=cfg.lora_alpha,
    lora_dropout=cfg.lora_dropout,
    bias='none',
    task_type='SEQ_CLS',
    target_modules=cfg.target_modules,
)
peft_model = get_peft_model(peft_base, lora_cfg)
peft_model.print_trainable_parameters()


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 443,906 || all params: 109,927,684 || trainable%: 0.4038


## 4) Train adapters
We fine-tune only the LoRA parameters. Metrics include Accuracy and Macro-F1.


In [5]:
# %% [train]
train_args = TrainingArguments(
    output_dir=os.path.join(cfg.output_dir, 'lora' if not cfg.use_qlora else 'qlora'),
    num_train_epochs=cfg.num_epochs,
    per_device_train_batch_size=cfg.batch_size,
    per_device_eval_batch_size=cfg.batch_size,
    learning_rate=cfg.lr,
    evaluation_strategy='epoch',
    save_strategy='epoch',
    logging_steps=50,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=peft_model,
    args=train_args,
    train_dataset=tok_ds['train'],
    eval_dataset=tok_ds['test'],  # quick turnaround; swap in a real val set if desired
    tokenizer=tok,
    data_collator=collator,
    compute_metrics=compute_metrics,
)
trainer.train()
eval_metrics = trainer.evaluate()
eval_metrics


  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,F1 Macro
1,0.0008,0.000627,1.0,1.0
2,0.0003,0.000291,1.0,1.0
3,0.0003,0.000234,1.0,1.0


{'eval_loss': 0.00023359589977189898,
 'eval_accuracy': 1.0,
 'eval_f1_macro': 1.0,
 'eval_runtime': 17.2576,
 'eval_samples_per_second': 144.864,
 'eval_steps_per_second': 9.097,
 'epoch': 3.0}

## 5) Confusion matrix
A quick confusion matrix for additional signal on misclassifications.


In [6]:
# %% [confusion-matrix]
preds = trainer.predict(tok_ds['test'])
y_true = preds.label_ids
y_pred = preds.predictions.argmax(axis=-1)
cm = confusion_matrix(y_true, y_pred)
cm




array([[2500]])

## 6) Save adapters, reload, and merge for export
We save the LoRA adapters, show how to reload them onto the base model, and optionally merge for a single exportable model.


In [7]:
# %% [save-reload-merge]
adapters_dir = os.path.join(cfg.output_dir, 'adapters_lora' if not cfg.use_qlora else 'adapters_qlora')
trainer.model.save_pretrained(adapters_dir)
tok.save_pretrained(adapters_dir)
print('Saved adapters to', adapters_dir)

# Reload for inference
reload_base = AutoModelForSequenceClassification.from_pretrained(
    cfg.base_model, num_labels=num_labels
).to(DEVICE)
from peft import PeftModel
reload_peft = PeftModel.from_pretrained(reload_base, adapters_dir).to(DEVICE)
reload_peft.eval()

# Optional: merge and save a single safetensors file (only for non-quantized base)
if not cfg.use_qlora:
    merged = reload_peft.merge_and_unload()
    merged_dir = os.path.join(cfg.output_dir, 'merged_lora_model')
    merged.save_pretrained(merged_dir)
    print('Merged model saved to', merged_dir)
else:
    print('Merging is disabled for QLoRA since the base is 4-bit quantized.')


Saved adapters to outputs_peft/adapters_lora


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Merged model saved to outputs_peft/merged_lora_model


## 7) Inference helper
A quick function that runs text → label using the saved adapters.


In [8]:
# %% [inference]
def classify_text(text: str, model, tokenizer, max_length: int = 256):
    model.eval()
    with torch.no_grad():
        enc = tokenizer(text, truncation=True, max_length=max_length, return_tensors='pt').to(DEVICE)
        out = model(**enc)
        pred = out.logits.argmax(dim=-1).item()
    return int(pred)

sample = "This movie was surprisingly good and kept me engaged."
label = classify_text(sample, reload_peft, tok, max_length=cfg.max_length)
print('Prediction (0=neg,1=pos):', label)


Prediction (0=neg,1=pos): 0


## 8) Notes for MLOps / production
- Track runs with MLflow or W&B (plug into `Trainer` callbacks).
- Keep adapters separate for small artifacts; merge only if you need a single file.
- Add a red-team prompt set when moving from classification to instruction SFT.
- For on-prem clusters, parameterize hyperparams via env/CLI and store artifacts to shared storage.
