In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd gdrive/My Drive/TRLWS

In [None]:
!pip install -qqq "xformers<0.0.24"

In [None]:
!pip install -qqq "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [None]:
!pip install -qqq --no-deps "trl<0.9.0" peft accelerate bitsandbytes

In [None]:
import torch
from trl import SFTTrainer
from datasets import Dataset, load_dataset, concatenate_datasets
from transformers import TrainingArguments, TextStreamer
from unsloth.chat_templates import get_chat_template
from unsloth import FastLanguageModel, is_bfloat16_supported
from tqdm import tqdm


In [None]:
# Load model
max_seq_length = 2048
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    dtype=None,
)

# Prepare model for PEFT
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=16,
    lora_dropout=0,
    target_modules=["q_proj", "k_proj", "v_proj", "up_proj", "down_proj", "o_proj", "gate_proj"],
    use_rslora=True,
    use_gradient_checkpointing="unsloth",
    random_state = 0
)
print(model.print_trainable_parameters())

In [None]:
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt", "system": "system"}, # ShareGPT style
    map_eos_token = True, # Maps <|im_end|> to </s> instead
)

def apply_template(examples):
    # Prepend the system message
    system_message = [{"from": "system", "value": "You are a highly knowledgeable medical assistant with expertise in various medical domains. Your primary role is to provide accurate, evidence-based medical information and advice. Your responses should be grounded in the latest clinical guidelines and research, ensuring reliability and precision. When answering questions, consider the context and complexity of the query, and aim to provide the most relevant and correct answers."}]
    formatted_conversations = []
    for example in examples["conversations"]:
        # Combine the system message with the user/assistant conversation
        conversation = system_message + example
        formatted_conversations.append(conversation)

    #messages = examples["conversations"]
    #text = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=False) for message in messages]
    text = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in formatted_conversations]
    return {"text": text}



In [None]:
# peparing MedQA dataset

medqa_dataset = load_dataset("GBaker/MedQA-USMLE-4-options-hf", split="train")

converted_conversations = []

for example in medqa_dataset:
    # Extract the question and answer options
    question = example['sent1']
    #options = [example['ending0'], example['ending1'], example['ending2'], example['ending3']]
    options = [
        f"A) {example['ending0']}",
        f"B) {example['ending1']}",
        f"C) {example['ending2']}",
        f"D) {example['ending3']}"
    ]
    #correct_answer = options[example['label']]
    #correct_answer = chr(65 + example['label'])  # Convert 0-3 to A-D

    correct_answer = f"Answer: {options[example['label']]}"

    options_str = "\n".join(options)
    input_text = f"Question: {question}\nOptions:\n{options_str}"

    # Create a ChatML-like conversation for each data point
    conversation = [
        {"from": "human", "value": input_text},
        {"from": "gpt", "value": correct_answer}
    ]

    # Add the conversation to the list
    converted_conversations.append(conversation)

formatted_medqa_dataset = Dataset.from_dict({"conversations": converted_conversations})

formatted_medqa_dataset = formatted_medqa_dataset.map(apply_template, batched=True)


In [None]:
# peparing MedMCQA dataset

med_mcqa_dataset = load_dataset("openlifescienceai/medmcqa", split="train")

converted_conversations = []

for example in med_mcqa_dataset:
    # Extract the question and answer options
    question = example['question']
    #options = [example['ending0'], example['ending1'], example['ending2'], example['ending3']]
    options = [
        f"A) {example['opa']}",
        f"B) {example['opb']}",
        f"C) {example['opc']}",
        f"D) {example['opd']}"
    ]
    #correct_answer = options[example['label']]
    #correct_answer = chr(65 + example['label'])  # Convert 0-3 to A-D

    correct_answer = f"Answer: {options[example['cop']]}"

    options_str = "\n".join(options)
    input_text = f"Question: {question}\nOptions:\n{options_str}"

    # Create a ChatML-like conversation for each data point
    conversation = [
        {"from": "human", "value": input_text},
        {"from": "gpt", "value": correct_answer}
    ]

    # Add the conversation to the list
    converted_conversations.append(conversation)

formatted_med_mcqa_dataset = Dataset.from_dict({"conversations": converted_conversations})

#print(formatted_med_mcqa_dataset[0])

formatted_med_mcqa_dataset = formatted_med_mcqa_dataset.map(apply_template, batched=True)

In [None]:
# peparing PubMedQA dataset

pub_medqa_dataset = load_dataset("qiaojin/PubMedQA", "pqa_artificial", split="train")

converted_conversations = []

for example in pub_medqa_dataset:
    # Extract the question and answer options
    context = example['context']
    #print(type(context))
    #context = context['contexts'][0]
    context = " ".join(context['contexts'])
    question = example['question']
    #correct_answer = example['final_decision']
    options = [
        f"A) yes",
        f"B) no",
        f"C) maybe"
    ]

    answer_to_option = {
        "yes": "A) yes",
        "no": "B) no",
        "maybe": "C) maybe"
    }

    correct_answer = example['final_decision']
    correct_answer = answer_to_option.get(correct_answer, "Unknown")

    correct_answer = f"Answer: {correct_answer}"

    options_str = "\n".join(options)
    input_text = f"Context: {context}\nQuestion: {question}\nOptions:\n{options_str}"

    # Create a ChatML-like conversation for each data point
    conversation = [
        {"from": "human", "value": input_text},
        {"from": "gpt", "value": correct_answer}
    ]

    # Add the conversation to the list
    converted_conversations.append(conversation)

formatted_pub_medqa_dataset = Dataset.from_dict({"conversations": converted_conversations})

formatted_pub_medqa_dataset = formatted_pub_medqa_dataset.map(apply_template, batched=True)

In [None]:
#print("MedQA dataset:", formatted_medqa_dataset)
#print("MedMCQA dataset:", formatted_med_mcqa_dataset)
#print("PubMedQA dataset:", formatted_pub_medqa_dataset)

#print(formatted_medqa_dataset[1])
#print(formatted_med_mcqa_dataset[15])
#print(formatted_pub_medqa_dataset[0])
print(formatted_medqa_dataset[0]["text"])
print(formatted_med_mcqa_dataset[0]["text"])
print(formatted_pub_medqa_dataset[9]["text"])

In [None]:
#concatenate  datasets

combined_dataset = concatenate_datasets([formatted_medqa_dataset, formatted_med_mcqa_dataset, formatted_pub_medqa_dataset])
combined_dataset = combined_dataset.shuffle(seed=42)

print("combined_dataset:", combined_dataset)

train_val_split = combined_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = train_val_split['train']
val_dataset = train_val_split['test']

print("training dataset:", train_dataset)
print("validation dataset:", val_dataset)

In [None]:
trainer=SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=True,
    args=TrainingArguments(
        learning_rate=3e-4,
        lr_scheduler_type="linear",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        num_train_epochs=1,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=360,
        optim="adamw_8bit",
        weight_decay=0.01,
        warmup_steps=10,
        output_dir="output",
        seed=0,
        eval_strategy="steps",
        eval_steps=360
    ),
)

trainer.train()
#Below is how to resume training if you are interrupted in the middle of the fine-tuning process.
#trainer.train(resume_from_checkpoint=True)

In [None]:
model.save_pretrained_merged("model_dir", tokenizer, save_method="merged_16bit")
#model.push_to_hub_merged("mlabonne/FineLlama-3.1-8B", tokenizer, save_method="merged_16bit")


In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "/content/gdrive/My Drive/TRLWS/model_dir",
    max_seq_length = max_seq_length,
    dtype=None,
    load_in_4bit=True,
)

In [None]:
# Inference on MedQA test set
# Load model for inference
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt", "system": "system"}, # ShareGPT style
    map_eos_token = True, # Maps <|im_end|> to </s> instead
)

model = FastLanguageModel.for_inference(model)

# Load test dataset
test_dataset = load_dataset("GBaker/MedQA-USMLE-4-options-hf", split="test")
#print(test_dataset)

correct = 0
total = 0

for example in tqdm(test_dataset):
  question = example['sent1']

  options = [
      f"A) {example['ending0']}",
      f"B) {example['ending1']}",
      f"C) {example['ending2']}",
      f"D) {example['ending3']}"
  ]
  correct_answer = chr(65 + example['label'])  # Convert 0-3 to A-D
  print("Correct Answer:", correct_answer)

  # Prepare input
  options_str = "\n".join(options)
  input_text = f"Question: {question}\nOptions:\n{options_str}\nAnswer:"
  #print(input_text)
  system_message = "You are a highly knowledgeable medical assistant with expertise in various medical domains. Your primary role is to provide accurate, evidence-based medical information and advice. Your responses should be grounded in the latest clinical guidelines and research, ensuring reliability and precision. When answering questions, consider the context and complexity of the query, and aim to provide the most relevant and correct answers."
  messages = [
      {"from": "system", "value": system_message},
      {"from": "human", "value": input_text}
  ]

  inputs = tokenizer.apply_chat_template(
      messages,
      tokenize = True,
      add_generation_prompt = True, # Must add for generation
      return_tensors = "pt",
  ).to("cuda")


  outputs = model.generate(input_ids = inputs, max_new_tokens = 128)
  predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=False)

  #print("predicted text:",predicted_text)

  predicted_answer = ""

  # Define the possible options
  options_list = ["A", "B", "C", "D"]

  # Search for the pattern "Answer: <option>"
  for option in options_list:
      if f"Answer: {option}" in predicted_text:
          predicted_answer += option
  print("Predicted Answer:",predicted_answer)

  # Check accuracy
  if predicted_answer == correct_answer:
    correct += 1
    print("Accurately predicted")
  else:
    print("Prediction is wrong")

  total += 1


accuracy = correct / total
print("Accuracy:", accuracy)




In [None]:
# Inference on MedMCQA validation set

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt", "system": "system"}, # ShareGPT style
    map_eos_token = True, # Maps <|im_end|> to </s> instead
)

# Load model for inference
model = FastLanguageModel.for_inference(model)

# Load test dataset
test_dataset = load_dataset("openlifescienceai/medmcqa", split="validation")

correct = 0
total = 0

for example in tqdm(test_dataset):
  #
  question = example['question']

  options = [
      f"A) {example['opa']}",
      f"B) {example['opb']}",
      f"C) {example['opc']}",
      f"D) {example['opd']}"
  ]

  correct_answer = chr(65 + example['cop'])  # Convert 0-3 to A-D
  print("Correct Answer:", correct_answer)

  # Prepare input
  options_str = "\n".join(options)
  input_text = f"Question: {question}\nOptions:\n{options_str}\nAnswer:"


  system_message = "You are a highly knowledgeable medical assistant with expertise in various medical domains. Your primary role is to provide accurate, evidence-based medical information and advice. Your responses should be grounded in the latest clinical guidelines and research, ensuring reliability and precision. When answering questions, consider the context and complexity of the query, and aim to provide the most relevant and correct answers."
  messages = [
      {"from": "system", "value": system_message},
      {"from": "human", "value": input_text}
  ]

  inputs = tokenizer.apply_chat_template(
      messages,
      tokenize = True,
      add_generation_prompt = True, # Must add for generation
      return_tensors = "pt",
  ).to("cuda")


  outputs = model.generate(input_ids = inputs, max_new_tokens = 128)
  predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=False)

  #print("predicted text:",predicted_text)
  predicted_answer = ""

  # Define the possible options
  options_list = ["A", "B", "C", "D"]

  # Search for the pattern "Answer: <option>"
  for option in options_list:
      if f"Answer: {option}" in predicted_text:
          predicted_answer += option
  print("Predicted Answer:", predicted_answer)

  # Check accuracy
  if predicted_answer == correct_answer:
    correct += 1
    print("Acurately predicted")
  else:
    print("Prediction is wrong")

  total += 1


accuracy = correct / total
print("Accuracy:", accuracy)



In [None]:
# Inference on PubMedQA validation set

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt", "system": "system"}, # ShareGPT style
    map_eos_token = True, # Maps <|im_end|> to </s> instead
)

# Load model for inference
model = FastLanguageModel.for_inference(model)

# Load test dataset
test_dataset = load_dataset("qiaojin/PubMedQA", "pqa_labeled", split="train")

correct = 0
total = 0

for example in tqdm(test_dataset):
  #
  context = example['context']
  context = " ".join(context['contexts'])
  question = example['question']

  options = [
      "A) yes",
      "B) no",
      "C) maybe"
  ]

  # Create a mapping from answer to option
  answer_to_option = {
      "yes": "A",
      "no": "B",
      "maybe": "C"
  }

  correct_answer = example['final_decision']
  correct_answer = answer_to_option.get(correct_answer, "Unknown")
  #correct_answer = chr(65 + example['final_decision'])  # Convert 0-3 to A-D
  print("Correct Answer:", correct_answer)

  # Prepare input
  options_str = "\n".join(options)
  input_text = f"Context: {context}\nQuestion: {question}\nOptions:\n{options_str}\nAnswer:"

  system_message = "You are a highly knowledgeable medical assistant with expertise in various medical domains. Your primary role is to provide accurate, evidence-based medical information and advice. Your responses should be grounded in the latest clinical guidelines and research, ensuring reliability and precision. When answering questions, consider the context and complexity of the query, and aim to provide the most relevant and correct answers."
  messages = [
      {"from": "system", "value": system_message},
      {"from": "human", "value": input_text}
  ]

  inputs = tokenizer.apply_chat_template(
      messages,
      tokenize = True,
      add_generation_prompt = True, # Must add for generation
      return_tensors = "pt",
  ).to("cuda")


  outputs = model.generate(input_ids = inputs, max_new_tokens = 128)
  predicted_text = tokenizer.decode(outputs[0], skip_special_tokens=False)

  #print("predicted text:",predicted_text)
  predicted_answer = ""

  # Define the possible options
  options_list = ["A", "B", "C"]

  # Search for the pattern "Answer: <option>"
  for option in options_list:
      if f"Answer: {option}" in predicted_text:
          predicted_answer += option
  print("Predicted Answer:", predicted_answer)

  # Check accuracy
  if predicted_answer == correct_answer:
    correct += 1
    print("Acurately predicted")
  else:
    print("Prediction is wrong")

  total += 1


accuracy = correct / total
print("Accuracy:", accuracy)

