# AALM — QLoRA Fine-tune on openai/gpt-oss-20b
Fine-tune GPT‑OSS‑20B on the Australian legal QA dataset using TRL + PEFT (QLoRA).

In [None]:
%pip -q install -U datasets transformers accelerate peft bitsandbytes trl evaluate

In [None]:
import os, torch, json
from dataclasses import dataclass
from typing import Dict, List
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig

BASE_MODEL = os.environ.get('BASE_MODEL', 'openai/gpt-oss-20b')
DATASET = os.environ.get('HF_DATASET', 'isaacus/open-australian-legal-qa')
OUTPUT_DIR = os.environ.get('OUTPUT_DIR', 'outputs/aalm-gpt-oss-20b-qlora')
MAX_STEPS = int(os.environ.get('MAX_STEPS', '300'))  # start small to sanity-check
PER_DEVICE_TRAIN_BATCH_SIZE = int(os.environ.get('BATCH_SIZE', '1'))
GRAD_ACCUM = int(os.environ.get('GRAD_ACCUM', '32'))
MAX_SEQ_LENGTH = int(os.environ.get('MAX_SEQ_LENGTH', '2048'))
LEARNING_RATE = float(os.environ.get('LR', '2e-4'))
WARMUP_RATIO = float(os.environ.get('WARMUP_RATIO', '0.03'))
SAVE_STEPS = int(os.environ.get('SAVE_STEPS', '100'))
LOGGING_STEPS = int(os.environ.get('LOGGING_STEPS', '10'))
SEED = int(os.environ.get('SEED', '42'))
USE_BF16 = True  # recommended on Ampere+
LOAD_IN_4BIT = True  # QLoRA
GRAD_CHECKPOINTING = True


## Load and inspect dataset

In [None]:
ds = load_dataset(DATASET)
display(ds)
example = ds['train'][0]
example


## Tokenizer and chat template
GPT‑OSS expects the repo's chat template (harmony). We will format samples using `apply_chat_template`.

In [None]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
has_chat = isinstance(getattr(tokenizer, 'chat_template', None), str)
has_chat


## Map dataset to formatted text

In [None]:
SYSTEM_PROMPT = (
    'You are AALM, the Australian Administrative Law Model. '
    'Answer legal questions about Australian law and cases accurately and concisely. '
    'Cite the source when appropriate and never fabricate citations.'
)

def format_sample(sample: Dict) -> str:
    q = sample.get('question')
    a = sample.get('answer')
    if not q or not a:
        t = sample.get('text')
        if t:
            return t
        raise ValueError('Sample missing question/answer/text fields')
    if has_chat and hasattr(tokenizer, 'apply_chat_template'):
        messages = [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': q},
            {'role': 'assistant', 'content': a},
        ]
        return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    return f'{SYSTEM_PROMPT}

Question: {q}
Answer: {a}'

def map_fn(examples):
    first_key = next(iter(examples.keys()))
    n = len(examples[first_key])
    texts = []
    for i in range(n):
        texts.append(format_sample({
            'question': examples.get('question', [None]*n)[i],
            'answer': examples.get('answer', [None]*n)[i],
            'text': examples.get('text', [None]*n)[i],
        }))
    return {'text': texts}

train_ds = ds['train'].map(map_fn, batched=True, remove_columns=ds['train'].column_names)
train_ds = train_ds.shuffle(seed=SEED)
train_ds[:2]


## Load 4‑bit base and attach LoRA

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=LOAD_IN_4BIT,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if USE_BF16 else torch.float16,
)
dtype = torch.bfloat16 if USE_BF16 else torch.float16
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    torch_dtype=dtype,
    device_map='auto',
)
model_type = getattr(getattr(model, 'config', None), 'model_type', '')
model_type


In [None]:
def pick_lora_targets(model_type: str) -> List[str]:
    mt = (model_type or '').lower()
    if mt == 'gpt_oss':
        return ['q_proj','k_proj','v_proj','o_proj']
    if mt == 'gpt_neox':
        return ['query_key_value','dense','dense_h_to_4h','dense_4h_to_h']
    if mt in {'llama','mistral','mixtral'}:
        return ['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj']
    return ['q_proj','k_proj','v_proj','o_proj']

lora_cfg = LoraConfig(
    r=64, lora_alpha=16, lora_dropout=0.05, bias='none', task_type='CAUSAL_LM',
    target_modules=pick_lora_targets(model_type),
)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()


## Train

In [None]:
training_args = SFTConfig(
    output_dir=OUTPUT_DIR,
    max_steps=MAX_STEPS,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type='cosine',
    warmup_ratio=WARMUP_RATIO,
    weight_decay=0.0,
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    bf16=USE_BF16, fp16=not USE_BF16,
    seed=SEED,
    max_seq_length=MAX_SEQ_LENGTH,
    packing=False,
    gradient_checkpointing=GRAD_CHECKPOINTING,
    do_eval=False,
    report_to=['none'],
)
trainer = SFTTrainer(
    model=model, tokenizer=tokenizer, train_dataset=train_ds, dataset_text_field='text', args=training_args
)
trainer.train()
trainer.model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print('Saved to', OUTPUT_DIR)
