# Age Range Conversation Trainer (Unsloth)

This notebook fine-tunes a 3B parameter instruction-tuned model on the synthetic age-labeled conversation dataset located in `../data/age_prediction_dataset.jsonl`. It is optimized for a single RTX 3060 12GB GPU using the [Unsloth](https://github.com/unslothai/unsloth) QLoRA workflow.

## Prerequisites
- Ubuntu 22.04 / CUDA 12.x runtime
- NVIDIA RTX 3060 (12GB VRAM) or similar
- Python 3.10+
- Latest NVIDIA driver with CUDA support
- `pip` with virtual environment recommended

In [None]:
# Install dependencies
!pip install -q unsloth[colab-new] datasets accelerate bitsandbytes peft transformers==4.39.3

In [None]:
# Configuration and dataset preparation
from pathlib import Path
from datasets import load_dataset
import transformers

DATA_PATH = Path('../data/age_prediction_dataset.jsonl').resolve()
assert DATA_PATH.exists(), f'Dataset not found at {DATA_PATH}'

raw_ds = load_dataset('json', data_files=str(DATA_PATH))['train']
SYSTEM_PROMPT = 'You are an analyst that infers the most likely age range of a speaker based on their written conversation. Respond only with the best matching age bucket.'

INSTRUCTION_PREFIX = 'Given the conversation, guess the most likely age range of the primary speaker.'

TARGET_LABELS = 'Options: 10-13, 14-17, 18-24, 25-34, 35-44, 45-54, 55-64, 65-75.'

def format_record(example):
    conversation = example['input']
    label = example['output']
    prompt = (
        f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<<SYS>>\n"
        f"{INSTRUCTION_PREFIX}\n{TARGET_LABELS}\n\nConversation:\n{conversation}\n[/INST]"
        f"{label}\n"
    )
    return {'text': prompt}

train_dataset = raw_ds.map(format_record, remove_columns=raw_ds.column_names)
print(train_dataset[0]['text'][:400])

In [None]:
# Load base model with Unsloth
from unsloth import FastLanguageModel

max_seq_length = 1024
dtype = 'bfloat16'  # falls back to float16 on GPUs without bfloat16
load_in_4bit = True
base_model_name = 'unsloth/llama-3-3b-instruct'

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = base_model_name,
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

tokenizer.pad_token = tokenizer.eos_token

model = FastLanguageModel.get_peft_model(
    model,
    r = 16,
    target_modules = ['q_proj','k_proj','v_proj','o_proj','gate_proj','up_proj','down_proj'],
    lora_alpha = 16,
    lora_dropout = 0.05,
    bias = 'none',
    use_gradient_checkpointing = True,
    random_state = 42,
)
print(f'Loaded {base_model_name} with LoRA adapters ready.')

In [None]:
# Training configuration
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

output_dir = 'checkpoints/age-predictor'
training_args = TrainingArguments(
    output_dir = output_dir,
    num_train_epochs = 5,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 8,
    warmup_steps = 10,
    learning_rate = 2e-4,
    fp16 = True,
    bf16 = False,
    logging_steps = 5,
    save_strategy = 'epoch',
    report_to = 'none',
    optim = 'paged_adamw_32bit',
)

data_collator = DataCollatorForLanguageModeling(tokenizer = tokenizer, mlm = False)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = train_dataset,
    data_collator = data_collator,
)

trainer.train()

In [None]:
# Save adapter and tokenizer
from pathlib import Path
save_path = Path('checkpoints/age-predictor')
save_path.mkdir(parents=True, exist_ok=True)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f'Saved LoRA adapter and tokenizer to {save_path}')

In [None]:
# Quick sanity-check inference
import torch

example_text = (
    "A: I'm finalizing college applications while finishing debate season and squeezing in my part-time barista shifts.\n"
    "B: That's a lot.\n"
    "A: Yeah, but I just need to get through FAFSA forms and scholarship essays."
)

prompt = (
    f"<s>[INST] <<SYS>>\n{SYSTEM_PROMPT}\n<<SYS>>\n{INSTRUCTION_PREFIX}\n{TARGET_LABELS}\n\nConversation:\n"
    f"{example_text}\n[/INST]"
)

inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
with torch.cuda.amp.autocast():
    outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))