In [None]:
!pip -q install -U transformers accelerate peft trl datasets bitsandbytes sentencepiece

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import EarlyStoppingCallback

base_model = "mistralai/Mistral-Small-24B-Instruct-2501"
train_file = "/content/train_aml_full.jsonl"
val_file   = "/content/valid_aml_full.jsonl"

train_ds = load_dataset("json", data_files=train_file)["train"]
val_ds   = load_dataset("json", data_files=val_file)["train"]

# Convert your {input, output} -> TRL-standard {prompt, completion}
train_ds = train_ds.map(lambda x: {"prompt": x["input"], "completion": x["output"]},
                        remove_columns=["input", "output"])
val_ds   = val_ds.map(lambda x: {"prompt": x["input"], "completion": x["output"]},
                      remove_columns=["input", "output"])

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    base_model,
    use_fast=True,
    fix_mistral_regex=True,
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
)

peft_config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
)

model.config.pad_token_id = tokenizer.pad_token_id
model.generation_config.pad_token_id = tokenizer.pad_token_id

args = SFTConfig(
    output_dir="mistral24b-qlora-sft",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=5e-5,
    num_train_epochs=8,
    weight_decay=0.05,
    logging_steps=1,
    save_steps=10,
    eval_strategy="steps",
    eval_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    report_to="none",
    bf16=True,
    gradient_checkpointing=True,

    max_length=None,
    packing=False,
    optim="paged_adamw_8bit",
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
)

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    peft_config=peft_config,
    processing_class=tokenizer,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
# if this errors on your TRL version, tell me the error text
)

trainer.train()

adapter_dir = "mistral24b-qlora-adapter-final"
trainer.model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
print("Saved adapter to:", adapter_dir)

Loading checkpoint shards:   0%|          | 0/10 [00:00<?, ?it/s]

Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
10,0.7359,0.869061,1.662713,225378.0,0.772172
20,0.534,0.829042,1.391366,434147.0,0.775398
30,0.2741,0.954668,1.223634,639680.0,0.779898
40,0.1941,0.99206,1.160615,852217.0,0.783281


Saved adapter to: mistral24b-qlora-adapter-final


In [None]:
!du -sh mistral24b-qlora-adapter*
!ls -lh mistral24b-qlora-adapter* | head

722M	mistral24b-qlora-adapter-final
total 722M
-rw-r--r-- 1 root root 1.1K Dec 17 01:08 adapter_config.json
-rw-r--r-- 1 root root 706M Dec 17 01:08 adapter_model.safetensors
-rw-r--r-- 1 root root 1.6K Dec 17 01:08 chat_template.jinja
-rw-r--r-- 1 root root 5.2K Dec 17 01:08 README.md
-rw-r--r-- 1 root root  21K Dec 17 01:08 special_tokens_map.json
-rw-r--r-- 1 root root 194K Dec 17 01:08 tokenizer_config.json
-rw-r--r-- 1 root root  17M Dec 17 01:08 tokenizer.json


In [None]:
!zip -r mistral24b-qlora-adapter-final.zip mistral24b-qlora-adapter-final

  adding: mistral24b-qlora-adapter-final/ (stored 0%)
  adding: mistral24b-qlora-adapter-final/README.md (deflated 65%)
  adding: mistral24b-qlora-adapter-final/adapter_config.json (deflated 58%)
  adding: mistral24b-qlora-adapter-final/special_tokens_map.json (deflated 87%)
  adding: mistral24b-qlora-adapter-final/tokenizer_config.json (deflated 95%)
  adding: mistral24b-qlora-adapter-final/chat_template.jinja (deflated 55%)
  adding: mistral24b-qlora-adapter-final/adapter_model.safetensors (deflated 54%)
  adding: mistral24b-qlora-adapter-final/tokenizer.json (deflated 84%)
