# Model fine tune: beyond prompting only

Now we are going to fine tune a llama model using the train and test datasets we have created.

1. Load the datasets for training and evaluation
2. Define the model and tokenizer
3. Set up the training configuration using `peft` for LoRA
4. Train the model using the `Trainer` class from `transformers`
5. Save the trained model and tokenizer

# The finetune dataset

We have a dataset pushed to hugging face hub with the next structure:

You are an expert AI specializing in multiple-choice questions.

```
Your task is to analyze the provided context, question, and options, then identify the single best answer.\nRespond with only the capital letter (A, B, C, or D) corresponding to your choice.

Context:

{{context}}

Question:

{{question}}

Options:

A) {{options[0]}}
B) {{options[1]}}
C) {{options[2]}}
D) {{options[3]}}

Answer: B
```


In [25]:
import os
import random
import json
from typing import Callable

import torch
from datasets import load_dataset, DatasetDict
from datasets import Dataset as HFDataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig, TaskType
from torch.utils.data import Dataset

os.environ["WANDB_DISABLED"] = "true" # Disable Weights & Biases logging because it is not needed for this task

In [19]:
# reproducibility
random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

In [36]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
questions = DatasetDict.load_from_disk("maia-pln-2025/pubmed_QA")
eval_questions = DatasetDict.load_from_disk("maia-pln-2025/pubmed_QA_test_questions")

In [54]:
class TrainAndEval(Dataset):
    """
    This loads a map which contains:
    - "id"
    - "excerpt"
    - "question"
    - "statement": the correct option
    - "distractors"
    """
    def __init__(self, file_path: str):
        self.file_path = file_path
        self._raw_data = []
        with open(file_path, "r") as f:
            for line in f:
                self._raw_data.append(line.strip())

    def __len__(self):
        return len(self._raw_data)

    def __getitem__(self, idx: int):
        return json.loads(self._raw_data[idx])

class EvalWithAnswers(Dataset):
    """
    A dataset that takes a TrainAndEval dataset and adds the statement to the distractors
    to create a multiple choice question. The statement is inserted at a random position
    in the distractors.

    Returns two item keys: options and answer_idx.
    """
    def __init__(self, dataset: TrainAndEval):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx: int):
        item = self.dataset[idx]

        options = item["distractors"]
        # insert the statement at any random position in the options
        index = random.randint(0, len(options))
        options.insert(index, item["statement"])

        item["options"] = options
        item["answer_idx"] = index

        return item

OPTIONS =  ['A', 'B', 'C', 'D']

class Prompted(Dataset):
    def __init__(
            self,
            dataset: EvalWithAnswers,
            prompter: Callable,
            options: list = OPTIONS,
    ):
        self.dataset = dataset
        self.prompter = prompter
        self.options = options

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        item = self.dataset[item]
        answer = item["answer_idx"]
        options = item["options"]
        context = item["excerpt"]
        question = item["question"]

        item["text"] = self.prompter(question, context, options, index_to_answer(answer, self.options))

        return item

class Tokenized(Dataset):
    """
    A dataset that takes a TrainAndEval dataset and adds the statement to the distractors
    to create a multiple choice question. The statement is inserted at a random position
    in the distractors.
    """

    def __init__(self, tokenizer, dataset: Prompted, max_length: int = 1200):
        self.dataset = dataset
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx: int):
        item = self.dataset[idx]

        # Tokenize the input text and mask tokens before the answer_start_text
        tokenized = self.tokenizer(
            item["text"],
            return_tensors="pt",
            truncation=True,
            padding="max_length",
            max_length=1200,
            return_attention_mask=True,
            truncation_strategy="longest_first",  # Safe fallback
        )

        item["input_ids"] = tokenized["input_ids"][0]  # Remove batch dimension
        item["attention_mask"] = tokenized["attention_mask"][0]
        item["labels"] = tokenized["input_ids"][0]


        return item

def index_to_answer(index: int, options: list) -> str:
    """
    Convert an index to an answer string.
    """
    return options[index]


def to_transformers_dataset(dataset: Prompted) -> HFDataset:
    data = [dataset[i] for i in range(len(dataset))]
    hf_dataset = HFDataset.from_list(data)

    return hf_dataset

def generate_prompt(
    question: str,
    context: str,
    options: list[str],
    answer: str = None
) -> str:
    options = "\n".join([f"{chr(65 + i)}) {option}" for i, option in enumerate(options)])
    prompt = f"""You are an expert AI specializing in multiple-choice questions.
Your task is to analyze the provided context, question, and options, then identify the single best answer.
Respond with only the capital letter (A, B, C, or D) corresponding to your choice.

Context:
{context}

Question:
{question}

Options:
{options}

Answer: {answer}"""
    return prompt

In [55]:
eval_dataset = TrainAndEval("./data/pubmed_QA_eval.json")
eval_with_answers = EvalWithAnswers(eval_dataset)
eval_prompted = Prompted(eval_with_answers, prompter=generate_prompt)
eval_tokenized = Tokenized(
    tokenizer=tokenizer,
    dataset=eval_prompted,
    max_length=1200
)

In [56]:
train_dataset = TrainAndEval("./data/pubmed_QA_train.json")
train_with_answers = EvalWithAnswers(train_dataset)
train_prompted = Prompted(train_with_answers, prompter=generate_prompt)
train_tokenized = Tokenized(
    tokenizer=tokenizer,
    dataset=train_prompted,
    max_length=1200
)

In [57]:
eval_prompted[0]["text"] # check the first item in the eval dataset

'You are an expert AI specializing in multiple-choice questions.\nYour task is to analyze the provided context, question, and options, then identify the single best answer.\nRespond with only the capital letter (A, B, C, or D) corresponding to your choice.\n\nContext:\nTemporal changes in medial basal hypothalamic LH-RH correlated with plasma LH during the rat estrous cycle and following electrochemical stimulation of the medial preoptic area in pentobarbital-treated proestrous rats. In the present studies we have simultaneously measured changes in medial basal hypothalamic (MBH) leutenizing hormone-releasing hormone (LH-RH) and in plasma LH by radioimmunoassay in female rats at various hours during the 4-day estrous cycle and under experimental conditions known to alter pituitary LH secretion. In groups of rats decapitated at 12.00 h and 15.00 h on estrus and diestrus, plasma LH remained at basal levels (5-8 ng/ml) and MBH-LH-RH concentrations showed average steady state concentration

In [58]:
tokenizer.decode(eval_tokenized[0]["labels"], skip_special_tokens=True) # check the first item in the eval dataset after tokenization

'You are an expert AI specializing in multiple-choice questions.\nYour task is to analyze the provided context, question, and options, then identify the single best answer.\nRespond with only the capital letter (A, B, C, or D) corresponding to your choice.\n\nContext:\nTemporal changes in medial basal hypothalamic LH-RH correlated with plasma LH during the rat estrous cycle and following electrochemical stimulation of the medial preoptic area in pentobarbital-treated proestrous rats. In the present studies we have simultaneously measured changes in medial basal hypothalamic (MBH) leutenizing hormone-releasing hormone (LH-RH) and in plasma LH by radioimmunoassay in female rats at various hours during the 4-day estrous cycle and under experimental conditions known to alter pituitary LH secretion. In groups of rats decapitated at 12.00 h and 15.00 h on estrus and diestrus, plasma LH remained at basal levels (5-8 ng/ml) and MBH-LH-RH concentrations showed average steady state concentration

In [59]:
hf_dataset = DatasetDict({
        "train": to_transformers_dataset(train_tokenized),
        "eval":  to_transformers_dataset(eval_tokenized),
    }
)

In [60]:
len(hf_dataset["train"]), len(hf_dataset["eval"]) # check the length of the datasets

(16890, 5000)

In [None]:
# hf_dataset.push_to_hub("Claudia031/maia-pln-2025-training-v2") # push to hugging face hub

In [2]:
dataset = load_dataset("Claudia031/maia-pln-2025-training-v2") # claudia's dataset (under our team member account)

README.md:   0%|          | 0.00/654 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/39.4M [00:00<?, ?B/s]

eval-00000-of-00001.parquet:   0%|          | 0.00/11.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/16890 [00:00<?, ? examples/s]

Generating eval split:   0%|          | 0/5000 [00:00<?, ? examples/s]

In [10]:
print(tokenizer.decode(dataset["train"][0]["input_ids"], skip_special_tokens=True))

You are an expert AI specializing in multiple-choice questions.
Your task is to analyze the provided context, question, and options, then identify the single best answer.
Respond with only the capital letter (A, B, C, or D) corresponding to your choice.

Context:
The rate of action of calcium on the electrical and mechanical responses of the crayfish muscle fibers. The effects of sudden changes in external Ca concentration on the time courses of the changes in size of the action potential and of the associated contraction in a single crayfish muscle fiber were investigated. Procaine-HCl was added to the bathing solution to make the muscle fiber excitable. The concentration of the divalent cations (Ca and Mg) was high enough to keep the threshold potential constant. In Ca-free solution, neither action potential nor contraction was observed. When the external Ca concentration was suddenly increased from 0 to 14 mM, the full sized action potentials were generated within several seconds, b

# Lora fine tune: search for all modules that can be trained

We are going to fine tune the modules `['q_proj', 'v_proj', 'down_proj', 'o_proj', 'k_proj', 'up_proj', 'gate_proj']` which
are the linear layers in the model

In order to prevent overfitting we will use LoRA (Low-Rank Adaptation) to fine tune the model with dropout and weight decay.

Dropout is a regularization technique that helps prevent overfitting by randomly setting a fraction of the input units to zero during training

Other thing used to prevent overfitting is load the best model at the end of training, which is done by setting `load_best_model_at_end=True` in the `TrainingArguments`, in conjunction with metric `eval_loss`.

In [12]:
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    torch_dtype="auto",
    device_map="auto",
    use_cache=False,
)

In [13]:
import torch.nn as nn

def find_all_linear_names(model):
    cls = nn.Linear  # Use standard Linear layer
    linear_module_names = set()

    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            linear_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in linear_module_names:  # Optionally exclude output head
        linear_module_names.remove('lm_head')

    return list(linear_module_names)

modules = find_all_linear_names(model)
modules

['q_proj', 'v_proj', 'down_proj', 'o_proj', 'k_proj', 'up_proj', 'gate_proj']

In [14]:
# dataset len
len(dataset["train"]), len(dataset["eval"])

(16890, 5000)

In [62]:
steps_per_epoch = 2000
batch_size = 4

In [None]:
lora_config = LoraConfig(
    r=16,  # we have tested that with 16 and 8 but 16 is better for this model
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.05, # we have tested that with 0.05 and 0.1 but 0.05 is better for this model
    bias="lora_only",
    task_type=TaskType.CAUSAL_LM,
    init_lora_weights=True,  # Initialize LoRA weights to zero
    use_rslora=False,  # Standard LoRA (more predictable)
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

training_args = TrainingArguments(
    output_dir="./checkpoints",

    eval_strategy="steps",
    eval_steps=steps_per_epoch,
    save_strategy="steps",
    save_steps=steps_per_epoch,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,      # because lower eval_loss is better
    logging_strategy="steps",
    logging_steps=steps_per_epoch,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01, # prevent overfitting scheduling a weight decay, which is a regularization technique that helps prevent overfitting by penalizing large weights
    logging_dir="./logs",
    fp16=True,
    save_total_limit=1,
    report_to=None
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["eval"],
    tokenizer=tokenizer,
)

trainer.train()

In [None]:
trainer.save_model("./outputs/fine-tuning/trainer")
tokenizer.save_pretrained("./outputs/fine-tuning/tokenizer")