# Fine-tune models for replication and extension of "A Data-Centric Approach To Generate Faithful and High Quality Patient Summaries with Large Language Models"

## Install depedencies

In [None]:
!pip install -q transformers==4.46.1 \
            accelerate==0.34.2 \
            datasets==3.0.0 \
            peft==0.11.1 \
            trl==0.9.4 \
            rouge_score \
            bert_score \
            sacremoses \
            sacrebleu \
            evaluate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.0/10.0 MB[0m [31m63.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m324.4/324.4 kB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.3/474.3 kB[0m [31m43.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m251.6/251.6 kB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.7/226.7 kB[0m [31m22.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m7.0 MB/s[0m e

## Imports

In [None]:
from collections import defaultdict
from pathlib import Path
import shutil

import evaluate
import numpy as np
import torch
from datasets import load_dataset
from huggingface_hub import login
from rouge_score import rouge_scorer
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling
from peft import LoraConfig, PeftModel
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
import wandb
import json

from google.colab import files

## Config

In [None]:
# Set model and paths here
device = "cuda"
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
train = "/content/data/hallucinations_mimic_di_cleaned_improved.json"
val = "/content/data/hallucinations_mimic_di_validation_cleaned_improved.json"
test = "/content/data/hallucinations_mimic_di_validation_cleaned_improved.json"
adapter_path = "/content/mistralai_Mistral-7B-Instruct-v0.3_cleaned_improved_ft" # For evaluation
predictions = "" # For evaluation
training_data_used = "cleaned_improved"
output = f"./{model_name.replace('/','_')}_{training_data_used}_ft"

## Huggingface Login

In [None]:
login()

## Upload and load data

In [None]:
!mkdir -p data

uploaded = files.upload()

for filename in uploaded.keys():
    shutil.move(filename, f"data/{filename}")

# Load data
data = load_dataset("json", data_files={
    "train": train,
    "validation": val,
    "test": test
})

## Prompt formatter

In [None]:
# Prompts
instruction = "Summarize for the patient what happened during the hospital stay based on this doctor's note:\n"
response = "Summary for the patient:\n"

# Prompt format
def format_batch(batch):

    outputs = []
    for text, summary in zip(batch["text"], batch["summary"]):
        outputs.append(
            f"{instruction}"
            f"{text}\n\n"
            f"{response}"
            f"{summary}"
        )
    return outputs

## Tokenizer

In [None]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

## Collator

### For Llama 2/Mistral

In [None]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

### For Qwen

In [None]:
data_collator = DataCollatorForCompletionOnlyLM(
    instruction_template=instruction,
    response_template=response,
    tokenizer=tokenizer
    )

## Set up model

In [None]:
# Model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    )

## Setup LoRA & trainer

In [None]:
# LoRA setup
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # For Llama 2 replication/Mistral
#     target_modules = [
#     "q_proj",
#     "k_proj",
#     "v_proj",
#     "o_proj",
#     "gate_proj",
#     "up_proj",
#     "down_proj",
# ], # For Qwen
    task_type="CAUSAL_LM"
)

# Trainer setup
training_args = SFTConfig(
    output_dir=output,
    max_seq_length=4096,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    num_train_epochs=20,
    logging_steps=50,
    eval_strategy="steps",
    save_strategy="steps",
    save_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    learning_rate=2e-5,
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=data["train"],
    eval_dataset=data["validation"],
    formatting_func=format_batch,
    data_collator=data_collator,
    args=training_args,
    peft_config=lora_config
)


## Run trainer and save model and tokenizer

In [None]:
trainer.train()

trainer.model.save_pretrained(output)
tokenizer.save_pretrained(output)

### Optionally zip for download

In [None]:
!zip -r Qwen_Qwen2.5-7B-Instruct_cleaned_improved_ft.zip Qwen_Qwen2.5-7B-Instruct_cleaned_improved_ft


# Helper functions

Adapted from Hegselmann et al

In [None]:
# Use custom rouge function to obtain rouge 3/4 which are not available in huggingface
def get_rouge_score(gold, pred):
    rouge_scores = ["rouge1", "rouge2", "rouge3", "rouge4", "rougeL"]
    scorer = rouge_scorer.RougeScorer(rouge_scores, use_stemmer=True)
    scores = scorer.score(gold, pred)
    return {k: scores[k].fmeasure * 100 for k in rouge_scores}


def compute_custom_metrics(srcs, golds, preds, device):
    scores = defaultdict(list)
    bertscore = evaluate.load("bertscore")
    sari = evaluate.load("sari")

    # For rouge and length go over examples one by one and determine mean
    for gold, pred in zip(golds, preds):
        for k, v in get_rouge_score(gold, pred).items():
            scores[k].append(v)
        scores["words"].append(len(pred.split(" ")))
    for k, v in scores.items():
        scores[k] = np.mean(v)

    # This is the default call using model_type="roberta-large"
    # This is the same as in the paper "Generation of Patient After-Visit Summaries to Support Physicians" (AVS_gen/eval_summarization.py) using the libary SummerTime
    scores["bert_score"] = (
        np.mean(
            (
                bertscore.compute(
                    predictions=preds, references=golds, lang="en", device=device
                )
            )["f1"]
        )
        * 100
    )
    # BERTScore authors recommend "microsoft/deberta-large-mnli" (https://github.com/Tiiiger/bert_score)
    scores["bert_score_deberta-large"] = (
        np.mean(
            (
                bertscore.compute(
                    predictions=preds,
                    references=golds,
                    device=device,
                    model_type="microsoft/deberta-large-mnli",
                )
            )["f1"]
        )
        * 100
    )
    scores["sari"] = sari.compute(
        sources=srcs, predictions=preds, references=[[g] for g in golds]
    )["sari"]
    # scores['sari'] = scores['sari'][0]
    # Importing readability for dallc score not working: https://pypi.org/project/py-readability-metrics/

    return scores


def print_metrics_as_latex(metrics):
    # Print latex table row
    order = [
        "rouge1",
        "rouge2",
        "rouge3",
        "rouge4",
        "rougeL",
        "bert_score",
        "bert_score_deberta-large",
        "sari",
        "words",
    ]
    print(" & ".join([f"${metrics[k]:.2f}$" for k in order]))

# Evaluation

### Optionally unzip

If the model is already fine-tuned, upload it into Drive and unzip it for evaluation

In [None]:
!unzip -o "/content/drive/MyDrive/models/model.zip" -d "/content/model"

In [None]:
wandb.init(
    project="patient-summary-eval",
    name=f"eval_{model_name}_{training_data_used}",
)

print("Loading fine-tuned model...")
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)
eval_model = PeftModel.from_pretrained(base_model, adapter_path)
eval_model.eval().to(device)

# Use validation split (authors used validation as test)
eval_data = data["validation"]

# Generate prompt helper
def generate_prompt(text):
    instruction = "Summarize for the patient what happened during the hospital stay based on this doctor's note:\n"
    response = "Summary for the patient:\n"
    return f"{instruction}{text}\n\n{response}"

# Prediction helper
def predict_one(text):
    prompt = generate_prompt(text)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = eval_model.generate(
            **inputs,
            max_new_tokens=350,
            eos_token_id=tokenizer.eos_token_id
        )

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return decoded[len(prompt):].strip()

# Load predictions
if predictions and Path(predictions).exists():
    print(f"Found existing predictions at {predictions}. Using existing predictions.")
    preds = []
    with open(predictions, "r") as f:
        for line in f:
            obj = json.loads(line)
            preds.append(obj["summary"])

# Generate predictions if they don't exist
else:
    print("Generating predictions...")
    preds = []
    for ex in tqdm(eval_data, desc="Predicting", ncols=100):
        preds.append(predict_one(ex["text"]))


    # Save predictions
    out_dir = Path(output)
    out_dir.mkdir(parents=True, exist_ok=True)

    pred_file = out_dir / f"{model_name.replace("/", "_")}_predictions_eval.jsonl"

    with open(pred_file, "w") as f:
        for p in preds:
            obj = {"summary": p}
            f.write(json.dumps(obj) + "\n")

    print("Predictions saved to predictions_eval.jsonl")



# Print examples
print("\n=== Example Outputs ===")
for i in range(5):
    print(f"\nExample {i}:")
    print("NOTE:", eval_data[i]["text"][:400], "...")
    print("GOLD:", eval_data[i]["summary"])
    print("PRED:", preds[i])

# Compute metrics
print("Computing metrics...")
metrics = compute_custom_metrics(
    srcs=[ex["text"] for ex in eval_data],
    golds=[ex["summary"] for ex in eval_data],
    preds=preds,
    device=device
)

# Log metrics to wandb
metrics_float = {k: float(v) for k, v in metrics.items()}
wandb.log(metrics_float)
wandb.finish()


In [None]:
!zip -r wandb_runs.zip wandb