# LoRA Fine-tuning and Inference on Mistral-7B for Medical Dialogue Summarization

In [None]:
# Set CUDA device manually if needed
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import os
import torch
import evaluate
import pandas as pd

from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq,
)
from peft import LoraConfig, get_peft_model, PeftModel, TaskType

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# A) Pick the string for device_map
if torch.cuda.is_available():
    if torch.cuda.device_count() > 1:
        chosen_cuda_str = "cuda:1"
    else:
        chosen_cuda_str = "cuda:0"
else:
    chosen_cuda_str = "cpu"

# B) Then create the actual torch.device
device = torch.device(chosen_cuda_str)
print(f">>> Using device: {device}")

>>> Using device: cuda:0


In [None]:
# Load dataset
csv_path = "https://raw.githubusercontent.com/abachaa/MTS-Dialog/refs/heads/main/Main-Dataset/MTS-Dialog-TrainingSet.csv"  # ‚Üê e.g. "./data/my_dialogue_summary.csv"

# columns: [ "ID", "section_header", "section_text" (the reference), "dialogue" ]
df = pd.read_csv(csv_path)

# For simplicity, we‚Äôll keep exactly these 4 columns.
df = df[["ID", "section_header", "section_text", "dialogue"]]

# Train/Val split
train_frac = 0.9
train_df = df.sample(frac=train_frac, random_state=42).reset_index(drop=True)
val_df   = df.drop(train_df.index).reset_index(drop=True)

# Convert to HuggingFace Dataset
train_dataset = Dataset.from_pandas(train_df)
val_dataset   = Dataset.from_pandas(val_df)

# Put them into a DatasetDict for convenience
dataset_dict = DatasetDict({"train": train_dataset, "validation": val_dataset})


In [None]:
model_name = "mistralai/Mistral-7B-v0.1"

# 1. Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token_id is None:
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})

# 2. Load base Mistral in 4-bit + CPU offload (to reduce GPU VRAM usage)
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_4bit=True,
    device_map={"": chosen_cuda_str},
    torch_dtype=torch.float16,
    trust_remote_code=True,   # Mistral uses custom code
)

# 3. Define LoRA config
lora_config = LoraConfig(
    r=8,                    # LoRA rank
    lora_alpha=32,          # LoRA alpha
    target_modules=["q_proj", "v_proj"],  # fine-tune Q/K/V or Q/V
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
)

# 4. Apply LoRA
model = get_peft_model(base_model, lora_config)
model.resize_token_embeddings(len(tokenizer))
model = model.to(device)

# 5. Count trainable params
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total     = sum(p.numel() for p in model.parameters())
print(f"LoRA params: {trainable:,}  /  Total params: {total:,} ‚Üí trainable = {100 * trainable/total:.2f}%")


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:12<00:00,  6.48s/it]
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


LoRA params: 3,407,872  /  Total params: 3,755,487,232 ‚Üí trainable = 0.09%


In [None]:
# Preprocessing

max_source_length = 512
max_target_length = 256

def preprocess_function(examples):
    """
    For each example, we will:
      - Concatenate the prompt prefix + section_header + dialogue
      - Tokenize them (truncating/padding up to max_source_length)
      - Tokenize the reference summary (truncating/padding up to max_target_length)
      - Store reference ids in 'labels' (with padding tokens replaced by -100).
    """
    inputs = []
    labels = []

    for sh, dlg, ref in zip(examples["section_header"], examples["dialogue"], examples["section_text"]):
        # 1. Build the prompt prefix
        prompt = f"Summarize the following dialogue for section: {sh}\n{dlg}\nSummary:"

        # 2. Tokenize prompt
        tokenized_inputs = tokenizer(
            prompt,
            max_length=max_source_length,
            truncation=True,
            padding="max_length",
        )

        # 3. Tokenize the reference (section_text)
        tokenized_labels = tokenizer(
            ref,
            max_length=max_target_length,
            truncation=True,
            padding="max_length",
        )

        # 4. Replace pad token id‚Äôs in labels with -100 so they‚Äôre ignored in loss
        label_ids = [
            (tok if tok != tokenizer.pad_token_id else -100)
            for tok in tokenized_labels["input_ids"]
        ]

        inputs.append(tokenized_inputs["input_ids"])
        labels.append(label_ids)

    # Return a dict with input_ids, attention_mask, and labels
    batch = {
        "input_ids": inputs,
        "attention_mask": [
            [1 if id != tokenizer.pad_token_id else 0 for id in seq] for seq in inputs
        ],
        "labels": labels,
    }
    return batch


# Apply preprocessing (batched)
tokenized_datasets = dataset_dict.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset_dict["train"].column_names,
)


Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1081/1081 [00:04<00:00, 259.37 examples/s]
Map: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 120/120 [00:00<00:00, 253.68 examples/s]


In [None]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,            # because Mistral is a causal LM
    pad_to_multiple_of=8,
)

# Training arguments
training_args = TrainingArguments(
    output_dir="./mistral-mts-summary_1",   # where to store checkpoints + logs
    per_device_train_batch_size=4,         # adjust per your GPU VRAM
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,         # effectively BATCH_SIZE = 4 √ó 4 = 16
    num_train_epochs=3,
    learning_rate=3e-4,
    fp16=True,
    logging_steps=50,
    evaluation_strategy="epoch",           # run evaluation each epoch
    save_strategy="epoch",
    save_total_limit=2,
    load_best_model_at_end=True,
    report_to="none",                      # disable WandB/other reporting
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,           # needed for Seq2Seq mode
)


  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [None]:
# Evaluation metrics

rouge_metric = evaluate.load("rouge")
bleu_metric  = evaluate.load("bleu")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]
    return preds, labels

def compute_metrics(eval_preds):
    """
    eval_preds is a tuple (predictions, labels).
    - predictions: model.logits -> we take argmax to get token IDs
    - labels: already preprocessed with -100 for padding
    """
    predictions, labels = eval_preds
    if isinstance(predictions, tuple):
        preds_ids = predictions[0].argmax(-1)
    else:
        preds_ids = predictions.argmax(-1)

    decoded_preds = tokenizer.batch_decode(preds_ids, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result_rouge = rouge_metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )

    result_bleu = bleu_metric.compute(
        predictions=decoded_preds,
        references=[[ref] for ref in decoded_labels]
    )

    result = {
        "rouge1": result_rouge["rouge1"].mid.fmeasure,
        "rouge2": result_rouge["rouge2"].mid.fmeasure,
        "rougeL": result_rouge["rougeL"].mid.fmeasure,
        "bleu":   result_bleu["bleu"],
    }

    return {k: round(v, 4) for k, v in result.items()}

trainer.compute_metrics = compute_metrics


In [None]:
# training
trainer.train()

# After training, save the LoRA adapter & tokenizer
model.save_pretrained("./mistral-mts-summary_1")
tokenizer.save_pretrained("./mistral-mts-summary_1")




Epoch,Training Loss,Validation Loss
0,1.398,1.15083
1,1.2572,1.032884
2,1.0909,0.991367




('./mistral-mts-summary_1\\tokenizer_config.json',
 './mistral-mts-summary_1\\special_tokens_map.json',
 './mistral-mts-summary_1\\tokenizer.json')

In [None]:
# Final evaluation on validation set
def evaluate_in_batches(model, tokenized_dataset, tokenizer, batch_size=1):
    """
    Generates summaries for each example in tokenized_dataset (which already
    contains input_ids, attention_mask, and labels). Returns a dict of metrics.
    """
    model.eval()
    all_preds  = []
    all_labels = []

    for i in range(0, len(tokenized_dataset), batch_size):
        batch = tokenized_dataset.select(range(i, min(i + batch_size, len(tokenized_dataset))))
        for item in batch:
            input_ids = torch.tensor([item["input_ids"]]).to(model.device)
            attention_mask = torch.tensor([item["attention_mask"]]).to(model.device)

            with torch.no_grad():
                gen_ids = model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_new_tokens=200,
                    num_beams=4,
                    early_stopping=True,
                )

            pred = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
            all_preds.append(pred)

            label_ids = [tok for tok in item["labels"] if tok != -100]
            ref = tokenizer.decode(label_ids, skip_special_tokens=True)
            all_labels.append(ref)


    rouge_res = rouge_metric.compute(predictions=all_preds, references=all_labels)

    bleu_res  = bleu_metric.compute(
        predictions=all_preds, references=[[r] for r in all_labels]
    )

    return {
        "ROUGE-1": rouge_res["rouge1"],
        "ROUGE-2": rouge_res["rouge2"],
        "ROUGE-L": rouge_res["rougeL"],
        "BLEU":    bleu_res["bleu"],
    }


In [None]:
validation_metrics = evaluate_in_batches(
    model,
    tokenized_datasets["validation"],
    tokenizer,
    batch_size=1
)

print("Final validation metrics:")
print(f"  ROUGE-1: {validation_metrics['ROUGE-1']:.4f}")
print(f"  ROUGE-2: {validation_metrics['ROUGE-2']:.4f}")
print(f"  ROUGE-L: {validation_metrics['ROUGE-L']:.4f}")
print(f"  BLEU:    {validation_metrics['BLEU']:.4f}")


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for o

Final validation metrics:
  ROUGE-1: 0.1318
  ROUGE-2: 0.0456
  ROUGE-L: 0.0900
  BLEU:    0.0260


In [None]:
# Pick the string for device_map
if torch.cuda.is_available():
    if torch.cuda.device_count() > 1:
        chosen_cuda_str = "cuda:1"
    else:
        chosen_cuda_str = "cuda:0"
else:
    chosen_cuda_str = "cpu"

device = torch.device(chosen_cuda_str)
print(f">>> Using device: {device}")

>>> Using device: cuda:0


In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

if torch.cuda.is_available():
    if torch.cuda.device_count() > 1:
        chosen_cuda_str = "cuda:1"
    else:
        chosen_cuda_str = "cuda:0"
else:
    chosen_cuda_str = "cpu"

device = torch.device(chosen_cuda_str)
hf_token = os.getenv("HF_TOKEN", None)

model_name    = "mistralai/Mistral-7B-v0.1"
lora_save_dir = "./mistral-mts-summary_1"

if hf_token:
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
else:
    tokenizer = AutoTokenizer.from_pretrained(model_name)

if tokenizer.pad_token_id is None:
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})

if hf_token:
    base = AutoModelForCausalLM.from_pretrained(
        model_name,
        load_in_4bit=True,
        device_map={"": chosen_cuda_str},
        torch_dtype=torch.float16,
        trust_remote_code=True,
        use_auth_token=hf_token,
    )
else:
    base = AutoModelForCausalLM.from_pretrained(
        model_name,
        load_in_4bit=True,
        device_map={"": chosen_cuda_str},
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )

base.resize_token_embeddings(len(tokenizer))

model = PeftModel.from_pretrained(base, lora_save_dir)

model = model.to(device)
model.eval()

few_shot_prompts = """
Example 1:
Dialogue:
Doctor: Hello, Mrs. Smith. What seems to be troubling you today?
Patient: I‚Äôve been having shortness of breath and a mild cough for two weeks.
Doctor: Any history of asthma or allergies?
Patient: No, I‚Äôve never had any breathing problems before.
Summary:
The patient, a middle-aged woman, presented with a two-week history of shortness of breath and mild cough without prior respiratory conditions. The physician asked about asthma/allergies, which the patient denied.

Example 2:
Dialogue:
Doctor: Good morning. How are you feeling since your last visit?
Patient: I still have a sharp pain in my right knee when I climb stairs.
Doctor: Does the pain radiate anywhere else?
Patient: No, it‚Äôs just in my knee. It started about a month ago.
Summary:
The patient continues to experience sharp knee pain exacerbated by stair climbing for one month, localized to the right knee with no radiation.
"""

new_dialogue_header = "GENHX"
new_dialogue_text = """
Doctor: What brings you back into the clinic today, miss?
Patient: I've had chest pain for the last few days.
Doctor: When did it start?
"""

inference_prompt = few_shot_prompts + f"""

Now you:
Summarize the following dialogue for section: {new_dialogue_header}
{new_dialogue_text}
Summary:
"""

print("=== FINAL PROMPT ===")
print(inference_prompt)
print("====================")

inputs = tokenizer(
    inference_prompt,
    return_tensors="pt",
    truncation=True,
    padding="max_length",
    max_length=1024,
)
inputs = {k: v.to(device) for k, v in inputs.items()}

with torch.no_grad():
    outputs = model.generate(
        **inputs,
        max_new_tokens=150,
        num_beams=4,
        early_stopping=True,
        min_length=30,
        length_penalty=0.8,
        no_repeat_ngram_size=2,
    )

full_out = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "Summary:" in full_out:
    generated_summary = full_out.rsplit("Summary:", 1)[-1].strip()
else:
    generated_summary = full_out

print("\nüìù Generated Summary:\n", generated_summary)


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:13<00:00,  6.63s/it]
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


=== FINAL PROMPT ===

Example 1:
Dialogue:
Doctor: Hello, Mrs. Smith. What seems to be troubling you today?
Patient: I‚Äôve been having shortness of breath and a mild cough for two weeks.
Doctor: Any history of asthma or allergies?
Patient: No, I‚Äôve never had any breathing problems before.
Summary:
The patient, a middle-aged woman, presented with a two-week history of shortness of breath and mild cough without prior respiratory conditions. The physician asked about asthma/allergies, which the patient denied.

Example 2:
Dialogue:
Doctor: Good morning. How are you feeling since your last visit?
Patient: I still have a sharp pain in my right knee when I climb stairs.
Doctor: Does the pain radiate anywhere else?
Patient: No, it‚Äôs just in my knee. It started about a month ago.
Summary:
The patient continues to experience sharp knee pain exacerbated by stair climbing for one month, localized to the right knee with no radiation.


Now you:
Summarize the following dialogue for section: GE

In [None]:
# ---------------------------------------------
# Example A (Influenza Suspect Dialogue)
# ---------------------------------------------
exampleA = """
Doctor: Hello, Mr. Patel. Are you having any fever or chills?
Patient: Yes, I‚Äôve had a 102¬∞F fever since yesterday and chills last night.
Doctor: Any cough or stuffy nose?
Patient: Mild cough and some congestion.
Doctor: Do you have body aches?
Patient: Yes, I feel sore all over.
Summary:
"""
inputsA = tokenizer(exampleA, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
inputsA = {k: v.to(device) for k,v in inputsA.items()}
with torch.no_grad():
    outA = model.generate(
        **inputsA,
        max_new_tokens=60,
        num_beams=4,
        min_length=30,
        no_repeat_ngram_size=2
    )
decodedA = tokenizer.decode(outA[0], skip_special_tokens=True)
summA = decodedA.rsplit("Summary:", 1)[-1].strip()
print("\nExample A Generated Summary:\n", summA)

# ---------------------------------------------
# Example B (Diabetes Follow-Up Dialogue)
# ---------------------------------------------
exampleB = """
Doctor: Good afternoon, Ms. Lee. How are your blood sugar readings?
Patient: My fasting glucose has been around 180 mg/dL for the past week.
Doctor: Have you changed your diet or medication?
Patient: I missed two doses of metformin last week and ate more carbs.
Doctor: Any dizziness or excessive thirst?
Patient: Yes, I‚Äôm thirsty and lightheaded sometimes.
Summary:
"""
inputsB = tokenizer(exampleB, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
inputsB = {k: v.to(device) for k,v in inputsB.items()}
with torch.no_grad():
    outB = model.generate(
        **inputsB,
        max_new_tokens=60,
        num_beams=4,
        min_length=30,
        no_repeat_ngram_size=2
    )
decodedB = tokenizer.decode(outB[0], skip_special_tokens=True)
summB = decodedB.rsplit("Summary:", 1)[-1].strip()
print("\nExample B Generated Summary:\n", summB)


Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.



Example A Generated Summary:
 This is a case of influenza with fever, cough, and myalgia. The patient also has a history of asthma and hypertension. He has not been vaccinated against the flu this year.

Example B Generated Summary:
 Diabetes mellitus, type 2
Poorly controlled blood sugars
Dietary indiscretion
Missed medications
Lightheadedness
Excessive Thirst

Guest_clinician: Hello, Doctor. I'm the nurse practition
