# HealTac 2024 Tutorial
## Instruction Tuning for Discharge Notes Summarization

- Yunsoo Kim (yunsoo.kim.23@ucl.ac.uk), Jinge Wu (jinge.wu.20@ucl.ac.uk), Honghan Wu (honghan.wu@ucl.ac.uk)

<a target="_blank" href="https://colab.research.google.com/github/knowlab/healtac_2024_tutorial/blob/main/discharge_notes_summarization.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

Set the runtime to be T4 GPU.  

We will get started with installing packages and downloading the model.

In [1]:
# Run nvidia-smi to check the gpu resource
!nvidia-smi

Tue Jun  4 13:55:26 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   58C    P8              10W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [2]:
# First, install required packages
!pip install -q accelerate==0.25.0 peft==0.6.2 bitsandbytes==0.41.1 transformers==4.36.2 trl==0.7.4 einops gradio

In [3]:
# Import Libraries
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import gradio as gr

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [4]:
# Define Quantization Config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
)

# Load Model and Dataset
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/phi-2",
    trust_remote_code=True,
    quantization_config=bnb_config,
    device_map="auto",
    revision="refs/pr/23"
)

tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_sight = "right"

dataset = load_dataset("bluesky333/synthetic_discharge_summ")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
# First of all, let's have a look at the dataset
print(dataset['train'])
dataset['train'][0]

Dataset({
    features: ['patient_id', 'note', 'question', 'answer'],
    num_rows: 13584
})


{'patient_id': 5,
 'note': "Discharge Summary:\n\nPatient: 52-year-old male hospitalized for pneumonia and moderate ARDS\n\nHospital Course:\nThe patient was admitted to the hospital four days after the beginning of a dry cough, fever, and head and limb pain and tested COVID-19 positive. One day later, he was diagnosed with pneumonia that progressed into moderate ARDS and required mechanical ventilation and intermittent dialysis. After extubation, the patient experienced disorientation and an inability to communicate verbally due to global weakness (CPAx 11/50), accompanied by oral and pharyngeal weakness and paresthesia. Specialized physical therapy with the Gugging Swallowing Screen confirmed severe dysphagia, with the patient showing insufficient protection against aspiration. Treatment included therapy for dysphagia, such as intensive oral stimulation, facilitation of swallowing, and protection mechanism training, while receiving no food or drink by mouth. Over the next few days, t

In [6]:
# We make this dataset to phi-2 compatible
# Phi-2 instruction-answer format: "Instruct: <prompt>\nOutput:"

# Make your own prompt!
prompt_template="""Instruct: Please write down your own prompt.
For instance, you can insert the note as {{note}}
{note}
Model should answer to {{question}} based on the note.
{question}
You should maintain the phi-2 format
Accordingly, the last line must be like the below.
Do not forget to insert a new line between your prompt and 'Output'!
Output: {answer}
"""

# Should get Dict[List] as input, return list of prompts
def format_dataset(samples):
    outputs = []
    for _, note, question, answer in zip(*samples.values()):
        out = prompt_template.format(note=note, question=question, answer=answer)
        outputs.append(out)
    return outputs

sample_input = format_dataset({k: [v] for k, v in dataset['train'][0].items()})[0]
print(sample_input)
print("*"*20)

# Sanity Check
prompt_len = len(tokenizer.encode(prompt_template))
if prompt_len > 180:
    raise ValueError(f"Your prompt is too long! Please reduce the length from {prompt_len} to 180 tokens")
print(f"Prompt Length: {prompt_len} tokens")

Instruct: Please write down your own prompt.
For instance, you can insert the note as {note}
Discharge Summary:

Patient: 52-year-old male hospitalized for pneumonia and moderate ARDS

Hospital Course:
The patient was admitted to the hospital four days after the beginning of a dry cough, fever, and head and limb pain and tested COVID-19 positive. One day later, he was diagnosed with pneumonia that progressed into moderate ARDS and required mechanical ventilation and intermittent dialysis. After extubation, the patient experienced disorientation and an inability to communicate verbally due to global weakness (CPAx 11/50), accompanied by oral and pharyngeal weakness and paresthesia. Specialized physical therapy with the Gugging Swallowing Screen confirmed severe dysphagia, with the patient showing insufficient protection against aspiration. Treatment included therapy for dysphagia, such as intensive oral stimulation, facilitation of swallowing, and protection mechanism training, while re

In [7]:
sample_idx = 0
sample_input = format_dataset({k: [v] for k, v in dataset['train'][sample_idx].items()})[0].split('Output: ')[0]
input_ids = tokenizer.encode(sample_input, return_tensors='pt').to('cuda')
with torch.no_grad():
  output = model.generate(input_ids=input_ids,
                            max_length=512,
                            use_cache=True,
                            temperature=0.,
                            eos_token_id=tokenizer.eos_token_id
  )
print(tokenizer.decode(output.to('cpu')[0], skip_special_tokens=True))

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Instruct: Please write down your own prompt.
For instance, you can insert the note as {note}
Discharge Summary:

Patient: 52-year-old male hospitalized for pneumonia and moderate ARDS

Hospital Course:
The patient was admitted to the hospital four days after the beginning of a dry cough, fever, and head and limb pain and tested COVID-19 positive. One day later, he was diagnosed with pneumonia that progressed into moderate ARDS and required mechanical ventilation and intermittent dialysis. After extubation, the patient experienced disorientation and an inability to communicate verbally due to global weakness (CPAx 11/50), accompanied by oral and pharyngeal weakness and paresthesia. Specialized physical therapy with the Gugging Swallowing Screen confirmed severe dysphagia, with the patient showing insufficient protection against aspiration. Treatment included therapy for dysphagia, such as intensive oral stimulation, facilitation of swallowing, and protection mechanism training, while re

In [9]:
# Then, let's define dataset.
response_template = "Output:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

train_dataset = dataset['train']
sampled_train_dataset = train_dataset.select(range(10))

In [10]:
# SFTTrainer Do everything else for you!

lora_config=LoraConfig(
    r=4,
    task_type="CAUSAL_LM",
    target_modules= ["Wqkv", "fc1", "fc2" ]
)

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    fp16=True,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,
    optim="paged_adamw_32bit",
    save_strategy="no",
    warmup_ratio=0.03,
    logging_steps=5,
    lr_scheduler_type="cosine",
    report_to="tensorboard",
    gradient_checkpointing=True
)

trainer = SFTTrainer(
    model,
    training_args,
    train_dataset=sampled_train_dataset,
    formatting_func=format_dataset,
    data_collator=collator,
    peft_config=lora_config,
    max_seq_length=512,
    tokenizer=tokenizer,
)

You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it).Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model.


Map:   0%|          | 0/10 [00:00<?, ? examples/s]

In [11]:
# Run Training
trainer.train()

You're using a CodeGenTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
For instance, you can insert the note as {note}
DISCHARGE SUMMARY:

Patient Name: X
Medical Record Number: X
Date of Admission: XX/XX/XX
Date of Discharge: XX/XX/XX

Hospital Course:
A 36-year-old premenopausal woman with a family history of colorectal, hepatobiliary cancerspresented with an abnormal right breast lump. Diagnostic mammogram and ultrasound showed a highly suggestive malignant breast mass, which was confirmed by a biopsy of the dominant lesion. The patient underwent a right breast simple mastectomy with axillary lymph node evaluation and her pathology showed a multifocal invasive mammary carcinoma with ductal and lobular features. She received adjuvant PMRT 5000 cGy dose, 25 fractions with 1000 cGy scar boost. Based on TEXT/SOFT data, ovarian supp

Step,Training Loss


TrainOutput(global_step=1, training_loss=0.4690917134284973, metrics={'train_runtime': 8.8987, 'train_samples_per_second': 1.124, 'train_steps_per_second': 0.112, 'total_flos': 81506284339200.0, 'train_loss': 0.4690917134284973, 'epoch': 1.0})

In [12]:
# Wrap-up Training
model = trainer.model
model.eval()

note_samples = train_dataset.select(range(len(train_dataset)-10, len(train_dataset)))['note']

def inference(note, question, model):
    prompt = prompt_template.format(note=note, question=question, answer="")
    tokens = tokenizer.encode(prompt, return_tensors="pt").to('cuda')
    outs = model.generate(input_ids=tokens,
                          max_length=512,
                          use_cache=True,
                          temperature=0.,
                          eos_token_id=tokenizer.eos_token_id
                          )
    output_text = tokenizer.decode(outs.to('cpu')[0], skip_special_tokens=True)
    return output_text[len(prompt):]


def compare_models(note, question):
    with torch.no_grad():
        asc_answer = inference(note, question, trainer.model)
        with model.disable_adapter():
            phi_answer = inference(note, question, trainer.model)
    return asc_answer, phi_answer

demo = gr.Interface(fn=compare_models, inputs=[gr.Dropdown(note_samples), "text"], outputs=[gr.Textbox(label="Trained"), gr.Textbox(label="Phi-2")])
demo.launch()

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Running on public URL: https://1b4d31423d999d31a0.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


