# LoRA Fine-Tuning: lora_science_v1

This Colab-ready workflow fine-tunes `Qwen/Qwen2-0.5B-Instruct` with LoRA adapters on the 19-document offline dataset so we can benchmark the `lora_science_v1` experiment against the `rag_baseline_v1` control.

### 0. Runtime checklist
* Select a Colab runtime with GPU (T4 preferred) before running any code.
* Keep an eye on VRAM usage (~16 GB on T4). Reduce sequence length or increase gradient accumulation if memory errors appear.

In [None]:
!nvidia-smi

### 1. Install Python dependencies
We pin versions compatible with Transformers 4.45+ and TRL's SFTTrainer.

In [None]:
%%capture
!pip install --upgrade pip
!pip install --quiet "transformers>=4.45.0" "accelerate>=0.33.0" "datasets>=3.0.0" peft trl bitsandbytes sentencepiece evaluate huggingface_hub pynvml

### 2. Authenticate (optional but recommended)
Set Hugging Face and GitHub tokens if you plan to pull private assets or push adapters. Tokens are stored in-memory only for this runtime.

In [None]:
import os
import subprocess
from getpass import getpass

hf_token = os.environ.get("HF_TOKEN")
if hf_token is None:
    entered = getpass("Enter Hugging Face token (leave blank to skip): ")
    hf_token = entered.strip() or None
if hf_token:
    os.environ["HF_TOKEN"] = hf_token
    subprocess.run(
        ["huggingface-cli", "login", "--token", hf_token, "--add-to-git-credential"], check=False
    )
else:
    print("Skipping Hugging Face login.")

if os.environ.get("GITHUB_TOKEN") is None:
    gh_token = getpass("Enter GitHub token for private repo access (leave blank to skip): ")
    if gh_token.strip():
        os.environ["GITHUB_TOKEN"] = gh_token.strip()
        print("Stored GitHub token in this session.")
    else:
        print("Skipping GitHub token setup. Upload the dataset manually if download fails.")

### 3. (Optional) Mount Google Drive
Use Drive if you want automatic persistence of adapters, logs, or dataset snapshots.

In [None]:
USE_DRIVE = True  # flip to True if you want to mount Drive
if USE_DRIVE:
    from google.colab import drive

    drive.mount("/content/drive")
    BASE_DIR = "/content/drive/MyDrive/beyond-the-cutoff"
else:
    BASE_DIR = "/content"
print(f"Working directory base: {BASE_DIR}")

### 4. Retrieve the offline dataset
We mirror `evaluation/datasets/offline_dataset.jsonl`. If the repo is private and no token is provided, upload the file manually to `/content/data/offline_eval/offline_dataset.jsonl`.

In [None]:
import json
from pathlib import Path

DATA_DIR = Path(BASE_DIR) / "data" / "offline_eval"
DATA_DIR.mkdir(parents=True, exist_ok=True)
DATASET_PATH = DATA_DIR / "offline_dataset.jsonl"

if DATASET_PATH.exists():
    print(f"Dataset already available: {DATASET_PATH}")
else:
    import requests

    headers = {}
    github_token = os.environ.get("GITHUB_TOKEN")
    if github_token:
        headers["Authorization"] = f"token {github_token}"
    url = "https://raw.githubusercontent.com/ignaciolinari/beyond-the-cutoff/main/evaluation/datasets/offline_dataset.jsonl"
    response = requests.get(url, headers=headers, timeout=60)
    if response.status_code == 200:
        DATASET_PATH.write_text(response.text, encoding="utf-8")
        print(f"Downloaded dataset to {DATASET_PATH}")
    else:
        raise RuntimeError(
            "Failed to download dataset. Upload the file manually and rerun this cell."
        )

### 5. Create deterministic train/val/test splits
We stratify by paper and task type using seed `20251101` to match the model adaptation plan.

In [None]:
import random
from collections import defaultdict


def load_examples(path: Path) -> list[dict]:
    raw = path.read_text(encoding="utf-8").strip().splitlines()
    return [json.loads(line) for line in raw if line]


def extract_paper_id(example: dict) -> str:
    meta = example.get("metadata") or {}
    if isinstance(meta, dict):
        candidate = meta.get("source_path") or meta.get("paper_id")
        if candidate:
            return Path(str(candidate)).stem
    sources = example.get("sources") or []
    if sources:
        return Path(str(sources[0])).stem
    rag = example.get("rag") or {}
    retrieved = rag.get("retrieved") or []
    if retrieved:
        first = retrieved[0]
        if isinstance(first, dict) and first.get("source_path"):
            return Path(str(first["source_path"])).stem
    return "unknown"


examples = load_examples(DATASET_PATH)
print(f"Loaded {len(examples)} examples")

groups = defaultdict(list)
for example in examples:
    key = (extract_paper_id(example), example.get("task_type"))
    groups[key].append(example)

buckets = list(groups.values())
rng = random.Random(20251101)
rng.shuffle(buckets)

target_counts = {
    "train": int(round(0.70 * len(examples))),
    "val": int(round(0.15 * len(examples))),
}
target_counts["test"] = len(examples) - target_counts["train"] - target_counts["val"]
splits = {"train": [], "val": [], "test": []}

for bucket in buckets:
    remaining = {split: target_counts[split] - len(splits[split]) for split in splits}
    target_split = max(remaining, key=lambda split: remaining[split])
    if remaining[target_split] <= 0:
        target_split = "train"
    splits[target_split].extend(bucket)

SPLIT_DIR = DATA_DIR / "splits"
SPLIT_DIR.mkdir(parents=True, exist_ok=True)
for split, rows in splits.items():
    out_path = SPLIT_DIR / f"lora_science_v1_{split}.jsonl"
    with out_path.open("w", encoding="utf-8") as handle:
        for row in rows:
            handle.write(json.dumps(row, ensure_ascii=False) + "\n")
    print(f"{split:>5}: {len(rows):>2} examples → {out_path}")

### 6. Load dataset into Hugging Face `DatasetDict`
We keep the raw fields and delegate prompt construction to the trainer’s formatting function.

In [None]:
from datasets import Dataset, DatasetDict


def read_split(split: str) -> Dataset:
    path = SPLIT_DIR / f"lora_science_v1_{split}.jsonl"
    rows = []
    with path.open("r", encoding="utf-8") as handle:
        for line in handle:
            if line.strip():
                rows.append(json.loads(line))
    return Dataset.from_list(rows)


dataset = DatasetDict({split: read_split(split) for split in ["train", "val", "test"]})
dataset

### 7. Initialize tokenizer, model, and LoRA config
We keep the base model in float16 and target attention + MLP projections for LoRA adapters.

In [None]:
import torch
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig

model_name = "Qwen/Qwen2-0.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

supports_cuda = torch.cuda.is_available()

compute_capability = torch.cuda.get_device_capability(0)[0] if supports_cuda else None

prefer_bf16 = bool(supports_cuda and compute_capability is not None and compute_capability >= 8)

model_dtype = torch.bfloat16 if prefer_bf16 else torch.float16


def build_base_model() -> AutoModelForCausalLM:
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=model_dtype,
        device_map="auto",
        trust_remote_code=True,
    )
    base_model.config.use_cache = False
    return base_model


model = build_base_model()


lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)

print(f"Loaded model on {model.device} with dtype {model_dtype}")

### 8. Define prompt formatting and trainer
We stitch instruction + retrieval contexts into the user turn and keep the ground-truth response as the assistant turn.

In [None]:
from collections.abc import Mapping, Sequence

from trl import SFTConfig, SFTTrainer


def _get_batch_value(column, index):
    if column is None:
        return None
    if isinstance(column, Mapping):
        return {key: _get_batch_value(value, index) for key, value in column.items()}
    if isinstance(column, Sequence) and not isinstance(column, str | bytes):
        return column[index] if len(column) > index else None
    return column


def build_user_message(
    instruction: str,
    rag_entry: dict | None,
    contexts_fallback: Sequence[str] | None,
) -> str:
    contexts = []
    if rag_entry and isinstance(rag_entry, dict):
        contexts = rag_entry.get("contexts") or []
    if not contexts and contexts_fallback is not None:
        contexts = contexts_fallback
    processed = [str(ctx).strip() for ctx in (contexts or []) if ctx]
    context_block = "\n\n".join(processed[:6])
    parts = [instruction.strip()]
    if context_block:
        parts.append("Context:\n" + context_block)
    return "\n\n".join(parts)


def format_example(example: dict) -> dict[str, str]:
    instruction = (example["instruction"] or "").strip()
    rag_item = example.get("rag")
    contexts_item = example.get("contexts")
    user_content = build_user_message(instruction, rag_item, contexts_item)
    assistant_reply = (example.get("expected_response") or "").strip()
    messages = [
        {
            "role": "system",
            "content": (
                "You are a scientific research assistant who answers with concise, "
                "evidence-grounded prose and includes inline numeric citations like [1]."
            ),
        },
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": assistant_reply},
    ]
    rendered = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    # Ensure EOS token is added here as packing is False
    if tokenizer.eos_token:
        rendered += tokenizer.eos_token
    return {"text": rendered}


output_root = Path(BASE_DIR) / "outputs" / "lora_science_v1"
output_root.mkdir(parents=True, exist_ok=True)


model_dtype = globals().get("model_dtype", getattr(model, "dtype", torch.float16))
supports_cuda = torch.cuda.is_available()
prefer_bf16 = globals().get("prefer_bf16", supports_cuda and model_dtype == torch.bfloat16)
fp16_flag = supports_cuda and model_dtype == torch.float16
bf16_flag = supports_cuda and prefer_bf16
tokenizer.model_max_length = 1024


# Ensure tokenizer.eos_token is a string
if tokenizer.eos_token is None or not isinstance(tokenizer.eos_token, str):
    tokenizer.eos_token = "<|endoftext|>"  # Set a common EOS token if it's not already a string


# Make sure we do not stack multiple adapters on the same model instance
if hasattr(model, "peft_config"):
    model = build_base_model()


# Apply formatting to the dataset before passing to SFTTrainer
formatted_dataset = dataset.map(format_example, remove_columns=dataset["train"].column_names)


training_args = SFTConfig(
    output_dir=str(output_root),
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=1e-4,
    weight_decay=0.01,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    logging_strategy="steps",
    logging_steps=10,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    gradient_checkpointing=True,
    fp16=fp16_flag,
    bf16=bf16_flag,
    max_grad_norm=1.0,
    report_to="none",
    packing=False,  # Keep packing=False as we formatted the text ourselves
)


trainer = SFTTrainer(
    model=model,
    train_dataset=formatted_dataset["train"],
    eval_dataset=formatted_dataset["val"],
    # Removed formatting_func as we pre-formatted the dataset
    peft_config=lora_config,
    args=training_args,
    # Removed tokenizer argument
)
trainer

### 9. Fine-tune the model
Training typically finishes in a few minutes on a T4 due to the compact dataset.

In [None]:
train_result = trainer.train()
train_result

#### Plot training loss curve
Use this after training to confirm the optimizer is behaving as expected.

In [None]:
# Plot the loss curve once training has been run
import matplotlib.pyplot as plt

loss_history = [
    (entry.get("step"), entry["loss"])
    for entry in trainer.state.log_history
    if "loss" in entry and entry.get("loss") is not None and entry.get("step") is not None
]

if not loss_history:
    print("No loss logs available yet. Run `trainer.train()` first.")
else:
    steps, losses = zip(*loss_history, strict=False)
    fig, ax = plt.subplots(figsize=(6, 4))
    ax.plot(steps, losses, marker="o", linewidth=1)
    ax.set_xlabel("Step")
    ax.set_ylabel("Loss")
    ax.set_title("Training Loss (per logged step)")
    ax.grid(alpha=0.3)
    plt.show()

### 10. Quick validation sanity check
Generate answers for a few validation samples to verify the adapter behaviour before exporting.

In [None]:
model.eval()


def preview_response(example: dict, max_new_tokens: int = 256) -> str:
    user_text = build_user_message(
        example["instruction"], example.get("rag"), example.get("contexts")
    )
    messages = [
        {
            "role": "system",
            "content": "You are a scientific research assistant who answers with citations.",
        },
        {"role": "user", "content": user_text},
    ]
    prompt_text = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    encoded = tokenizer(
        prompt_text,
        return_tensors="pt",
        truncation=True,
        max_length=tokenizer.model_max_length,
        return_attention_mask=True,
    )
    encoded = {key: value.to(model.device) for key, value in encoded.items()}
    if encoded["input_ids"].shape[-1] == tokenizer.model_max_length:
        print(
            "Prompt truncated to fit model_max_length; consider tightening contexts if this occurs often."
        )
    gen_config = GenerationConfig(
        do_sample=False,
        max_new_tokens=max_new_tokens,
        temperature=0.2,
        top_p=0.9,
    )
    with torch.no_grad():
        generated = model.generate(**encoded, generation_config=gen_config)
    response_ids = generated[0, encoded["input_ids"].shape[-1] :]
    return tokenizer.decode(response_ids, skip_special_tokens=True).strip()


for example in dataset["val"].select(range(min(3, len(dataset["val"])))):
    print("Instruction:", example["instruction"])
    print("Ground truth:", example.get("expected_response", "").strip())
    print("Model output:", preview_response(example))
    print("-" * 80)

### 11. Save adapters and tokenizer
Store artifacts under `outputs/adapters/lora_science_v1`. Upload to Drive or remote storage as needed.

In [None]:
adapter_dir = output_root / "adapters"
adapter_dir.mkdir(parents=True, exist_ok=True)
adapter_path = adapter_dir / "lora_science_v1"
trainer.model.save_pretrained(adapter_path)
tokenizer.save_pretrained(adapter_path)
print(f"Saved LoRA adapter to {adapter_path}")

### 12. Package artifacts
Compress the adapter, training args, and logs for upload back to the repo or Drive.

In [None]:
import shutil

archive_stem = output_root.parent / "lora_science_v1_artifacts"
zip_target = archive_stem.with_suffix(".zip")
if zip_target.exists():
    zip_target.unlink()
zip_path = shutil.make_archive(str(archive_stem), "zip", root_dir=output_root)
print(f"Packed artifacts at {zip_path}")

### 13. Next steps
1. **Persist training metadata** &rightarrow; run the next cell to capture seed `20251101`, the key hyperparameters, and recent metrics in `outputs/adapters/lora_science_v1/EXPERIMENT_METADATA.json`.
2. **Materialize merged weights** &rightarrow; execute the following cell to merge the LoRA adapter into the base model and save a ready-to-quantize checkpoint under `outputs/lora_science_v1/merged_full_model`.
3. **Quantize for Ollama** &rightarrow; use the provided CLI snippet to convert the merged weights to GGUF via `llama.cpp`'s `convert-hf-to-gguf.py`, then register the artifact with Ollama.
4. **Re-benchmark** &rightarrow; rerun `scripts/evaluate_models.py` with the new Ollama tag to compare against `rag_baseline_v1`.

In [None]:
# Persist experiment metadata for reproducibility
import json
from datetime import datetime

import peft
import transformers
import trl
from packaging.version import Version

metadata_target = output_root / "adapters" / "lora_science_v1" / "EXPERIMENT_METADATA.json"
metadata_target.parent.mkdir(parents=True, exist_ok=True)

training_summary = {
    "seed": 20251101,
    "timestamp_utc": datetime.utcnow().isoformat(timespec="seconds") + "Z",
    "base_model": model_name,
    "adapter_dir": str(output_root / "adapters" / "lora_science_v1"),
    "output_dir": str(output_root),
    "hyperparameters": {
        "num_train_epochs": training_args.num_train_epochs,
        "per_device_train_batch_size": training_args.per_device_train_batch_size,
        "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
        "learning_rate": training_args.learning_rate,
        "weight_decay": training_args.weight_decay,
        "lr_scheduler_type": training_args.lr_scheduler_type,
        "warmup_ratio": training_args.warmup_ratio,
        "max_seq_length": 1024,
    },
    "optimizer_steps": trainer.state.global_step,
    "train_loss": next(
        (entry.get("loss") for entry in reversed(trainer.state.log_history) if "loss" in entry),
        None,
    ),
    "final_metrics": train_result.metrics
    if "train_result" in locals() and train_result is not None
    else {},
    "libraries": {
        "transformers": Version(transformers.__version__).base_version,
        "trl": Version(trl.__version__).base_version,
        "peft": Version(peft.__version__).base_version,
    },
}

metadata_target.write_text(
    json.dumps(training_summary, indent=2, sort_keys=True) + "\n", encoding="utf-8"
)
print(f"Wrote metadata to {metadata_target}")

In [None]:
# Merge the LoRA adapter back into the base model and save full weights
from peft import AutoPeftModelForCausalLM

merged_output_dir = output_root / "merged_full_model"
merged_output_dir.mkdir(parents=True, exist_ok=True)

print("Merging adapter into the base model…")
merge_dtype = globals().get("model_dtype", torch.float16)
merged_model = AutoPeftModelForCausalLM.from_pretrained(
    adapter_path,
    torch_dtype=merge_dtype,
    device_map="auto",
    trust_remote_code=True,
).merge_and_unload()

merged_model.to("cpu")
merged_model.save_pretrained(merged_output_dir)
tokenizer.save_pretrained(merged_output_dir)
print(f"Merged checkpoint written to {merged_output_dir}")

#### Quantize and evaluate
Run the commands below from your local checkout once the merged checkpoint is synced down:
```bash
# Convert merged HF weights to GGUF (example with Q4_0 quantization)
python /path/to/llama.cpp/convert-hf-to-gguf.py \
  --model-dir outputs/lora_science_v1/merged_full_model \
  --outfile outputs/lora_science_v1/merged_full_model/Qwen2-0.5B-lora_science_v1.Q4_0.gguf \
  --data-type q4_0

# Register a new Ollama model tag
ollama create qwen2-lora-science -f ollama/Modelfile
ollama push qwen2-lora-science

# Re-run the offline evaluation with the new tag
python scripts/evaluate_models.py \
  --model-tag qwen2-lora-science \
  --preset configs/retrieval_presets/qa_default.yaml \
  --output-dir evaluation/results/lora_science_v1```

### 14. Re-package artifacts
Create a fresh archive after merging and metadata updates so downstream steps stay in sync.

In [None]:
# Rebuild the archive to include merged weights and metadata
import shutil

archive_stem = output_root.parent / "lora_science_v1_artifacts_postmerge"
zip_target = archive_stem.with_suffix(".zip")
if zip_target.exists():
    zip_target.unlink()
zip_path = shutil.make_archive(str(archive_stem), "zip", root_dir=output_root)
print(f"Packed updated artifacts at {zip_path}")