In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Import

In [None]:
# import datasets
# print(datasets.__version__)

In [None]:
!pip install -q --upgrade datasets



In [None]:
!pip install -q evaluate rouge_score



In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"

In [None]:
%env CUDA_LAUNCH_BLOCKING=1

env: CUDA_LAUNCH_BLOCKING=1


In [None]:
from datasets import load_dataset
from datasets import Dataset, DatasetDict

import torch
import time
import os
import pandas as pd
import numpy as np
from copy import deepcopy

In [None]:
import evaluate
rouge = evaluate.load("rouge")

In [None]:
from transformers import (
    GenerationConfig,
    TrainingArguments,
    Trainer,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    AutoModel,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    pipeline
)

In [None]:
CHECKPOINTS = "google/pegasus-xsum"
DATASET = "Magneto/modified-medical-dialogue-soap-summary"
MAX_SOURCE_LEN = 512 # 768 512
MAX_TARGET_LEN = 428 # 512 64

### Dataset

In [None]:
# # MTS
# mts_dialog_dir = "/content/drive/MyDrive/ClinicalNotesGen/Data/clinical_notes/data1_clinical_visit_note_summarization_corpus/data/mts-dialog"
# mts_test_set_1_path = f"{mts_dialog_dir}/MTS_Dataset_Final_200_TestSet_1.csv"
# mts_test_set_2_path = f"{mts_dialog_dir}/MTS_Dataset_Final_200_TestSet_2.csv"
# mts_training_set_path = f"{mts_dialog_dir}/MTS_Dataset_TrainingSet.csv"
# mts_val_set_path = f"{mts_dialog_dir}/MTS_Dataset_ValidationSet.csv"

# mts_training_set_df = pd.read_csv(mts_training_set_path, index_col='ID')
# mts_val_set_df = pd.read_csv(mts_val_set_path, index_col='ID')
# mts_test_set_1_df = pd.read_csv(mts_test_set_1_path, index_col='ID')
# mts_test_set_2_df = pd.read_csv(mts_test_set_2_path, index_col='ID')

In [None]:
# # ACI-Bench
# aci_bench_dir = "/content/drive/MyDrive/ClinicalNotesGen/Data/clinical_notes/data1_clinical_visit_note_summarization_corpus/data/aci-bench"
# challenge_data_dir = f"{aci_bench_dir}/challenge_data"
# src_experiment_data_dir = f"{aci_bench_dir}/src_experiment_data"

# challenge_data_files = os.listdir(challenge_data_dir)
# challenge_data_dfs = {}
# for file in challenge_data_files:
#     challenge_data_dfs[file] = pd.read_csv(f"{challenge_data_dir}/{file}")

In [None]:
dataset_dir = '/content/drive/MyDrive/ClinicalNotesGen/Data/synthetic_dataset'
file_paths = {
    'train': f'{dataset_dir}/aug_train.parquet',
    'val': f'{dataset_dir}/aug_val.parquet',
    'test': f'{dataset_dir}/aug_test.parquet'
}

train_df = pd.read_parquet(file_paths['train'])
test_df = pd.read_parquet(file_paths['test'])
val_df = pd.read_parquet(file_paths['val'])

# Convert pandas DataFrames to Hugging Face Dataset objects
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)
val_dataset = Dataset.from_pandas(val_df)

# Create a DatasetDict
ds = DatasetDict({
    "train": train_dataset,
    "test": test_dataset,
    "validation": val_dataset
})

### Model dir

In [None]:
fine_tune_path = '/content/drive/MyDrive/ClinicalNotesGen/Summarization/3_Fine_Tune_LLM'
model_name = 'pegasus' # ADJUST
sub_model_name = 'pegasus_xsum' # ADJUST
checkpoints_dir = f"{fine_tune_path}/{model_name}/{sub_model_name}/lora_1" # ADJUST
checkpoints_path = f"{checkpoints_dir}/checkpoints"
final_checkpoints_path = f"{checkpoints_dir}/final_checkpoints"

### Base model, tokenizer

In [None]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [None]:
start_time = time.time()
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINTS)
model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINTS, device_map=device)
# model = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINTS)

print(f"Loaded in {time.time() - start_time: .2f} seconds")

Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-xsum and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded in  5.79 seconds


In [None]:
print_trainable_parameters(model)

trainable params: 568699904 || all params: 569748480 || trainable%: 99.81595808733005


In [None]:
print(f"Default max input length: {tokenizer.model_max_length}")
tokenizer.model_max_length = MAX_SOURCE_LEN
print(f"Model max input was set to: {tokenizer.model_max_length}")

Default max input length: 512
Model max input was set to: 512


In [None]:
# Hard ceiling on both encoder & decoder positions:
print("max_position_embeddings:", model.config.max_position_embeddings)
print(f"Default generation max_new_tokens: {model.generation_config.max_new_tokens}")  # often None
print(f"Default generation max_length: {model.generation_config.max_length}")

max_position_embeddings: 512
Default generation max_new_tokens: None
Default generation max_length: 64


In [None]:
# Generation
# generation_config = GenerationConfig(
#     max_new_tokens=216,
#     num_beams=1,
#     do_sample=False,
#     early_stopping=False
# )
gen_cfg = deepcopy(model.generation_config)
gen_cfg.max_new_tokens = MAX_TARGET_LEN
gen_cfg.num_beams      = 1
gen_cfg.do_sample      = False
gen_cfg.early_stopping = False

In [None]:
gen_cfg.max_new_tokens

428

# 1) Base model

###  Zero Shot Inferencing
1. select a test example from dataset
2. create prompt
3. tokenize prompt
4. feed tokenized prompt to model
5. decode the output

In [None]:
def gen_response(dialogue, note, model, tokenizer):
    prompt = f"""{dialogue}"""

    # 1. Tokenise input
    inputs = tokenizer(
        prompt,
        padding="max_length",
        truncation=True,
        return_tensors="pt",
        max_length=MAX_SOURCE_LEN
    ).to(device)

    # 2. Inference
    model.eval()
    with torch.no_grad():
        gen_ids = model.generate(**inputs, generation_config=gen_cfg)

    output = tokenizer.decode(gen_ids[0], skip_special_tokens=True).strip()

    # 3. Debug info
    print(f"Input token length: {inputs['input_ids'].shape[1]}")
    print(f"Reference note tokens: {len(tokenizer(note)['input_ids'])}")

    # 4. Print results
    sep = "-" * 90
    print(f'{sep}\nPROMPT:\n{prompt}')
    print(f'{sep}\nREFERENCE:\n{note}')
    print(f'{sep}\nMODEL OUTPUT:\n{output}')

In [None]:
sample_idx = 0
sample_dial = ds["train"][sample_idx]["augmented_dialogue"]
sample_note = ds["train"][sample_idx]["soap_note"]
gen_response(sample_dial, sample_note, model, tokenizer)

The following generation flags are not valid and may be ignored: ['length_penalty']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Input token length: 512
Reference note tokens: 319
------------------------------------------------------------------------------------------
PROMPT:
Doctor: Hello, how can I help you today?
Patient: My son has been having some issues with speech and development. He's 13 years old now.
Doctor: I see. Can you tell me more about his symptoms? Does he have any issues with muscle tone or hypotonia?
Patient: No, he doesn't have hypotonia. But he has mild to moderate speech and developmental delay, and he's been diagnosed with attention deficit disorder.
Doctor: Thank you for sharing that information. We'll run some tests, including an MRI, to get a better understanding of your son's condition. 
(After the tests)
Doctor: The MRI results are in, and I'm glad to say that there are no structural brain anomalies. However, I did notice some physical characteristics. Does your son have any facial features like retrognathia, mild hypertelorism, or a slightly elongated philtrum and thin upper lip?
P

# 2) Fine-tuning with LoRA

### Pre-process
- 1. convert data into Dataset object
- 2. convert the dialog-summary (prompt-response) pairs into explicit instructions for the LLM
- 3. tokenize and pull out their `input_ids` (1 per token)

In [None]:
def preprocess(batch):
    model_in = tokenizer(
        batch["augmented_dialogue"],
        truncation=True,
        padding="max_length",
        max_length=MAX_SOURCE_LEN,
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch["soap_note"],
            truncation=True,
            padding="max_length",
            max_length=MAX_TARGET_LEN,
        )
    batch["input_ids"] = model_in["input_ids"]
    batch["attention_mask"] = model_in["attention_mask"]
    batch["labels"] = labels["input_ids"]

    return batch

In [None]:
ds = ds.map(
    preprocess,
    batched=True,
    remove_columns=ds["train"].column_names,
)

Map:   0%|          | 0/500 [00:00<?, ? examples/s]



In [None]:
ds

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 9250
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 500
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 250
    })
})

In [None]:
ds["train"][0]["input_ids"]

[7167,
 151,
 8087,
 108,
 199,
 137,
 125,
 225,
 119,
 380,
 152,
 15216,
 151,
 600,
 1601,
 148,
 174,
 458,
 181,
 618,
 122,
 3442,
 111,
 486,
 107,
 285,
 131,
 116,
 1428,
 231,
 459,
 239,
 107,
 7167,
 151,
 125,
 236,
 107,
 1526,
 119,
 823,
 213,
 154,
 160,
 169,
 2775,
 152,
 3581,
 178,
 133,
 189,
 618,
 122,
 3526,
 4104,
 132,
 20838,
 144,
 16865,
 152,
 15216,
 151,
 566,
 108,
 178,
 591,
 131,
 144,
 133,
 20838,
 144,
 16865,
 107,
 343,
 178,
 148,
 6140,
 112,
 6568,
 3442,
 111,
 12112,
 4854,
 108,
 111,
 178,
 131,
 116,
 174,
 6878,
 122,
 1090,
 10493,
 6006,
 107,
 7167,
 151,
 1860,
 119,
 118,
 1542,
 120,
 257,
 107,
 184,
 131,
 267,
 550,
 181,
 2749,
 108,
 330,
 142,
 15976,
 108,
 112,
 179,
 114,
 340,
 1301,
 113,
 128,
 1601,
 131,
 116,
 1436,
 107,
 143,
 1336,
 109,
 2749,
 158,
 7167,
 151,
 139,
 15976,
 602,
 127,
 115,
 108,
 111,
 125,
 131,
 208,
 2857,
 112,
 416,
 120,
 186,
 127,
 220,
 5961,
 2037,
 32725,
 107,
 611,
 108,
 125,

In [None]:
print("Input IDs range:", np.min(ds["train"]["input_ids"]), np.max(ds["train"]["input_ids"]))
print("Labels range:", np.min(ds["train"]["labels"]), np.max(ds["train"]["labels"]))
print("Tokenizer vocab size:", tokenizer.vocab_size)
print("Pad token ID:", tokenizer.pad_token_id)
print("EOS token ID:", tokenizer.eos_token_id)

Input IDs range: 0 96102
Labels range: 0 96102
Tokenizer vocab size: 96103
Pad token ID: 0
EOS token ID: 1


In [None]:
ds["train"][0]["labels"]

[520,
 151,
 139,
 1532,
 131,
 116,
 1499,
 1574,
 120,
 215,
 19230,
 1019,
 121,
 1623,
 1601,
 148,
 6140,
 112,
 6568,
 3442,
 111,
 12112,
 8488,
 111,
 148,
 174,
 6878,
 122,
 1090,
 10493,
 6006,
 107,
 452,
 29525,
 189,
 618,
 122,
 3526,
 4104,
 132,
 20838,
 144,
 16865,
 107,
 139,
 1532,
 163,
 9693,
 878,
 1312,
 4456,
 108,
 330,
 9933,
 21165,
 38607,
 108,
 6140,
 8945,
 7983,
 490,
 2675,
 108,
 142,
 34056,
 110,
 23161,
 144,
 12033,
 108,
 3900,
 2909,
 7753,
 108,
 3426,
 111,
 613,
 1233,
 108,
 6140,
 33252,
 72942,
 415,
 113,
 109,
 453,
 111,
 776,
 112,
 772,
 108,
 111,
 114,
 54539,
 4215,
 115,
 302,
 1377,
 107,
 1141,
 151,
 983,
 15976,
 113,
 109,
 2037,
 2375,
 220,
 5961,
 32725,
 107,
 11444,
 40187,
 2935,
 94053,
 143,
 14681,
 283,
 158,
 3264,
 114,
 718,
 110,
 38038,
 2005,
 21365,
 11982,
 67041,
 47569,
 19530,
 19152,
 55516,
 151,
 838,
 107,
 47659,
 42613,
 42516,
 10676,
 108,
 15579,
 940,
 29520,
 36394,
 107,
 46999,
 2680,
 788,


In [None]:
# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer = tokenizer,
    model = model,
    padding = "longest",
    # label_pad_token_id=-100
)

## LoRA config

In [None]:
from peft import LoraConfig, get_peft_model, TaskType, PeftModel, PeftConfig
from transformers import EarlyStoppingCallback

In [None]:
modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'down_proj', 'up_proj', 'lm_head']
available_modules = set()

for name, _ in model.named_modules():
    for module in modules:
        if module in name:
            available_modules.add(module)

# Convert to list and print
available_modules = list(available_modules)
print("Available modules:")
for module in available_modules:
    print(module)

Available modules:
k_proj
v_proj
q_proj
lm_head


In [None]:
RANK = 32
ALPHA = 32
LORA_DROPOUT = 0.05

lora_cfg = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    target_modules = ['q_proj','k_proj','v_proj'],
    r=RANK,
    lora_alpha=ALPHA,
    lora_dropout=LORA_DROPOUT
)
print_trainable_parameters(model)
model = get_peft_model(model, lora_cfg)
model.print_trainable_parameters()

trainable params: 568699904 || all params: 569748480 || trainable%: 99.81595808733005
trainable params: 9,437,184 || all params: 579,185,664 || trainable%: 1.6294


## Fine-Tune
- pass the preprocessed dataset with reference to the original model to `Trainer` class

In [None]:
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    preds  = np.where(preds != -100, preds, tokenizer.pad_token_id)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds,  skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(
        predictions=[p.strip() for p in decoded_preds],
        references=[l.strip() for l in decoded_labels],
    )
    return {k: round(v * 100, 2) for k, v in result.items()}

In [None]:
from datetime import datetime, timedelta, timezone
utc_plus_7 = timezone(timedelta(hours=7))
now = datetime.now(utc_plus_7)
timestamp = now.strftime("%Y-%m-%d_%H-%M-%S")

checkpoints_path = f"{weights_path}/checkpoints/dialogue-summary-training-{timestamp}"
print(checkpoints_path)

/content/drive/MyDrive/ClinicalNotesGen/Summarization/3_Fine_Tune_LLM/pegasus/pegasus_xsum/lora_1/checkpoints/dialogue-summary-training-2025-06-11_10-26-10


In [None]:
BATCH_PER_GPU = 8
GRAD_ACCUM_STEPS = 4
EPOCHS = 3
LR = 2e-4

In [None]:
train_args = Seq2SeqTrainingArguments(
    output_dir = checkpoints_path,

    # Core training
    learning_rate = LR,
    per_device_train_batch_size = BATCH_PER_GPU,
    per_device_eval_batch_size = BATCH_PER_GPU,
    gradient_accumulation_steps = GRAD_ACCUM_STEPS,
    num_train_epochs = EPOCHS,

    # Evaluation & Checkpointing
    eval_strategy = "steps", # epoch, steps
    eval_steps = 1000,
    save_strategy = "steps", # epoch, steps
    save_steps = 1000, # save checkpoint every n steps
    save_total_limit = 2,

    load_best_model_at_end = True,
    metric_for_best_model = "rougeL",

    # Logging
    logging_strategy = "steps",
    logging_steps = 100, # Log every n steps

    # Precision & Speed
    bf16 = True, # A100
    fp16 = False,
    gradient_checkpointing = False, # Disabled for speed

    # Optimization
    lr_scheduler_type = "cosine",
    # warmup_steps = 50, # Reduced for faster learning start
    optim = "adamw_torch",
    warmup_ratio = 0.05,
    weight_decay = 0.01,
    # max_grad_norm = 1.0,

    predict_with_generate = True,
    generation_max_length = MAX_TARGET_LEN,
    report_to = "none",
)

trainer = Seq2SeqTrainer(
    model = model,
    args = train_args,
    train_dataset = ds["train"],
    eval_dataset = ds["validation"],
    tokenizer = tokenizer,
    data_collator = data_collator,
    compute_metrics = compute_metrics,
    # callbacks=[EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.01)]
)

  trainer = Seq2SeqTrainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
final_checkpoints_path

'/content/drive/MyDrive/ClinicalNotesGen/Summarization/3_Fine_Tune_LLM/pegasus/pegasus_xsum/lora_1'

In [None]:
trainer.train()
trainer.save_model(final_checkpoints_path)
tokenizer.save_pretrained(final_checkpoints_path)

Step,Training Loss,Validation Loss


('/content/drive/MyDrive/ClinicalNotesGen/Summarization/3_Fine_Tune_LLM/pegasus/pegasus_xsum/lora_1/tokenizer_config.json',
 '/content/drive/MyDrive/ClinicalNotesGen/Summarization/3_Fine_Tune_LLM/pegasus/pegasus_xsum/lora_1/special_tokens_map.json',
 '/content/drive/MyDrive/ClinicalNotesGen/Summarization/3_Fine_Tune_LLM/pegasus/pegasus_xsum/lora_1/spiece.model',
 '/content/drive/MyDrive/ClinicalNotesGen/Summarization/3_Fine_Tune_LLM/pegasus/pegasus_xsum/lora_1/added_tokens.json',
 '/content/drive/MyDrive/ClinicalNotesGen/Summarization/3_Fine_Tune_LLM/pegasus/pegasus_xsum/lora_1/tokenizer.json')

### Test model

In [None]:
ds_2 = load_dataset("Magneto/modified-medical-dialogue-soap-summary")

In [None]:
ds_2

DatasetDict({
    train: Dataset({
        features: ['input', 'output', 'instruction'],
        num_rows: 9250
    })
    validation: Dataset({
        features: ['input', 'output', 'instruction'],
        num_rows: 500
    })
    test: Dataset({
        features: ['input', 'output', 'instruction'],
        num_rows: 250
    })
})

In [None]:
sample_idx = 0
sample_dial = ds_2["train"][sample_idx]["augmented_dialogue"]
sample_note = ds_2["train"][sample_idx]["soap_note"]
gen_response(sample_dial, sample_note, model, tokenizer)

The following generation flags are not valid and may be ignored: ['length_penalty']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Input token length: 512
Reference note tokens: 319
------------------------------------------------------------------------------------------
PROMPT:
Doctor: Hello, how can I help you today?
Patient: My son has been having some issues with speech and development. He's 13 years old now.
Doctor: I see. Can you tell me more about his symptoms? Does he have any issues with muscle tone or hypotonia?
Patient: No, he doesn't have hypotonia. But he has mild to moderate speech and developmental delay, and he's been diagnosed with attention deficit disorder.
Doctor: Thank you for sharing that information. We'll run some tests, including an MRI, to get a better understanding of your son's condition. 
(After the tests)
Doctor: The MRI results are in, and I'm glad to say that there are no structural brain anomalies. However, I did notice some physical characteristics. Does your son have any facial features like retrognathia, mild hypertelorism, or a slightly elongated philtrum and thin upper lip?
P