# AALM â€” Inference and Merge
Use the trained LoRA adapter for inference, and optionally merge it into the base weights.

In [None]:
%pip -q install -U transformers peft bitsandbytes

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

BASE_MODEL = 'openai/gpt-oss-20b'
ADAPTER_DIR = 'outputs/aalm-gpt-oss-20b-qlora'  # set to your path
USE_BF16 = True
SYSTEM = 'You are AALM, the Australian Administrative Law Model.'


In [None]:
dtype = torch.bfloat16 if USE_BF16 else torch.float16
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_DIR)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
base = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=dtype, device_map='auto')
model = PeftModel.from_pretrained(base, ADAPTER_DIR)
model.eval()
has_chat = isinstance(getattr(tokenizer, 'chat_template', None), str)
has_chat


In [None]:
def generate(question: str, system: str = SYSTEM, max_new_tokens: int = 512):
    if has_chat and hasattr(tokenizer, 'apply_chat_template'):
        messages = [
            {'role': 'system', 'content': system},
            {'role': 'user', 'content': question},
        ]
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    else:
        text = f'{system}

Question: {question}
Answer:'
    inputs = tokenizer(text, return_tensors='pt').to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, do_sample=True, temperature=0.7, top_p=0.9,
                             max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(out[0], skip_special_tokens=True)

print(generate('In New South Wales, when is procedural fairness required in administrative decision-making?'))


## Merge adapter into base (optional)

In [None]:
from peft import PeftModel
OUTPUT_MERGED = 'outputs/aalm-gpt-oss-20b-merged'
merged = model.merge_and_unload()
merged.save_pretrained(OUTPUT_MERGED)
tokenizer.save_pretrained(OUTPUT_MERGED)
print('Merged model saved to', OUTPUT_MERGED)
