# 🧠 MedChatGuard - PeFT Fine-Tuning in Colab
Fine-tune a QA-model (`deepset/roberta-base-squad2`) on synthetic EHR data using QLoRA.


### Install Dependencies

In [None]:
!pip install transformers datasets evaluate accelerate

### Load SQuAD-style dataset from Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

DATA_PATH = "/content/drive/MyDrive/Colab Notebooks/FineTuning/ehr_squad_format.json"

In [1]:
# Case where not from G-Drive
DATA_PATH = "../data/finetune/ehr_squad_format.json"

### Load Dataset and Tokenizer

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

dataset = load_dataset("json", data_files=DATA_PATH, field="data")
tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2")

  from .autonotebook import tqdm as notebook_tqdm


### Preprocessing Function

In [3]:
dataset["train"]

Dataset({
    features: ['title', 'paragraphs'],
    num_rows: 578
})

In [4]:
def preprocess_examples(example):
    context = example["paragraphs"][0]["context"]
    question = example["paragraphs"][0]["qas"][0]["question"]
    answers = example["paragraphs"][0]["qas"][0]["answers"]
    answer = answers[0]  # assuming single answer

    # Tokenize with offsets
    encoding = tokenizer(
        question,
        context,
        truncation="only_second",
        max_length=384,
        stride=128,
        return_overflowing_tokens=False,
        return_offsets_mapping=True,
        padding="max_length"
    )

    # Extract answer char positions
    start_char = answer["answer_start"]
    end_char = start_char + len(answer["text"])

    # Find token span that matches the character span
    offsets = encoding["offset_mapping"]
    start_pos, end_pos = 0, 0
    for idx, (start, end) in enumerate(offsets):
        if start <= start_char < end:
            start_pos = idx
        if start < end_char <= end:
            end_pos = idx
            break

    encoding["start_positions"] = start_pos
    encoding["end_positions"] = end_pos
    encoding.pop("offset_mapping")

    return encoding

tokenized = dataset["train"].map(preprocess_examples, remove_columns=dataset["train"].column_names)


### Load Model and Tokenizer

In [6]:
from transformers import TrainingArguments, Trainer, default_data_collator

training_args = TrainingArguments(
    output_dir="./roberta-qa-checkpoints",
    # evaluation_strategy="epoch",
    save_strategy="epoch",
    num_train_epochs=3,
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized,
    # eval_dataset=tokenized["validation"] if "validation" in tokenized else None,
    tokenizer=tokenizer,
    data_collator=default_data_collator
)


  trainer = Trainer(


### Train

In [7]:
trainer.train()


Step,Training Loss
10,3.1181
20,2.2707
30,2.0596
40,2.1351
50,2.2473
60,2.0256
70,1.8983
80,1.9738
90,1.9736
100,1.9655


TrainOutput(global_step=219, training_loss=2.0241574239513103, metrics={'train_runtime': 2244.8828, 'train_samples_per_second': 0.772, 'train_steps_per_second': 0.098, 'total_flos': 339816432135168.0, 'train_loss': 2.0241574239513103, 'epoch': 3.0})

### Save Model

In [8]:
# SAVE_PATH = "/content/drive/MyDrive/Colab Notebooks/FineTuning/roberta_qa_finetuned"
SAVE_PATH = "../models/finetuned_model/roberta-base-squad2"

trainer.save_model(SAVE_PATH)
tokenizer.save_pretrained(SAVE_PATH)

('../models/finetuned_model/roberta-base-squad2\\tokenizer_config.json',
 '../models/finetuned_model/roberta-base-squad2\\special_tokens_map.json',
 '../models/finetuned_model/roberta-base-squad2\\vocab.json',
 '../models/finetuned_model/roberta-base-squad2\\merges.txt',
 '../models/finetuned_model/roberta-base-squad2\\added_tokens.json',
 '../models/finetuned_model/roberta-base-squad2\\tokenizer.json')