In [None]:
!pip install -qqq "xformers<0.0.24"

In [None]:
!pip install -qqq "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

In [None]:
!pip install -qqq --no-deps "trl<0.9.0" peft accelerate bitsandbytes

In [None]:
import torch
from trl import SFTTrainer
from datasets import Dataset, load_dataset, concatenate_datasets
from transformers import TrainingArguments, TextStreamer
from unsloth.chat_templates import get_chat_template
from unsloth import FastLanguageModel, is_bfloat16_supported
from tqdm import tqdm
import json
import pandas as pd


In [None]:
# Load model
max_seq_length = 2048
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="Dasun7/Med-mix-Llama-3.1-8B",
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    dtype=None,
)

# Prepare model for PEFT
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=16,
    lora_dropout=0,
    target_modules=["q_proj", "k_proj", "v_proj", "up_proj", "down_proj", "o_proj", "gate_proj"],
    use_rslora=True,
    use_gradient_checkpointing="unsloth",
    random_state = 0
)
print(model.print_trainable_parameters())

In [None]:
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
    mapping = {"role" : "from", "content" : "value", "user" : "human", "assistant" : "gpt", "system": "system"}, # ShareGPT style
    map_eos_token = True, # Maps <|im_end|> to </s> instead
)

def apply_template(examples):
    # Prepend the system message
    system_message = [{"from": "system", "value": "You are a highly knowledgeable medical assistant with expertise in various medical domains. Your primary role is to provide accurate, evidence-based medical information and advice. Your responses should be grounded in the latest clinical guidelines and research, ensuring reliability and precision. When answering questions, consider the context and complexity of the query, and aim to provide the most relevant and correct answers."}]
    formatted_conversations = []
    for example in examples["conversations"]:
        # Combine the system message with the user/assistant conversation
        conversation = system_message + example
        formatted_conversations.append(conversation)

    text = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) for convo in formatted_conversations]
    return {"text": text}



In [None]:

note_chat_dataset = load_dataset("akemiH/NoteChat", split="train")

In [None]:
# Convert the Hugging Face datasets to Pandas DataFrames
df2 = pd.DataFrame(note_chat_dataset)

In [None]:
# Function to convert conversation format in the second dataset to match the first dataset's format
def convert_conversation(conversation):
    conversation_list = []

    for line in conversation.splitlines():
        line = line.strip()
        if not line or line.startswith("["):  # Skip empty lines and descriptive text in brackets
            continue

        if line.startswith("Doctor:"):
            conversation_list.append({
                "from": "gpt",
                "value": line.replace("Doctor:", "").strip()
            })
        else:
            # Assume any line not starting with "Doctor:" is from the human
            if ": " in line:
                speaker, text = line.split(": ", 1)
                conversation_list.append({
                    "from": "human",
                    "value": text.strip()
                })
            else:
                conversation_list.append({
                    "from": "human",
                    "value": line.strip()
                })

    return conversation_list


# Apply the conversion function to the "conversation" column in df2
# Replace 'conversation' with the actual column name in df2
df2["conversations"] = df2["conversation"].apply(convert_conversation)

df2 = df2.drop(columns=["conversation"])
df2 = df2.drop(columns=["data"])

converted_notechat_dataset = Dataset.from_pandas(df2)


print(converted_notechat_dataset)

In [None]:
converted_notechat_dataset = converted_notechat_dataset.map(apply_template, batched=True)

In [None]:
print(converted_notechat_dataset)
print(converted_notechat_dataset[92795]["conversations"])
print(converted_notechat_dataset[92795]["text"])

In [None]:
# Shuffle the dataset with the specified seed
shuffled_notechat_dataset = converted_notechat_dataset.shuffle(seed=42)

#*************************
# Select the first 100,000 instances for the training set
notechat_train_dataset = shuffled_notechat_dataset.select(range(100000))

#******************************
# Select the next 10,000 instances for the validation set
notechat_val_dataset = shuffled_notechat_dataset.select(range(100000, 110000))

print(notechat_train_dataset)
print(notechat_val_dataset)

In [None]:
# peparing MedQA dataset

medqa_dataset = load_dataset("GBaker/MedQA-USMLE-4-options-hf", split="train")

converted_conversations = []

for example in medqa_dataset:
    # Extract the question and answer options
    question = example['sent1']
    #options = [example['ending0'], example['ending1'], example['ending2'], example['ending3']]
    options = [
        f"A) {example['ending0']}",
        f"B) {example['ending1']}",
        f"C) {example['ending2']}",
        f"D) {example['ending3']}"
    ]


    correct_answer = f"Answer: {options[example['label']]}"

    options_str = "\n".join(options)
    input_text = f"Question: {question}\nOptions:\n{options_str}"

    # Create a ChatML-like conversation for each data point
    conversation = [
        {"from": "human", "value": input_text},
        {"from": "gpt", "value": correct_answer}
    ]

    # Add the conversation to the list
    converted_conversations.append(conversation)

formatted_medqa_dataset = Dataset.from_dict({"conversations": converted_conversations})

formatted_medqa_dataset = formatted_medqa_dataset.map(apply_template, batched=True)


In [None]:
# peparing MedMCQA dataset

med_mcqa_dataset = load_dataset("openlifescienceai/medmcqa", split="train")

converted_conversations = []

for example in med_mcqa_dataset:
    # Extract the question and answer options
    question = example['question']
    #options = [example['ending0'], example['ending1'], example['ending2'], example['ending3']]
    options = [
        f"A) {example['opa']}",
        f"B) {example['opb']}",
        f"C) {example['opc']}",
        f"D) {example['opd']}"
    ]


    correct_answer = f"Answer: {options[example['cop']]}"

    options_str = "\n".join(options)
    input_text = f"Question: {question}\nOptions:\n{options_str}"

    # Create a ChatML-like conversation for each data point
    conversation = [
        {"from": "human", "value": input_text},
        {"from": "gpt", "value": correct_answer}
    ]

    # Add the conversation to the list
    converted_conversations.append(conversation)

formatted_med_mcqa_dataset = Dataset.from_dict({"conversations": converted_conversations})

#print(formatted_med_mcqa_dataset[0])

formatted_med_mcqa_dataset = formatted_med_mcqa_dataset.map(apply_template, batched=True)

In [None]:
# peparing PubMedQA dataset

pub_medqa_dataset = load_dataset("qiaojin/PubMedQA", "pqa_artificial", split="train")

converted_conversations = []

for example in pub_medqa_dataset:
    # Extract the question and answer options
    context = example['context']

    context = " ".join(context['contexts'])
    question = example['question']

    options = [
        f"A) yes",
        f"B) no",
        f"C) maybe"
    ]

    answer_to_option = {
        "yes": "A) yes",
        "no": "B) no",
        "maybe": "C) maybe"
    }

    correct_answer = example['final_decision']
    correct_answer = answer_to_option.get(correct_answer, "Unknown")

    correct_answer = f"Answer: {correct_answer}"

    options_str = "\n".join(options)
    input_text = f"Context: {context}\nQuestion: {question}\nOptions:\n{options_str}"

    # Create a ChatML-like conversation for each data point
    conversation = [
        {"from": "human", "value": input_text},
        {"from": "gpt", "value": correct_answer}
    ]

    # Add the conversation to the list
    converted_conversations.append(conversation)

formatted_pub_medqa_dataset = Dataset.from_dict({"conversations": converted_conversations})

formatted_pub_medqa_dataset = formatted_pub_medqa_dataset.map(apply_template, batched=True)

In [None]:
#concatenate  datasets

combined_dataset = concatenate_datasets([formatted_medqa_dataset, formatted_med_mcqa_dataset, formatted_pub_medqa_dataset])
combined_dataset = combined_dataset.shuffle(seed=42)

print("combined_dataset:", combined_dataset)

train_val_split = combined_dataset.train_test_split(test_size=0.1, seed=42)

medmix_complete_train_dataset = train_val_split['train']

# Shuffle the dataset with the specified seed
shuffled_medmix_complete_train_dataset = medmix_complete_train_dataset.shuffle(seed=42)

#*************************
# Select the first 33,333 instances for the training set
medmix_train_dataset = shuffled_medmix_complete_train_dataset.select(range(33333))


medmix_complete_val_dataset = train_val_split['test']

# Shuffle the dataset with the specified seed
shuffled_medmix_complete_val_dataset = medmix_complete_val_dataset.shuffle(seed=42)

#*************************
# Select the first 33,33 instances for the validation set
medmix_val_dataset = shuffled_medmix_complete_val_dataset.select(range(3333))

print(medmix_train_dataset)
print(medmix_val_dataset)


In [None]:
final_train_dataset = concatenate_datasets([notechat_train_dataset, medmix_train_dataset])
final_train_dataset = final_train_dataset.shuffle(seed=42)

print(final_train_dataset)

In [None]:
final_val_dataset = concatenate_datasets([notechat_val_dataset, medmix_val_dataset])
final_val_dataset = final_val_dataset.shuffle(seed=42)

print(final_val_dataset)

In [None]:
trainer=SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=final_train_dataset,
    eval_dataset=final_val_dataset,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    dataset_num_proc=2,
    packing=True,
    args=TrainingArguments(
        learning_rate=3e-4,
        lr_scheduler_type="linear",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        num_train_epochs=1,
        fp16=not is_bfloat16_supported(),
        bf16=is_bfloat16_supported(),
        logging_steps=360,
        optim="adamw_8bit",
        weight_decay=0.01,
        warmup_steps=10,
        output_dir="output",
        seed=0,
        eval_strategy="steps",
        eval_steps=360
    ),
)

trainer.train()
#Below is how to resume training if you are interrupted in the middle of the fine-tuning process.
#trainer.train(resume_from_checkpoint=True)

In [None]:
model.push_to_hub_merged("Dasun7/Med-nirvana-v1-8B", tokenizer, save_method="merged_16bit")