# GeoRAG — Fine-tune Mistral-7B on Planetary QA Data

This notebook runs the full pipeline on Google Colab (T4 or A100):
1. Install deps
2. Upload your QA dataset
3. Fine-tune with QLoRA
4. Test the model
5. Download the LoRA adapter weights

In [None]:
# check what GPU we got
!nvidia-smi --query-gpu=name,memory.total --format=csv,noheader

In [None]:
!pip install -q \
    torch>=2.1.0 \
    transformers>=4.38.0 \
    peft>=0.9.0 \
    bitsandbytes>=0.43.0 \
    accelerate>=0.27.0 \
    datasets>=2.18.0 \
    sentence-transformers>=2.5.0 \
    sentencepiece>=0.1.99 \
    wandb>=0.16.0 \
    tqdm

In [None]:
# (optional) log in to wandb — skip if you don't want experiment tracking
# import wandb
# wandb.login()

## Upload QA dataset

Upload `qa_train.jsonl` and `qa_test.jsonl` from `georag/outputs/`.
Or just run the cell below to upload via Colab's file picker.

In [None]:
from google.colab import files
import os

os.makedirs("georag_data", exist_ok=True)

print("Upload qa_train.jsonl and qa_test.jsonl")
uploaded = files.upload()
for name, data in uploaded.items():
    with open(f"georag_data/{name}", "wb") as f:
        f.write(data)
    print(f"  saved {name} ({len(data)} bytes)")

In [None]:
!wc -l georag_data/*.jsonl

## Config

In [None]:
import json
from pathlib import Path

# model
MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"

# lora
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
LORA_TARGET_MODULES = [
    "q_proj", "k_proj", "v_proj", "o_proj",
    "gate_proj", "up_proj", "down_proj",
]

# training
EPOCHS = 3
BATCH_SIZE = 4
GRAD_ACCUM = 8           # effective batch = 32
LR = 2e-4
WARMUP_RATIO = 0.06
MAX_SEQ_LEN = 1024
PATIENCE = 3

# paths
TRAIN_FILE = Path("georag_data/qa_train.jsonl")
TEST_FILE = Path("georag_data/qa_test.jsonl")
OUTPUT_DIR = Path("georag_lora_weights")
OUTPUT_DIR.mkdir(exist_ok=True)

# wandb (set to False to skip)
USE_WANDB = False
WANDB_PROJECT = "georag"
WANDB_RUN = "mistral-7b-lora-r16-colab"

print("config loaded")

## Dataset

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

SYSTEM_PROMPT = (
    "You are GeoRAG, an expert assistant for NASA planetary science. "
    "Answer questions about surface features on the Moon, Mars, Mercury, "
    "and other celestial bodies using precise nomenclature data."
)

def format_prompt(question, context=""):
    parts = [f"[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<</SYS>>\n\n"]
    if context:
        parts.append(f"Context:\n{context}\n\n")
    parts.append(f"{question} [/INST]")
    return "".join(parts)


class QADataset(Dataset):
    def __init__(self, path, tokenizer, max_len=MAX_SEQ_LEN):
        self.samples = [json.loads(line) for line in open(path)]
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item = self.samples[idx]
        prompt = format_prompt(item["question"])
        full = f"{prompt} {item['answer']}{self.tokenizer.eos_token}"

        enc = self.tokenizer(full, max_length=self.max_len,
                             padding="max_length", truncation=True,
                             return_tensors="pt")
        input_ids = enc["input_ids"].squeeze(0)
        attn_mask = enc["attention_mask"].squeeze(0)

        # only compute loss on the answer tokens
        prompt_len = self.tokenizer(prompt, max_length=self.max_len,
                                    truncation=True, return_tensors="pt"
                                    )["input_ids"].shape[1]
        labels = input_ids.clone()
        labels[:prompt_len] = -100
        labels[attn_mask == 0] = -100

        return {"input_ids": input_ids, "attention_mask": attn_mask,
                "labels": labels}

print("dataset class ready")

## Load model + LoRA

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, TaskType

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=quant_config,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    target_modules=LORA_TARGET_MODULES,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
# build dataloaders
train_ds = QADataset(TRAIN_FILE, tokenizer)
val_ds = QADataset(TEST_FILE, tokenizer)
print(f"train: {len(train_ds)}, val: {len(val_ds)}")

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          drop_last=True, pin_memory=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        pin_memory=True, num_workers=2)

## Training loop

In [None]:
from transformers import get_cosine_schedule_with_warmup
from tqdm.auto import tqdm

if USE_WANDB:
    import wandb
    wandb.init(project=WANDB_PROJECT, name=WANDB_RUN, config={
        "model": MODEL_NAME, "lora_r": LORA_R, "lora_alpha": LORA_ALPHA,
        "epochs": EPOCHS, "batch_size": BATCH_SIZE,
        "effective_batch": BATCH_SIZE * GRAD_ACCUM, "lr": LR,
    })

device = next(model.parameters()).device

optimizer = torch.optim.AdamW(
    (p for p in model.parameters() if p.requires_grad),
    lr=LR, weight_decay=0.01,
)

total_steps = (len(train_loader) // GRAD_ACCUM) * EPOCHS
warmup_steps = int(total_steps * WARMUP_RATIO)
scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps,
)

best_val_loss = float("inf")
patience_counter = 0
global_step = 0


def eval_loss():
    model.eval()
    total, n = 0.0, 0
    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            total += model(**batch).loss.item()
            n += 1
    return total / max(n, 1)


for epoch in range(1, EPOCHS + 1):
    model.train()
    epoch_loss = 0.0
    optimizer.zero_grad()

    pbar = tqdm(train_loader, desc=f"epoch {epoch}/{EPOCHS}")
    for step, batch in enumerate(pbar, 1):
        batch = {k: v.to(device) for k, v in batch.items()}
        out = model(**batch)
        loss = out.loss / GRAD_ACCUM
        loss.backward()
        epoch_loss += out.loss.item()

        if step % GRAD_ACCUM == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1

            pbar.set_postfix(loss=f"{out.loss.item():.4f}",
                             lr=f"{scheduler.get_last_lr()[0]:.2e}")

            if USE_WANDB:
                wandb.log({"train/loss": out.loss.item(),
                           "train/lr": scheduler.get_last_lr()[0],
                           "global_step": global_step})

    avg_train = epoch_loss / len(train_loader)
    vloss = eval_loss()
    print(f"epoch {epoch}/{EPOCHS}  train_loss={avg_train:.4f}  val_loss={vloss:.4f}")

    if USE_WANDB:
        wandb.log({"epoch": epoch, "train/epoch_loss": avg_train, "val/loss": vloss})

    # early stopping
    if vloss < best_val_loss:
        best_val_loss = vloss
        patience_counter = 0
        model.save_pretrained(str(OUTPUT_DIR))
        tokenizer.save_pretrained(str(OUTPUT_DIR))
        print(f"  saved best model -> {OUTPUT_DIR}")
    else:
        patience_counter += 1
        print(f"  no improvement ({patience_counter}/{PATIENCE})")
        if patience_counter >= PATIENCE:
            print("  early stopping.")
            break

if USE_WANDB:
    wandb.finish()

print(f"\ndone. best val loss = {best_val_loss:.4f}")

## Quick test

In [None]:
# reload the best checkpoint
from peft import PeftModel

model.eval()

test_questions = [
    "What type of feature is Tycho?",
    "Where is Olympus Mons located?",
    "Is Caloris Planitia on Mars?",
    "Name some valleys on Mars.",
]

for q in test_questions:
    prompt = format_prompt(q)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True,
                       max_length=MAX_SEQ_LEN).to(device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=200,
                             temperature=0.3, top_p=0.9,
                             repetition_penalty=1.15, do_sample=True)
    answer = tokenizer.decode(out[0][inputs["input_ids"].shape[1]:],
                              skip_special_tokens=True)
    print(f"Q: {q}")
    print(f"A: {answer.strip()}")
    print()

## Download weights

Download the LoRA adapter to use locally with the RAG pipeline and eval harness.

In [None]:
# zip and download
!cd georag_lora_weights && zip -r /content/georag_lora_weights.zip .

from google.colab import files
files.download("/content/georag_lora_weights.zip")
print("download started — put the zip contents into georag/outputs/georag-mistral-lora/")