# Fine-Tuning T5 for Healthcare

In [1]:
import os
import torch
import time
import json
import pandas as pd
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, GenerationConfig, TrainingArguments, Trainer
from transformers import Trainer, TrainingArguments
from transformers import TextDataset, DataCollatorForLanguageModeling
from datasets import load_dataset, concatenate_datasets, Dataset

2024-10-07 11:01:52.829864: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Environment
Check settings

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"        
print(f"Using device: {device}")

Using device: cpu


## Data
Load datasets

In [44]:
dataset_train = load_dataset("ruslanmv/ai-medical-chatbot", split='train[:8000]')
dataset_train = dataset_train.rename_columns({'Description':'instruction', 'Patient':'context', 'Doctor':'response'})

dataset_test = load_dataset("ruslanmv/ai-medical-chatbot", split='train[-2000:-1000]')
dataset_test = dataset_test.rename_columns({'Description':'instruction', 'Patient':'context', 'Doctor':'response'})

dataset_validation = load_dataset("ruslanmv/ai-medical-chatbot", split='train[-1000:]')
dataset_validation = dataset_validation.rename_columns({'Description':'instruction', 'Patient':'context', 'Doctor':'response'})

In [45]:
dataset_train_merged = concatenate_datasets(
    [dataset_train]
)
dataset_test_merged = concatenate_datasets(
    [dataset_test]
)
dataset_validation_merged = concatenate_datasets(
    [dataset_validation]
)

In [46]:
dataset_train_merged.to_csv('train_merged.csv', index=False)
dataset_test_merged.to_csv('test_merged.csv', index=False)
dataset_validation_merged.to_csv('validation_merged.csv', index=False)

Creating CSV from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Creating CSV from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

984002

In [3]:
dataset = load_dataset('csv', data_files={
    "train": "train_merged.csv", 
    "test": "test_merged.csv", 
    "validation": "validation_merged.csv"
    })

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['instruction', 'context', 'response'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['instruction', 'context', 'response'],
        num_rows: 1000
    })
    validation: Dataset({
        features: ['instruction', 'context', 'response'],
        num_rows: 1000
    })
})

In [5]:
dataset['train'][0]

{'instruction': 'Q. What does abutment of the nerve root mean?',
 'context': 'Hi doctor,I am just wondering what is abutting and abutment of the nerve root means in a back issue. Please explain. What treatment is required for\xa0annular bulging and tear?',
 'response': 'Hi. I have gone through your query with diligence and would like you to know that I am here to help you. For further information consult a neurologist online -->'}

In [6]:
dataset_train = pd.read_csv('./train_merged.csv')

In [7]:
dataset_train.head()

Unnamed: 0,instruction,context,response
0,Q. What does abutment of the nerve root mean?,"Hi doctor,I am just wondering what is abutting...",Hi. I have gone through your query with dilige...
1,Q. What should I do to reduce my weight gained...,"Hi doctor, I am a 22-year-old female who was d...",Hi. You have really done well with the hypothy...
2,Q. I have started to get lots of acne on my fa...,Hi doctor! I used to have clear skin but since...,Hi there Acne has multifactorial etiology. Onl...
3,Q. Why do I have uncomfortable feeling between...,"Hello doctor,I am having an uncomfortable feel...",Hello. The popping and discomfort what you fel...
4,Q. My symptoms after intercourse threatns me e...,"Hello doctor,Before two years had sex with a c...",Hello. The HIV test uses a finger prick blood ...


In [8]:
idx = 100
idx_var = 50

check = dataset_train.iloc[idx]['instruction']
check2 = dataset_train.iloc[idx]['response']

test = dataset_train.iloc[idx_var]['instruction']
test_response = dataset_train.iloc[idx_var]['response']

print(f'instruction:\n{test}\n\nresponse:\n')

print(f'======================\nexpected reponse: \n {test_response}')

instruction:
Q. Shall I take Raw Bovine Ovary pills to increase breast size?

response:

expected reponse: 
 Hi. For further information consult an internal medicine physician online -->


## Model and Tokenizer
Define configuration settings

In [9]:
model_name='t5-small'
os.environ['TOKENIZERS_PARALLELISM'] = 'true' 

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
base_model = base_model.to(device)

In [11]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [12]:
tokenizer.chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

In [13]:
input_text = "translate English to German: How old are you?"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

outputs = base_model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

<pad> Wie alt sind Sie?</s>




In [14]:
def get_response(question):
    test_prompt = f'instruction:\n{question}\n\nresponse:\n'
    input_text = tokenizer(test_prompt,return_tensors="pt").input_ids
    output = base_model.generate(
        input_ids=input_text, 
        generation_config=GenerationConfig(max_new_tokens=200, 
        num_beams=1))
    return tokenizer.decode(output[0], skip_special_tokens=True)

In [15]:
question = "Shall I take Raw Bovine Ovary pills to increase breast size?"
get_response(question)

':'

In [16]:
def tokenize_function(example):
    start_prompt = 'Instruction:\n'
    end_prompt = '\nResponse:'
    prompt = [start_prompt + dialogue + end_prompt for dialogue in example["instruction"]]
    example['input_ids'] = tokenizer(prompt, padding="max_length", truncation=True, return_tensors="pt").input_ids
    example['labels'] = tokenizer(example["response"], padding="max_length", truncation=True, return_tensors="pt").input_ids
    
    return example

In [17]:
shuffled_dataset = dataset.shuffle(seed=42)

In [18]:
tokenized_datasets = shuffled_dataset.map(tokenize_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['instruction', 'context', 'response'])

Map:   0%|          | 0/8000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [19]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 1000
    })
    validation: Dataset({
        features: ['input_ids', 'labels'],
        num_rows: 1000
    })
})

Fine Tune

In [20]:
finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
finetuned_model = finetuned_model.to('cpu')
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [21]:
output_dir = 't5-healthcare-log'

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    learning_rate=5e-3,
    num_train_epochs=2,
    per_device_train_batch_size=16,     # batch size per device during training
    per_device_eval_batch_size=16,      # batch size for evaluation
    weight_decay=0.01,
    logging_steps=50,
    evaluation_strategy='steps',        # evaluation strategy to adopt during training
    eval_steps=500,                 
)

trainer = Trainer(
    model=finetuned_model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
)



In [22]:
trainer.evaluate()

  0%|          | 0/63 [00:00<?, ?it/s]

{'eval_loss': 5.344299793243408,
 'eval_model_preparation_time': 0.002,
 'eval_runtime': 40.3217,
 'eval_samples_per_second': 24.801,
 'eval_steps_per_second': 1.562}

In [23]:
%%time

trainer.train()

  0%|          | 0/1000 [00:00<?, ?it/s]

{'loss': 1.6891, 'grad_norm': 0.17573419213294983, 'learning_rate': 0.00475, 'epoch': 0.1}
{'loss': 0.8226, 'grad_norm': 0.11860419809818268, 'learning_rate': 0.0045000000000000005, 'epoch': 0.2}
{'loss': 0.8015, 'grad_norm': 0.16136465966701508, 'learning_rate': 0.00425, 'epoch': 0.3}
{'loss': 0.7289, 'grad_norm': 0.14195092022418976, 'learning_rate': 0.004, 'epoch': 0.4}
{'loss': 0.7817, 'grad_norm': 0.1362648457288742, 'learning_rate': 0.00375, 'epoch': 0.5}
{'loss': 0.7488, 'grad_norm': 0.14362695813179016, 'learning_rate': 0.0034999999999999996, 'epoch': 0.6}
{'loss': 0.7824, 'grad_norm': 0.08650536835193634, 'learning_rate': 0.0032500000000000003, 'epoch': 0.7}
{'loss': 0.7207, 'grad_norm': 0.10980451852083206, 'learning_rate': 0.003, 'epoch': 0.8}
{'loss': 0.7242, 'grad_norm': 0.14244355261325836, 'learning_rate': 0.0027500000000000003, 'epoch': 0.9}
{'loss': 0.6992, 'grad_norm': 0.12006626278162003, 'learning_rate': 0.0025, 'epoch': 1.0}


  0%|          | 0/63 [00:00<?, ?it/s]

{'eval_loss': 1.034824013710022, 'eval_model_preparation_time': 0.002, 'eval_runtime': 38.3902, 'eval_samples_per_second': 26.048, 'eval_steps_per_second': 1.641, 'epoch': 1.0}
{'loss': 0.664, 'grad_norm': 0.10889982432126999, 'learning_rate': 0.0022500000000000003, 'epoch': 1.1}
{'loss': 0.6737, 'grad_norm': 0.10321959108114243, 'learning_rate': 0.002, 'epoch': 1.2}
{'loss': 0.6874, 'grad_norm': 0.13431867957115173, 'learning_rate': 0.0017499999999999998, 'epoch': 1.3}
{'loss': 0.6764, 'grad_norm': 0.12763723731040955, 'learning_rate': 0.0015, 'epoch': 1.4}
{'loss': 0.6824, 'grad_norm': 0.1370091289281845, 'learning_rate': 0.00125, 'epoch': 1.5}
{'loss': 0.6623, 'grad_norm': 0.12862040102481842, 'learning_rate': 0.001, 'epoch': 1.6}
{'loss': 0.6223, 'grad_norm': 0.0874997228384018, 'learning_rate': 0.00075, 'epoch': 1.7}
{'loss': 0.6112, 'grad_norm': 0.11442851275205612, 'learning_rate': 0.0005, 'epoch': 1.8}
{'loss': 0.6407, 'grad_norm': 0.101710744202137, 'learning_rate': 0.00025, '

  0%|          | 0/63 [00:00<?, ?it/s]

{'eval_loss': 1.0010883808135986, 'eval_model_preparation_time': 0.002, 'eval_runtime': 38.8036, 'eval_samples_per_second': 25.771, 'eval_steps_per_second': 1.624, 'epoch': 2.0}
{'train_runtime': 3646.1503, 'train_samples_per_second': 4.388, 'train_steps_per_second': 0.274, 'train_loss': 0.755041633605957, 'epoch': 2.0}
CPU times: user 9min 59s, sys: 5min 29s, total: 15min 28s
Wall time: 1h 46s


TrainOutput(global_step=1000, training_loss=0.755041633605957, metrics={'train_runtime': 3646.1503, 'train_samples_per_second': 4.388, 'train_steps_per_second': 0.274, 'total_flos': 2165468823552000.0, 'train_loss': 0.755041633605957, 'epoch': 2.0})

In [24]:
finetuned_model.save_pretrained("t5-healthcare")

In [25]:
tokenizer.save_pretrained("t5-healthcare")

('t5-healthcare/tokenizer_config.json',
 't5-healthcare/special_tokens_map.json',
 't5-healthcare/tokenizer.json')

In [26]:
model_path = "t5-healthcare"
finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
finetuned_model = finetuned_model.to('cpu')

finetuned_tokenizer = AutoTokenizer.from_pretrained(model_path)

In [27]:
def get_response(question):
    test_prompt = f'instruction:\n{question}\n\nresponse:\n'
    input_text = finetuned_tokenizer(test_prompt,return_tensors="pt").input_ids
    output = finetuned_model.generate(
        input_ids=input_text, 
        generation_config=GenerationConfig(max_new_tokens=200, 
        num_beams=1))
    return finetuned_tokenizer.decode(output[0], skip_special_tokens=True)

In [28]:
question = "i have a question about cancelling order {{Order Number}}"
get_response(question)

'Hello. I have gone through your query and can understand your concern. For further information consult an internal medicine physician online -->'

In [29]:
for i in range(0, 5, 1):
    print('Instruction: ' + dataset['test'][i]['instruction'])
    print('Predict. :' + get_response(dataset['test'][i]['instruction']))
    print('Expected: ' + dataset['test'][i]['response'])
    print('=================================\n')

Instruction: 23 year old married female, suffering with pimples and scars ?
Predict. :Hi. I have gone through your case. I can understand your concern. I have gone through the attachment (attachment removed to protect patient identity). I can understand your concern. I would like to know if you have any other history of pimples or pimples? Is there any history of pimples or pimples? Is there any history of pimples or pimples? Is there any history of pimples or pimples? Is there any history of pimples or pimples?
Expected: hello, pimples have become a life style disease & a disease of metropolitan cities...we hardly see them in villages.But can also be because of hormonal disturbances in our body,esp if it increases 2 days prior to periods.Never use a moisturizing cream on face.Avoid using oily soaps too.Use specific anti acne soap like acnelak or dermadew acne soap twice a day.Use nadibact cream over pimples in day time and adaferin at night.This might help you a lot.Now a days there a

Evaluation

In [30]:
import evaluate

In [31]:
rouge = evaluate.load('rouge')

In [32]:
instructions = dataset['test'][0:25]['instruction']
responses = dataset['test'][0:25]['response']

base_model_responses = []
finetuned_model_responses = []
peft_model_responses = []

for instruction in instructions:
    prompt = f""" 
    instruction:
    {instruction}

    response:
    """
    
    input_ids = tokenizer.encode(prompt, return_tensors='pt')

    base_model_outputs = base_model.generate(input_ids=input_ids,generation_config=GenerationConfig(max_new_tokens=200))
    base_model_output = tokenizer.decode(base_model_outputs[0], skip_special_tokens=True)

    finetuned_model_outputs = finetuned_model.generate(input_ids=input_ids,generation_config=GenerationConfig(max_new_tokens=200))
    finetuned_model_output = tokenizer.decode(finetuned_model_outputs[0], skip_special_tokens=True)


    # peft_model_outputs = peft_model.generate(input_ids=input_ids,generation_config=GenerationConfig(max_new_tokens=200))
    # peft_model_output = tokenizer.decode(peft_model_outputs[0], skip_special_tokens=True)

    base_model_responses.append(base_model_output)
    finetuned_model_responses.append(finetuned_model_output)
    # peft_model_responses.append(peft_model_output)

In [33]:
zipped_responses = list(zip(responses, base_model_responses, finetuned_model_responses))
df = pd.DataFrame(zipped_responses, columns=['human','base','finetuned'])
df

Unnamed: 0,human,base,finetuned
0,"hello, pimples have become a life style diseas...",,Hi. I have gone through your case. I can under...
1,"hello dear , you havnt mentioned ur age.Any wa...","Instruktion: Gutes ointment, um die pimples un...",Hi. I have gone through your case. I can under...
2,Your options from here are:- Try home therapy ...,instruction: 18-jährige girl suffering with in...,Hello. I have gone through your query and I ca...
3,"I got you, man If you have long hair, get it s...",Instruction: Having pimples on face response:,Hi. I have gone through your query and can und...
4,acne can be treated with local applications of...,: How to remove pimples.,Hi. I have gone through your query and can und...
5,i think u have that problem be bold to face an...,"- My friend has pimples, how to get rid of it?",Hi. I have gone through your case. I can under...
6,You may use ACNIL SOAP to wash your face 3-4 t...,Antwort:,Hello. I have gone through your query and can ...
7,"Hi, Thanks for the query. I understand your ...",:: Is pregnancy possible through non-pénétrati...,Hi. I have gone through your query with dilige...
8,"Hello Welcome to HCM, thanks for posting your ...",Antwort:,Hi. I have gone through your query with dilige...
9,I am not getting you.,Instruktion: White pimples around tip of my pe...,Hi. I have gone through your query. I can unde...


In [34]:
responses = dataset['test'][0:25]['response']

base_model_results = rouge.compute(
    predictions=base_model_responses,
    references=responses,
    use_aggregator=True,
    use_stemmer=True,
)
print('Base\n',base_model_results)

finetuned_model_results = rouge.compute(
    predictions=finetuned_model_responses,
    references=responses,
    use_aggregator=True,
    use_stemmer=True,
)
print('Fine-tuned\n',finetuned_model_results)

# peft_model_results = rouge.compute(
#     predictions=peft_model_responses,
#     references=response,
#     use_aggregator=True,
#     use_stemmer=True,
# )
# print('\PEFT\n',peft_model_results)

Base
 {'rouge1': 0.03300179735629219, 'rouge2': 0.002857142857142857, 'rougeL': 0.0276179280413884, 'rougeLsum': 0.027777065448010948}
Fine-tuned
 {'rouge1': 0.10866909425618543, 'rouge2': 0.005276240164577634, 'rougeL': 0.0723023524180746, 'rougeLsum': 0.07273447515611937}
