# HealTac 2024 Tutorial
## Instruction Tuning for Discharge Notes Summarization

- Yunsoo Kim (yunsoo.kim.23@ucl.ac.uk), Jinge Wu (jinge.wu.20@ucl.ac.uk), Honghan Wu (honghan.wu@ucl.ac.uk)

<a target="_blank" href="https://colab.research.google.com/github/knowlab/healtac_2024_tutorial/blob/main/discharge_notes_summarization.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

Set the runtime to be T4 GPU.  

We will get started with installing packages and downloading the model because they take some time.

In [None]:
# Run nvidia-smi to check the gpu resource
!nvidia-smi

In [None]:
# Install the required packages
!pip install -q accelerate==0.25.0 peft==0.6.2 bitsandbytes==0.41.1 transformers==4.36.2 trl==0.7.4 einops gradio

In [None]:
# Import packages
import torch
from datasets import load_dataset # loading the dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig # for LoRA
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM # Trainer and DataCollator
import gradio as gr # for deployment

In [None]:
# Define Quantization Config
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
)

# Load Model
# We will use Phi-2
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/phi-2",
    trust_remote_code=True,
    quantization_config=quantization_config,
    device_map="auto",
    revision="refs/pr/23"
)

tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_sight = "right"

# Load Dataset
dataset = load_dataset("bluesky333/synthetic_discharge_summ")

PEFT Markup Description

LoRA

Quantization

In [None]:
# Let's have a look at the train dataset
print(dataset['train'])
dataset['train'][0]

In [None]:
# the test dataset
print(dataset['test'])
dataset['test'][0]

In [None]:
# We make this dataset to phi-2 compatible
# Phi-2 instruction-answer format: "Instruct: <prompt>\nOutput:"

# Make your own prompt!
prompt_template="""Instruct: Please write down your own prompt.
For instance, you can insert the note as {{note}}
{note}
Model should answer to {{question}} based on the note.
{question}
You should maintain the phi-2 format
Accordingly, the last line must be like the below.
Do not forget to insert a new line between your prompt and 'Output'!
Output: {answer}
"""

prompt_template="""Instruct: Answer the question about the following clinical note. \n{note}.
Output: {answer}
"""


# Should get Dict[List] as input, return list of prompts
def format_dataset(samples):
    outputs = []
    for _, note, question, answer in zip(*samples.values()):
        out = prompt_template.format(note=note, question=question, answer=answer)
        outputs.append(out)
    return outputs

sample_input = format_dataset({k: [v] for k, v in dataset['train'][0].items()})[0]
print(sample_input)
print("*"*20)

# Sanity Check
prompt_len = len(tokenizer.encode(prompt_template))
if prompt_len > 180:
    raise ValueError(f"Your prompt is too long! Please reduce the length from {prompt_len} to 180 tokens")
print(f"Prompt Length: {prompt_len} tokens")

In [None]:
sample_idx = 0
sample_input = format_dataset({k: [v] for k, v in dataset['train'][sample_idx].items()})[0].split('Output: ')[0]
input_ids = tokenizer.encode(sample_input, return_tensors='pt').to('cuda')
with torch.no_grad():
  output = model.generate(input_ids=input_ids,
                            max_length=512,
                            use_cache=True,
                            temperature=0.,
                            eos_token_id=tokenizer.eos_token_id
  )
print(tokenizer.decode(output.to('cpu')[0], skip_special_tokens=True))

In [None]:
# Then, let's define dataset.
response_template = "Output:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

train_dataset = dataset['train']

In [None]:
# SFTTrainer Do everything else for you!

lora_config=LoraConfig(
    r=4,
    task_type="CAUSAL_LM",
    target_modules= ["Wqkv", "fc1", "fc2" ]
)

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    fp16=True,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,
    optim="paged_adamw_32bit",
    save_strategy="no",
    warmup_ratio=0.03,
    logging_steps=1,
    lr_scheduler_type="cosine",
    report_to=None,
    gradient_checkpointing=True
)

trainer = SFTTrainer(
    model,
    training_args,
    train_dataset=train_dataset,
    formatting_func=format_dataset,
    data_collator=collator,
    peft_config=lora_config,
    max_seq_length=512,
    tokenizer=tokenizer,
)

In [None]:
# Run Training
trainer.train()

In [None]:
# Wrap-up Training
model = trainer.model
model.eval()

note_samples = dataset['test']['note']

def inference(note, question, model):
    prompt = prompt_template.format(note=note, question=question, answer="")
    tokens = tokenizer.encode(prompt, return_tensors="pt").to('cuda')
    outs = model.generate(input_ids=tokens,
                          max_length=512,
                          use_cache=True,
                          temperature=1.,
                          eos_token_id=tokenizer.eos_token_id
                          )
    output_text = tokenizer.decode(outs.to('cpu')[0], skip_special_tokens=True)
    return output_text[len(prompt):]


def compare_models(note, question):
    with torch.no_grad():
        asc_answer = inference(note, question, trainer.model)
        with model.disable_adapter():
            phi_answer = inference(note, question, trainer.model)
    return asc_answer, phi_answer

demo = gr.Interface(fn=compare_models, inputs=[gr.Dropdown(note_samples), "text"], outputs=[gr.Textbox(label="Trained"), gr.Textbox(label="Phi-2")])
demo.launch()