# Phi-3 Mini (3.8B) — QLoRA on RTX 4070 (bf16) for JSON Exam Generation

This notebook trains a **Phi-3 Mini Instruct** LoRA adapter with QLoRA (int4) on your dataset,
adds a robust evaluation metric (**JSON parse rate**), uses a cosine LR schedule + early stopping,
and provides **strict** and **creative** decoding that still returns valid JSON.

> GPU: **RTX 4070** (bf16-capable). Settings tuned for 12 GB VRAM.

In [None]:
import os, math, random, json, re
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple

import torch
from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig,
    Trainer, TrainingArguments, EarlyStoppingCallback
)
from peft import LoraConfig, get_peft_model, PeftModel

print("PyTorch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))
    print("Compute capability:", torch.cuda.get_device_capability())
    print("BF16 support:", torch.cuda.get_device_capability()[0] >= 8)
    torch.set_float32_matmul_precision("high")

try:
    from torch.nn.attention import SDPBackend, sdpa_kernel
    sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION])
    print("SDPA flash kernels: enabled")
except Exception as e:
    print("SDPA flash kernels: not enabled", e)

In [None]:
import torch, transformers, bitsandbytes as bnb
print("torch:", torch.__version__, "| CUDA:", torch.version.cuda, "| cuda_available:", torch.cuda.is_available())
print("transformers:", transformers.__version__)
print("bitsandbytes:", bnb.__version__)

try:
    import bitsandbytes.cuda_setup.main as bnb_setup
    bnb_setup.main_check()
except Exception as e:
    print("bnb cuda_setup check raised:", e)

import platform, sys
print("OS:", platform.platform(), "| Python:", sys.version)

In [None]:
import bitsandbytes

BASE_MODEL = "microsoft/Phi-3-mini-4k-instruct"
DATA_PATH = "data/train.json"
OUT_DIR = "out-phi3-lora-4070"
SEED = 42

random.seed(SEED); torch.manual_seed(SEED)

bf16_ok = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16 if bf16_ok else torch.float16,
)

tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
tok.padding_side = "right"

base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
)

lora = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    task_type="CAUSAL_LM",
)
model = get_peft_model(base, lora)
model.print_trainable_parameters()

model.config.use_cache = False
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

print("Loaded model + LoRA for QLoRA.")

In [None]:

from datasets import load_dataset
ds = load_dataset("json", data_files=DATA_PATH, split="train")
print("Total records:", len(ds))

def build_source(prompt: str) -> str:
    return f"controls: {prompt}\nReturn ONLY a JSON object."

MAX_SEQ_LEN = 1024

def preprocess_batch(batch):
    input_ids_batch, attn_batch, labels_batch = [], [], []
    for p, y in zip(batch["prompt"], batch["output"]):
        src = build_source(p)
        src_ids = tok(src, add_special_tokens=False)["input_ids"]
        tgt_ids = tok(y, add_special_tokens=False)["input_ids"]
        combined = src_ids + tgt_ids
        if len(combined) > MAX_SEQ_LEN:
            overflow = len(combined) - MAX_SEQ_LEN
            tgt_ids = tgt_ids[overflow:]
            combined = src_ids + tgt_ids
        labels = [-100]*len(src_ids) + tgt_ids
        attn = [1]*len(combined)
        input_ids_batch.append(combined); attn_batch.append(attn); labels_batch.append(labels)
    return {"input_ids": input_ids_batch, "attention_mask": attn_batch, "labels": labels_batch}

cols = ds.column_names
ds_proc = ds.map(preprocess_batch, batched=True, remove_columns=cols)
print(ds_proc)

In [None]:
n = len(ds_proc)
eval_size = max(1, int(0.1 * n))
ds_proc = ds_proc.shuffle(seed=SEED)
ds_train = ds_proc.select(range(n - eval_size))
ds_eval  = ds_proc.select(range(n - eval_size, n))

from dataclasses import dataclass
from typing import Any, Dict, List
import torch

@dataclass
class DataCollatorForCausal:
    tokenizer: Any
    pad_to_multiple_of: int = 8
    label_pad_token_id: int = -100
    def __call__(self, features: List[Dict[str, Any]]):
        max_len = max(len(f["input_ids"]) for f in features)
        if self.pad_to_multiple_of:
            m = self.pad_to_multiple_of
            if max_len % m != 0:
                max_len = ((max_len // m) + 1) * m
        input_ids, attn_mask, labels = [], [], []
        for f in features:
            ids, attn, lab = f["input_ids"], f["attention_mask"], f["labels"]
            pad_len = max_len - len(ids)
            input_ids.append(ids + [tok.pad_token_id]*pad_len)
            attn_mask.append(attn + [0]*pad_len)
            labels.append(lab + [self.label_pad_token_id]*pad_len)
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attn_mask, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
        }

collator = DataCollatorForCausal(tokenizer=tok, pad_to_multiple_of=8)
print("Train/Eval sizes:", len(ds_train), len(ds_eval))

In [None]:
def _extract_or_repair_text(s: str):
    s = s.strip()
    if s.startswith("{") and s.endswith("}"):
        try: return json.loads(s)
        except json.JSONDecodeError: pass
    m = re.search(r"\{[\s\S]*\}", s)
    if m:
        cand = m.group(0)
        try: return json.loads(cand)
        except json.JSONDecodeError:
            cand2 = re.sub(r",(\s*[}\]])", r"\1", cand)
            cand2 = cand2.replace("“", '"').replace("”", '"').replace("’", "'")
            cand2 = re.sub(r"\bNone\b", "null", cand2)
            cand2 = re.sub(r"\bTrue\b", "true", cand2)
            cand2 = re.sub(r"\bFalse\b", "false", cand2)
            try: return json.loads(cand2)
            except Exception: return None
    return None

def compute_metrics(eval_preds):
    preds = eval_preds.predictions
    if isinstance(preds, tuple):
        preds = preds[0]
    texts = tok.batch_decode(preds, skip_special_tokens=True)
    ok = sum(1 for t in texts if isinstance(_extract_or_repair_text(t), dict))
    return {"json_parse_rate": ok / max(1, len(texts))}

In [None]:
bf16_ok = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8

args = TrainingArguments(
    output_dir=OUT_DIR,
    num_train_epochs=6,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    weight_decay=0.01,
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="json_parse_rate",
    greater_is_better=True,
    bf16=bf16_ok,
    fp16=not bf16_ok,
    optim="paged_adamw_8bit",
    report_to="none",
    group_by_length=True,
    predict_with_generate=True,
    generation_max_length=900,
    generation_num_beams=4,
    gradient_checkpointing=True,
    max_grad_norm=0.5,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds_train,
    eval_dataset=ds_eval,
    data_collator=collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

model.print_trainable_parameters()
trainer.train()
model.save_pretrained(OUT_DIR)
tok.save_pretrained(OUT_DIR)
print("Saved LoRA adapter to:", OUT_DIR)

In [None]:
MERGED_DIR = OUT_DIR + "-merged"
try:
    base_fp16 = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float16, device_map="auto")
    peft_loaded = PeftModel.from_pretrained(base_fp16, OUT_DIR)
    merged = peft_loaded.merge_and_unload()
    os.makedirs(MERGED_DIR, exist_ok=True)
    merged.save_pretrained(MERGED_DIR)
    tok.save_pretrained(MERGED_DIR)
    print("Merged model saved to:", MERGED_DIR)
except Exception as e:
    print("Merge skipped due to environment/VRAM:", e)

In [None]:
from transformers import StoppingCriteria, StoppingCriteriaList, BitsAndBytesConfig

USE_MERGED = False
CKPT_DIR = OUT_DIR + ("-merged" if USE_MERGED else "")

if USE_MERGED:
    tok_inf = AutoTokenizer.from_pretrained(CKPT_DIR, use_fast=True)
    if tok_inf.pad_token is None: tok_inf.pad_token = tok_inf.eos_token
    model_inf = AutoModelForCausalLM.from_pretrained(CKPT_DIR, torch_dtype=torch.float16, device_map="auto").eval()
else:
    tok_inf = AutoTokenizer.from_pretrained(OUT_DIR, use_fast=True)
    if tok_inf.pad_token is None: tok_inf.pad_token = tok_inf.eos_token
    bf16_ok = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
    base_q = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16 if bf16_ok else torch.float16
        ),
        device_map="auto"
    )
    model_inf = PeftModel.from_pretrained(base_q, OUT_DIR).eval()

MAX_SRC_LEN = 512
MAX_NEW_TOKENS = 900

def make_controls(topics, difficulty, length, fmt):
    return f"topics={','.join(topics)}; difficulty={difficulty}; length={length}; format={'+'.join(fmt)}"

def make_src(ctrl):
    return (
        f"controls: {ctrl}\n"
        "Return ONLY a JSON object with keys: "
        "metadata(topics[],difficulty,length,format[]), "
        "questions[{id:int,text:str,type:multiple_choice|open_answer,options:[str]|null,answer:str,subquestions:null}]."
    )

def _balanced_braces(text: str) -> bool:
    depth, in_str, esc = 0, False, False
    for ch in text:
        if in_str:
            if esc: esc = False
            elif ch == '\\': esc = True
            elif ch == '"': in_str = False
        else:
            if ch == '"': in_str = True
            elif ch == '{': depth += 1
            elif ch == '}':
                depth -= 1
                if depth < 0: return False
    return depth == 0 and "{" in text

class BalancedJSONStop(StoppingCriteria):
    def __init__(self, tok): self.tok = tok
    def __call__(self, input_ids, scores, **kwargs):
        text = self.tok.decode(input_ids[0], skip_special_tokens=True)
        return text.strip().startswith("{") and _balanced_braces(text)

brace_ids = tok_inf.encode("{", add_special_tokens=False)
first_brace = brace_ids[0] if brace_ids else None
assert first_brace is not None, "Tokenizer couldn't encode '{'."
def prefix_allowed_tokens_fn(batch_id, input_ids):
    return [first_brace] if input_ids.shape[1] == 1 else None

def _extract_or_repair(s: str):
    s = s.strip()
    if s.startswith("{") and s.endswith("}"):
        try: return json.loads(s)
        except json.JSONDecodeError: pass
    m = re.search(r"\{[\s\S]*\}", s)
    if m:
        cand = m.group(0)
        try: return json.loads(cand)
        except json.JSONDecodeError:
            cand2 = re.sub(r",(\s*[}\]])", r"\1", cand)
            cand2 = cand2.replace("“", '"').replace("”", '"').replace("’", "'")
            cand2 = re.sub(r"\bNone\b", "null", cand2)
            cand2 = re.sub(r"\bTrue\b", "true", cand2)
            cand2 = re.sub(r"\bFalse\b", "false", cand2)
            try: return json.loads(cand2)
            except Exception: return None
    return None

def canonicalize(obj):
    md = obj.get("metadata", {})
    if "length" in md: md["length"] = str(md["length"])
    for q in obj.get("questions", []):
        if q.get("type") == "open_answer": q["options"] = None
        if "answer" in q and q["answer"] is not None: q["answer"] = str(q["answer"])
        if "subquestions" not in q or q["subquestions"] is None: q["subquestions"] = None
    return obj

@torch.no_grad()
def generate_exam_strict(topics, difficulty, length, fmt):
    ctrl = make_controls(topics, difficulty, length, fmt)
    src  = make_src(ctrl)
    enc  = tok_inf(src, return_tensors="pt", truncation=True, max_length=MAX_SRC_LEN).to(model_inf.device)
    out  = model_inf.generate(
        **enc,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=False,
        num_beams=5,
        length_penalty=0.9,
        early_stopping=True,
        no_repeat_ngram_size=3,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
        stopping_criteria=StoppingCriteriaList([BalancedJSONStop(tok_inf)]),
    )
    text = tok_inf.decode(out[0], skip_special_tokens=True).strip()
    obj = _extract_or_repair(text)
    return (canonicalize(obj), text) if isinstance(obj, dict) else (None, text)

@torch.no_grad()
def generate_exam_creative(topics, difficulty, length, fmt, temperature=1.1, top_p=0.93, top_k=80, repetition_penalty=1.07):
    ctrl = make_controls(topics, difficulty, length, fmt)
    src = make_src(ctrl)
    enc = tok_inf(src, return_tensors="pt", truncation=True, max_length=MAX_SRC_LEN).to(model_inf.device)
    out = model_inf.generate(
        **enc,
        max_new_tokens=MAX_NEW_TOKENS,
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        no_repeat_ngram_size=3,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
        stopping_criteria=StoppingCriteriaList([BalancedJSONStop(tok_inf)]),
    )
    text = tok_inf.decode(out[0], skip_special_tokens=True).strip()
    obj  = _extract_or_repair(text)
    return (canonicalize(obj), text) if isinstance(obj, dict) else (None, text)

obj_s, raw_s = generate_exam_strict(["algebra","linear-equations"], "hard", 12, ["multiple_choice","open_answer"])
print("STRICT RAW (first 500):\n", raw_s[:500], "\nSTRICT PARSED:", "OK" if obj_s else "FAIL")

obj_c, raw_c = generate_exam_creative(["algebra","polynomials"], "medium", 12, ["multiple_choice","open_answer"], temperature=1.1)
print("\nCREATIVE RAW (first 500):\n", raw_c[:500], "\nCREATIVE PARSED:", "OK" if obj_c else "FAIL")
if obj_c:
    print("\nCREATIVE JSON keys:", list(obj_c.keys()))