<a href="https://colab.research.google.com/github/e-gluzman/biomedical-ai/blob/main/medical_question_llm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook, I showcase finetuning a state of the art open source Large Language Model (Mistral) to solve Medical Questions using the Medical Question Dataset.

I load the LLM from Hugging Face, create prompts from training data, fine-tune the model on the medical question dataset and then apply it to medical question answering.


---

Some useful resources: <br>
https://www.datacamp.com/tutorial/mistral-7b-tutorial
https://saankhya.medium.com/large-language-models-llms-a-comprehensive-guide-58ce825c8c0b

In [None]:
!pip install -q accelerate peft bitsandbytes
!pip install -q git+https://github.com/huggingface/transformers
!pip install -q trl py7zr auto-gptq optimum

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [3]:
import pandas as pd
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GPTQConfig, TrainingArguments, GenerationConfig
import torch
from datasets import Dataset
from peft import LoraConfig, AutoPeftModelForCausalLM, prepare_model_for_kbit_training, get_peft_model, PeftConfig
from trl import SFTTrainer
import os
import tqdm
import json
from huggingface_hub import notebook_login


In [None]:
notebook_login()

In [None]:
# lets load and view the traning data
dataset = load_dataset("medmcqa")

In [None]:
train_df = pd.DataFrame(dataset['train'])

In [None]:
train_df.head()

Unnamed: 0,id,question,opa,opb,opc,opd,cop,choice_type,exp,subject_name,topic_name,text
0,e9ad821a-c438-4965-9f77-760819dfa155,Chronic urethral obstruction due to benign pri...,Hyperplasia,Hyperophy,Atrophy,Dyplasia,2,single,Chronic urethral obstruction because of urinar...,Anatomy,Urinary tract,\n Question:\n Chronic urethral obstruct...
1,e3d3c4e1-4fb2-45e7-9f88-247cc8f373b3,Which vitamin is supplied from only animal sou...,Vitamin C,Vitamin B7,Vitamin B12,Vitamin D,2,single,Ans. (c) Vitamin B12 Ref: Harrison's 19th ed. ...,Biochemistry,Vitamins and Minerals,\n Question:\n Which vitamin is supplied...
2,5c38bea6-787a-44a9-b2df-88f4218ab914,All of the following are surgical options for ...,Adjustable gastric banding,Biliopancreatic diversion,Duodenal Switch,Roux en Y Duodenal By pass,3,multi,"Ans. is 'd' i.e., Roux en Y Duodenal Bypass Ba...",Surgery,Surgical Treatment Obesity,\n Question:\n All of the following are ...
3,cdeedb04-fbe9-432c-937c-d53ac24475de,Following endaerectomy on the right common car...,Central aery of the retina,Infraorbital aery,Lacrimal aery,Nasociliary aretry,0,multi,The central aery of the retina is a branch of ...,Ophthalmology,,\n Question:\n Following endaerectomy on...
4,dc6794a3-b108-47c5-8b1b-3b4931577249,Growth hormone has its effect on growth through?,Directly,IG1-1,Thyroxine,Intranuclear receptors,1,single,"Ans. is 'b' i.e., IGI-1GH has two major functi...",Physiology,,\n Question:\n Growth hormone has its ef...


In [None]:
# lets generate instructions for model training and encode them into 'text' column
def generate_prompt(x):
    cop = 'Nothing'
    if x['cop'] == 0:
        cop = x['opa']
    elif x['cop'] == 1:
        cop = x['opb']
    elif x['cop'] == 2:
        cop = x['opc']
    elif x['cop'] == 3:
        cop = x['opd']
    question = '{}\nOptions:\n1. {}\n2. {}\n3. {}\n4. {}\n'.format(x['question'], x['opa'], x['opb'], x['opc'], x['opd'])
    answer = cop
    prompt = f"""
    Question:
    {question}
    [INST] Solve this post graduate medical entrance exam MCQ and provide the correct option. [/INST]
    Answer: {answer} </s>"""
    return prompt

In [None]:
train_df = pd.DataFrame(dataset['train'])

In [None]:
train_df['text'] = train_df.apply(lambda x: generate_prompt(x),axis=1)

In [None]:
# here is an example prompt for our training
train_df['text'][0]

'\n    Question:\n    Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchyma\nOptions:\n1. Hyperplasia\n2. Hyperophy\n3. Atrophy\n4. Dyplasia\n\n    [INST] Solve this post graduate medical entrance exam MCQ and provide the correct option. [/INST]\n    Answer: Hyperophy </s>'

In [None]:
# this code loads the Mistral-7B LLM fine-tuned for instruction prompts
model_id = "mistralai/Mistral-7B-Instruct-v0.1"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id =  tokenizer.unk_token_id
tokenizer.padding_side = 'left'

quantization_config_loading = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False
)

model = AutoModelForCausalLM.from_pretrained(
                                model_id,
                                quantization_config=quantization_config_loading,
                                device_map="auto"
                            )

model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache=False
model.config.pretraining_tp=1
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [None]:
# Lora allows to us accelerate the fine-tuning of large models while consuming less memory.
peft_config = LoraConfig(
                    r=16,
                    lora_alpha=32,
                    lora_dropout=0.05,
                    bias="none",
                    task_type="CAUSAL_LM",
                    target_modules=[
                        "q_proj",
                        "k_proj",
                        "v_proj",
                        "o_proj",
                        "gate_proj",
                        "up_proj",
                        "down_proj",
                        "lm_head",
                    ]
                )

model = get_peft_model(model, peft_config)

In [None]:
data = Dataset.from_pandas(train_df)

In [None]:
training_arguments = TrainingArguments(
                            output_dir="mistral-gptq-finetuned-medmcqa",
                            per_device_train_batch_size=8,
                            gradient_accumulation_steps=1,
                            optim="paged_adamw_32bit",
                            learning_rate=2e-4,
                            lr_scheduler_type="cosine",
                            save_strategy="epoch",
                            logging_steps=50,
                            num_train_epochs=1,
                            max_steps=5000,
                            fp16=True,
                            push_to_hub=True
                        )

trainer = SFTTrainer(
            model=model,
            train_dataset=data,
            peft_config=peft_config,
            dataset_text_field="text",
            args=training_arguments,
            tokenizer=tokenizer,
            packing=False,
            max_seq_length=512
    )

In [None]:
# train the model on our data and save to hub
trainer.train()
trainer.push_to_hub()


In [None]:
username = "egluzman"
model_id = username + "mistral-finetuned-medmcqa-2"

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoPeftModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map="cuda",
    )

In [7]:
generation_config = GenerationConfig(
    do_sample=True,
    top_k=1,
    temperature=0.1,
    max_new_tokens=25,
    pad_token_id=tokenizer.pad_token_id
)

In [8]:
def generate_test_prompt(x):
    question = '{}\nOptions:\n1. {}\n2. {}\n3. {}\n4. {}\n'.format(x['question'], x['opa'], x['opb'], x['opc'], x['opd'])
    prompt = f"""
    Question:
    {question}
    [INST] Solve this post graduate medical entrance exam MCQ and answer correctly. [/INST]
    Answer: """
    return prompt


In [9]:
val_data_df = pd.DataFrame(dataset['validation'])
val_data_df['text'] = val_data_df.apply(lambda x: generate_test_prompt(x),axis=1)

In [10]:
example = val_data_df['text'][3]
inputs = tokenizer(example, return_tensors="pt").to("cuda")

outputs = model.generate(**inputs, generation_config=generation_config)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(answer)


    Question:
    Axonal transport is:
Options:
1. Antegrade
2. Retrograde
3. Antegrade and retrograde
4. None

    [INST] Solve this post graduate medical entrance exam MCQ and answer correctly. [/INST]
    Answer: 1. Antegrade

    Axonal transport refers to the movement of vesicles and organelles along the ax


In [None]:
# Iterate through the set of questions and generate model answers
def solve_question(question_prompt):
    inputs = tokenizer(question_prompt, return_tensors="pt", padding=True, truncation=True).to("cuda")
    outputs = model.generate(**inputs, generation_config=generation_config)
    answer = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return answer

all_answers = []
import re
val_data_prompts = list(val_data_df['text'])
for i in tqdm.tqdm(range(0, len(val_data_prompts), 16)):
    question_prompts = val_data_prompts[i:i+16]
    ans = solve_question(question_prompts)
    ans_option = []
    for text in ans:
        ans_option.append(re.search(r'Answer: \s*(.*)', text).group(1))
    all_answers.extend(ans_option)

In [25]:
correct_answers = []
for i in range(len(val_data_df)):
    if val_data_df['cop'][i] == 0:
        correct_answers.append(val_data_df['opa'][i])
    elif val_data_df['cop'][i] == 1:
        correct_answers.append(val_data_df['opb'][i])
    elif val_data_df['cop'][i] == 2:
        correct_answers.append(val_data_df['opc'][i])
    elif val_data_df['cop'][i] == 3:
        correct_answers.append(val_data_df['opd'][i])
correct_count = 0
for i in range(len(val_data_df)):
    correct_count += correct_answers[i] == all_answers[i]

In [43]:
# Here is an example propmpt and the model's answers
print(question_prompts[0])
print('')
print('The correct answer is:')
print('')
print(correct_answers[0])
print('')
print('The model answered:')
print(all_answers[0])


    Question:
    To remove centric interference, reduce:
Options:
1. Supporting cusps
2. Central fossa
3. Both of the above
4. None

    [INST] Solve this post graduate medical entrance exam MCQ and answer correctly. [/INST]
    Answer: 

The correct answer is:

Impulse through myelinated fibers is slower than non-myelinated fibers

The model answered:
4. Local anesthesia is effective only when the nerve is not covered by myelin sheath


In [32]:
# calculate the model's accuracy
print(f'The final score is {round(correct_count/len(val_data_df),2)} %!')

The final score is 0.37 %!
