# KoSAIM 2025 Summer School

## Fine-tuning a clinical domain LLM

- 강사: 김지호(jiho.kim@kaist.ac.kr), 임수정(sujeongim@kaist.ac.kr)

- 발표자료: https://docs.google.com/presentation/d/1KGcN4iYkw7GH6zZinSW2o9mFWYXmtwpKfkA5Bcc2FDs/edit?usp=sharing

- 레퍼런스: https://github.com/starmpcc/KAIA-LLM-FT-2024

## [Step 1] 환경 세팅

### 패키지 설치

In [None]:
!pip install -q accelerate  peft  bitsandbytes  transformers trl  numpy einops gradio nltk triton gcsfs fsspec

### 라이브러리 가져오기

In [None]:
import os
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt

## [Step 2] 사전 학습된 모델 (및 토크나이저) 불러오기

### 모델 가져오기

In [None]:
# Quantization Config 정의
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
)

# 사전학습된 기본 모델 가져오기
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/phi-2",
    trust_remote_code=True,
    quantization_config=bnb_config,
    device_map="auto",
    force_download=True,
)

### 모델 살펴보기

In [None]:
# (양자화된) 레이어 살펴보기
model.model.layers[0].mlp.fc1.weight

In [None]:
# (양자회된) 레이어 데이터 타입 살펴보기
model.model.layers[0].mlp.fc1.weight.dtype

In [None]:
# 모델 파라미터 개수 확인
sum([p.numel() for p in model.parameters()])

In [None]:
# cuda memory 체크
print(torch.cuda.memory_summary())

### 토크나이저 가져오기

In [None]:
# 모델에 맞는 토크나이저 가져오기
tokenizer = AutoTokenizer.from_pretrained('microsoft/phi-2')
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_sight = "right"

### 토크나이저 살펴보기

In [None]:
# tokenizer vocab 개수 확인
len(tokenizer.vocab)

In [None]:
# tokenizer vocab 확인
tokenizer.vocab

In [None]:
# tokenizer special token 확인
tokenizer.special_tokens_map

In [None]:
# tokenizer로 tokenize 해보기
tokenizer.tokenize("Hi, my name is John.")

In [None]:
# tokenizer로 encoding 해보기
tokenizer.encode("Hi, my name is John.")

In [None]:
# tokenizer로 encoding 해보기 (2)
tokenizer("Hi, my name is John.")

In [None]:
# tokenizer로 encoding -> decoding 해보기
sample_text = "Hi, my name is John."
encoded_text = tokenizer.encode(sample_text)
decoded_text = tokenizer.decode(encoded_text)
print(decoded_text)

## [Step 3] Asclepius-Synthetic-Clinical-Notes 데이터 확인하기

- 데이터 링크 : https://huggingface.co/datasets/starmpcc/Asclepius-Synthetic-Clinical-Notes


### 데이터 불러오기

In [None]:
# Asclepius-Synthetic-Clinical-Notes 원본 데이터셋 가져오기
dataset = load_dataset("starmpcc/Asclepius-Synthetic-Clinical-Notes")

In [None]:
# 데이터셋 체크
dataset

In [None]:
# 노트 길이 확인
plt.hist([len(sample['note']) for sample in dataset['train']])
plt.show()

In [None]:
# 필터링: 노트의 길이가 1500보다 작은 경우
dataset = dataset.filter(lambda x: [len(i)<1500 for i in x['note']], batched=True)

In [None]:
# 데이터셋 체크 (필터링 이후)
dataset

In [None]:
# 필터링 함수 정의
def prompt_shorter_than(samples):
    # 각 샘플의 'note', 'question', 'answer' 필드를 공백으로 연결하여 하나의 문자열로 결합
    concatenated = [" ".join([i, j, k]) for i, j, k in zip(samples['note'], samples['question'], samples['answer'])]
    # 결합된 문자열을 토크나이저로 토큰화하고, 토큰 길이가 320 이하인지를 확인하여 리스트로 반환
    return [len(i)<=320 for i in tokenizer(concatenated)['input_ids']]

In [None]:
# 필터링: 토크나이저
dataset = dataset.filter(prompt_shorter_than, batched=True)

In [None]:
# 데이터셋 체크 (필터링 이후)
dataset

### 데이터 탐색하기

In [None]:
# train 데이터 구성
print(dataset['train'])
print()

In [None]:
# 샘플 데이터 확인
sample_idx = 0
sample_data = dataset['train'][sample_idx]
sample_data

In [None]:
# 데이터셋을 DataFrame으로 변환
df = pd.DataFrame(dataset['train'])

# 데이터프레임 일부 출력(5개만 출력)
df.head(5)

In [None]:
# Task 종류별 개수
df.groupby('task').size().plot(kind='barh', color=plt.cm.Set3.colors)
plt.xlabel('Number of Tasks')
plt.ylabel('Task Type')
plt.title('Number of Tasks')
plt.gca().spines[['top', 'right',]].set_visible(False)

In [None]:
# Task 분포
df['task'].value_counts().plot(kind='pie', autopct='%1.1f%%', colors=plt.cm.Set3.colors)
plt.ylabel('')
plt.title('Distribution of Tasks')
plt.show()

## [Step 4] 학습 데이터 전처리

### 프롬프트 데이터 전처리 함수 정의 (`formatting_func`)

In [None]:
# 해당 프롬프트 포맷은 phi-2 모델에 사용 가능
# Phi-2 instruction-answer format: "Instruct: <prompt>\nOutput:"

prompt_template="""Instruct: Answer to the question for the given clinical note.
[note start]
{note}
[note end]

Question: {question}

Output: {answer}"""

In [None]:
print(prompt_template.format(note="xxx", question="yyy", answer="zzz"))

In [None]:
# 샘플 데이터를 입력으로 받아 형식에 맞게 프롬프트를 구성하여 내보내는 함수
def format_dataset(samples):
    outputs = []
    for _, note, question, answer, _ in zip(*samples.values()):
        out = prompt_template.format(note=note, question=question, answer=answer)
        outputs.append(out)
    return outputs

sample_input = format_dataset({k: [v] for k, v in dataset['train'][0].items()})[0]
print(sample_input)

In [None]:
# Sanity Check
prompt_len = len(tokenizer.encode(prompt_template))
if prompt_len > 180:
    raise ValueError(f"Your prompt is too long! Please reduce the length from {prompt_len} to 180 tokens")
print(f"Prompt Length: {prompt_len} tokens")

### 프롬프트 데이터 입출력 확인

In [None]:
# 샘플 프롬프트 데이터 생성
sample_idx = 10
sample_data = dataset['train'][sample_idx]
sample_fmt_data = format_dataset({k: [v] for k, v in sample_data.items()})
print(sample_fmt_data[0])

In [None]:
# 샘플 프롬프트 입력 데이터 (input)
sample_input = sample_fmt_data[0].split("Output: ")[0] + "Output: "
print(sample_input)

In [None]:
sample_output = sample_fmt_data[0].split("Output: ")[1]
print(sample_output)

### 프롬프트 입력 후 출력 생성

In [None]:
input_ids = tokenizer.encode(sample_input, return_tensors='pt').to('cuda')

# 모델을 사용하여 입력 시퀀스에 대한 출력 생성
with torch.no_grad():
  output = model.generate(
      input_ids=input_ids,
      max_length=512,
      use_cache=True,
      temperature=0.,
      eos_token_id=tokenizer.eos_token_id,
)

# 생성된 출력을 디코딩하여 텍스트로 변환
print(tokenizer.decode(output.to('cpu')[0], skip_special_tokens=True))

In [None]:
# 실제 Output에 해당하는 부분만 필터링
print(tokenizer.decode(output.to('cpu')[0], skip_special_tokens=True).split("Output: ")[1])

### 학습할 데이터셋 정의  (`train_dataset`)

In [None]:
TRAIN_DATASET_SIZE = 2000
train_dataset = dataset['train']
sampled_train_dataset = train_dataset.select(range(TRAIN_DATASET_SIZE))

### Data Collator 정의 (`data_collator`)

In [None]:
response_template = "Output:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

## [Step 5] 모델 학습

### 학습 환경 세팅 (`LoraConfig`, `SFTConfig`, `SFTTrainer`)

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

lora_config=LoraConfig(
    r=4,
    task_type="CAUSAL_LM",
    target_modules= ["Wqkv", "fc1", "fc2" ]
)

sft_config = SFTConfig(
    output_dir="./results",
    num_train_epochs=1,
    fp16=True,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=1e-4,
    optim="paged_adamw_32bit",
    save_strategy="no",
    warmup_ratio=0.03,
    logging_steps=5,
    lr_scheduler_type="cosine",
    gradient_checkpointing=True,
    max_seq_length=512,
)

trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=sampled_train_dataset,
    formatting_func=format_dataset,
    data_collator=collator,
    peft_config=lora_config,
    processing_class=tokenizer,
)

### 모델 학습하기

In [None]:
trainer.train()

### huggingface에 모델 업로드하기

#### `huggingface-cli` 로그인

In [None]:
!huggingface-cli login # TODO: you need a 'write' type token

#### Trainer 업로드하기

In [None]:
your_name = "" # TODO: huggingface id (e.g. "Sujeongim")
trainer.push_to_hub(f"{your_name}/kosaim2024-phi-2-asclepius")

#### 업로드한 모델 다운받기

In [None]:
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM

## Copy & paste the code that hugging face suggested (click 'Use this model')
your_name = "" # TODO: huggingface id (e.g., "Sujeongim")
config = PeftConfig.from_pretrained(f"{your_name}/results")
base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2")
model = PeftModel.from_pretrained(base_model, f"{your_name}/results")

## [Step 6] 모델 추론



- 데이터를 통해 테스트해볼 수 있는 task의 종류 및 각 task에 해당하는 예시 질문은 다음과 같습니다.

- 질문 형식은 예시 질문에 국한될 필요는 없습니다.




| <b>Task</b> | <b>Task 설명</b> |  <b>예시 질문</b> |
|-----|-----|-----|
|Named Entity Recognition|텍스트에서 사람, 장소, 조직 등 고유명사를 식별합니다.|- Can Named Entity Recognition identify any thrombophilia-related entities in this discharge summary? <br> <br> - What named entities related to COVID-19 infections can be identified through Named Entity Recognition in this discharge summary?|
|Abbreviation Expansion|약어를 원래의 긴 형태로 확장합니다.| - What is the expanded form of the abbreviation 'CSF'? <br> <br> - What are the abbreviated terms in the given discharge summary that require expansion?|
|Relation Extraction|텍스트에서 두 개체 간의 관계를 식별하고 추출합니다.|- What was the treatment provided to the patient with hypokalaemia, malnutrition, and decreased renal function, and how did it improve their symptoms? <br><br> - What is the relationship extracted between ipilimumab treatment and the patient's thyroid storm in the given discharge summary?|
|Temporal Information Extraction|텍스트에서 날짜, 시간과 같은 시간 정보를 식별하고 추출합니다.|- When was the patient discharged following surgery? <br><br> - When did the patient first complain of swelling in the right sternoclavicular joint, and how long did it take to significantly resolve symptoms with therapy?|
|Coreference Resolution|문맥에서 같은 대상을 가리키는 다른 표현(지시어)을 연결합니다.|- What coreferences are resolved in the hospital course section related to the patient's diagnosis of DHR? <br><br> - What pronouns or nouns in the hospital course section of the discharge summary were subject to coreference resolution and how were they resolved?|
|Paraphrasing|문장을 다른 표현으로 바꾸어 재구성합니다.|- Can you rephrase the sentence "The patient was deemed to have a guarded prognosis with multiorgan failure" in a simpler way for a patient or family member to understand? <br><br> - How can the hospital course summary be paraphrased to make it more easily comprehensible for the patient and their family?|
|Summarization|긴 텍스트에서 중요한 정보를 추출하여 짧게 요약합니다|- What is the summary of the patient's diagnosis and treatment during hospitalization and discharge? <br><br> - What was the primary diagnosis and treatment plan for the patient in the given discharge summary, and what persistent symptoms did they experience despite the treatment?|
|Question Answering|텍스트를 기반으로 질문에 대한 답을 제공합니다.|- What was the patient diagnosed with and what treatment was chosen for his refractory ascites? <br><br> - What was the treatment plan for the patient's multi-system process, and how effective was it in achieving remission?|


In [None]:
# 비교 평가하기
model = trainer.model
model.eval()

note_samples = train_dataset.select(range(len(train_dataset)-10, len(train_dataset)))['note']

def inference(note, question, model):
    prompt = prompt_template.format(note=note, question=question, answer="")
    tokens = tokenizer.encode(prompt, return_tensors="pt").to('cuda')
    outs = model.generate(
        input_ids=tokens,
        max_length=512,
        use_cache=True,
        temperature=0.,
        eos_token_id=tokenizer.eos_token_id
    )
    output_text = tokenizer.decode(outs.to('cpu')[0], skip_special_tokens=True)
    return output_text[len(prompt):]


def compare_models(note, question):
    with torch.no_grad():
        asc_answer = inference(note, question, trainer.model)
        with model.disable_adapter():
            phi_answer = inference(note, question, trainer.model)
    return asc_answer, phi_answer

demo = gr.Interface(
    fn=compare_models,
    inputs=[gr.Dropdown(note_samples), "text"],
    outputs=[gr.Textbox(label="Asclepius"), gr.Textbox(label="Phi-2")]
)
demo.launch(share=True)