<a href="https://colab.research.google.com/github/boheling/healthAI/blob/main/ChatDoctor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# Install necessary packages
!pip install transformers datasets trl peft --quiet
#!huggingface-cli login

from google.colab import drive
drive.mount('/content/drive')

from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig, get_peft_model, TaskType

# Use the gpt-2 model
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Wrap the model with LoRA for efficient fine-tuning
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,                  # LoRA rank
    lora_alpha=32,        # Scaling factor
    lora_dropout=0.1,     # Dropout for LoRA layers
    target_modules=["q_proj", "v_proj"]  # Adjust this list based on your model's architecture
)
model = get_peft_model(model, lora_config)

# Save the pre-trained model (with LoRA wrapper) for backup
model.save_pretrained("./pretrain_model")
tokenizer.save_pretrained("./pretrain_model")

# Load a portion of the dataset and split into training and evaluation
train_dataset = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k", split="train[:8%]")
eval_dataset  = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k", split="train[8%:10%]")

# Define a tokenization function with structured formatting
def tokenize_function(examples):
    texts = []
    for i in range(len(examples["instruction"])):
        text = f"Instruction: {examples['instruction'][i]}\n"
        # Process 'input'
        if examples["input"][i]:
            if isinstance(examples["input"][i], list):
                input_text = " ".join(examples["input"][i])
            else:
                input_text = examples["input"][i]
            text += f"Input: {input_text}\n"
        # Process 'output'
        if examples["output"][i]:
            if isinstance(examples["output"][i], list):
                output_text = " ".join(examples["output"][i])
            else:
                output_text = examples["output"][i]
            text += f"Output: {output_text}\n"
        texts.append(text)
    return tokenizer(texts, truncation=True, max_length=128)

# Tokenize the training dataset and remove the original columns
unfiltered_train = train_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["instruction", "input", "output"]
)
train_tokenized = unfiltered_train.filter(lambda x: len(x["input_ids"]) > 0)

# Tokenize the evaluation dataset and remove the original columns
unfiltered_eval = eval_dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["instruction", "input", "output"]
)
eval_tokenized = unfiltered_eval.filter(lambda x: len(x["input_ids"]) > 0)

# Define training arguments (checkpoints and logs are saved to your Google Drive)
training_args = TrainingArguments(
    output_dir="./sft_output",  # save checkpoints here
    logging_dir="./sft_logs",
    per_device_train_batch_size=6,    # Adjust as necessary
    num_train_epochs=2,               # Increase for more training
    logging_steps=10,
    save_steps=100,
    evaluation_strategy="steps",
    eval_steps=100,
    fp16=True,                        # Mixed precision for T4 GPU
    dataloader_num_workers=2,         # Adjust based on your CPU
)

# Initialize the SFTTrainer with the LoRA-wrapped model and separate datasets
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=eval_tokenized,
)

# Evaluate the pre-trained model (baseline)
print("Evaluating pre-trained model...")
pretrain_metrics = trainer.evaluate()
print("Pre-training evaluation metrics:", pretrain_metrics)

# Sample prompt before training
prompt = "Q: What could be wrong with lower back pain in a cancer patient?\nA:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=100, do_sample=True, temperature=0.7, top_p=0.9)
print("Pre-training output:", tokenizer.decode(outputs[0], skip_special_tokens=True))

# Start fine-tuning
print("Starting fine-tuning...")
trainer.train()

# Save the fine-tuned (post-train) model (including the LoRA weights) to Google Drive
model.save_pretrained("/content/drive/MyDrive/SFT/posttrain_model_20")
tokenizer.save_pretrained("/content/drive/MyDrive/SFT/posttrain_model_20")

# Evaluate the fine-tuned model
print("Evaluating fine-tuned model...")
posttrain_metrics = trainer.evaluate()
print("Post-training evaluation metrics:", posttrain_metrics)

# Sample prompt after training
prompt = "Q: What could be wrong with lower back pain in a cancer patient?\nA:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_length=100, do_sample=True, temperature=0.7, top_p=0.9)
print("Post-training output:", tokenizer.decode(outputs[0], skip_special_tokens=True))

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


tokenizer_config.json:   0%|          | 0.00/3.07k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/679 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.55G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/542 [00:00<?, ?B/s]

(…)-00000-of-00001-5e7cb295b9cff0bf.parquet:   0%|          | 0.00/70.5M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/112165 [00:00<?, ? examples/s]

Map:   0%|          | 0/8973 [00:00<?, ? examples/s]

Filter:   0%|          | 0/8973 [00:00<?, ? examples/s]

Map:   0%|          | 0/2243 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2243 [00:00<?, ? examples/s]



Converting train dataset to ChatML:   0%|          | 0/8973 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/8973 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/8973 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/2243 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/2243 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/2243 [00:00<?, ? examples/s]

Evaluating pre-trained model...


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mboheling[0m ([33mboheling-stanford-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Pre-training evaluation metrics: {'eval_loss': 4.322879314422607, 'eval_model_preparation_time': 0.0099, 'eval_runtime': 31.463, 'eval_samples_per_second': 71.29, 'eval_steps_per_second': 8.931}
Pre-training output: Q: What could be wrong with lower back pain in a cancer patient?
A: The lower back pain in a cancer patient is likely to be due to a number of factors, including:

1. **Cancer itself** - Cancer can have a significant impact on the lower back, often leading to issues such as pain, stiffness, or difficulty moving the legs.

2. **Infiltration of cancer cells into the lower back muscles** - This can cause localized pain, stiffness, or difficulty
Starting fine-tuning...


Step,Training Loss,Validation Loss,Model Preparation Time
100,3.3032,3.324415,0.0099
200,3.0144,3.185761,0.0099
300,3.0591,3.141333,0.0099
400,3.0517,3.116447,0.0099
500,3.1501,3.098642,0.0099
600,3.0913,3.085376,0.0099
700,3.0394,3.073951,0.0099
800,3.0843,3.064441,0.0099
900,2.963,3.057035,0.0099
1000,2.9903,3.049524,0.0099


Evaluating fine-tuned model...


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Post-training evaluation metrics: {'eval_loss': 2.9231958389282227, 'eval_model_preparation_time': 0.0099, 'eval_runtime': 24.7696, 'eval_samples_per_second': 90.555, 'eval_steps_per_second': 11.345}
Post-training output: Q: What could be wrong with lower back pain in a cancer patient?
A: Lower back pain in a cancer patient is not a common issue. It is possible that you may have lower back pain in a cancer patient, but it is not a typical symptom of cancer. Lower back pain in a cancer patient could be due to the effects of chemotherapy, radiation therapy, or other treatment methods. You should discuss with your doctor about your symptoms and medical history to determine whether you have a lower back


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_model_and_tokenizer(model_path):
    """
    Loads the model and tokenizer from the specified path.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path)
    return model, tokenizer

def generate_text(model, tokenizer, prompt, max_new_tokens=50):
    """
    Generates text using the provided model and tokenizer.
    """
    # Tokenize the prompt
    inputs = tokenizer(prompt, return_tensors="pt")
    # Use GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Generate text
    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.7, top_p=0.9)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

def main():
    # Paths to the saved models
    pretrain_path = "./pretrain_model"
    posttrain_path = "/content/drive/MyDrive/SFT/posttrain_model_20"

    # Load models and tokenizers
    pre_model, pre_tokenizer = load_model_and_tokenizer(pretrain_path)
    post_model, post_tokenizer = load_model_and_tokenizer(posttrain_path)

    # Define a set of clinical questions for comparison
    clinical_questions = [
        "Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: What could be wrong with lower back pain in a cancer patient?\n Output:",
        "Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: I have type 2 diabetes, how should I know when I am getting better or worse?\n Output:",
        "Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: I heard there is a disease call diabetic ketoacidosis, what is that and what kind of tests I need to get to see whether I have it?\n Output:",
        "Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: What is the recommended treatment for acute myocardial infarction?\n Output:",
        "Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: My mom has pathophysiology of hypertension, what is it? please explain to me in layman's terms.\n Output:",
        "Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: I am taking beta-blockers for a year.  I recently fell very fatigue every morning, could it be a side effects of beta-blockers?\n Output:"
    ]

    # Compare responses from both models
    for question in clinical_questions:
        actually_question = question.split("Input:")[1][:-7]
        print(f"\n{'='*80}\nQuestion: {actually_question}\n")

        print("Pre-trained Model Response:")
        response_pre = generate_text(pre_model, pre_tokenizer, question, max_new_tokens=200)
        print(response_pre)

        print("\nPost-trained Model Response:")
        response_post = generate_text(post_model, post_tokenizer, question, max_new_tokens=200)
        print(response_post)

if __name__ == '__main__':
    main()


Question:  What could be wrong with lower back pain in a cancer patient?
 

Pre-trained Model Response:


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: What could be wrong with lower back pain in a cancer patient?
 Output: The patient's lower back pain could be caused by several factors, including ... (e.g., [insert example]).

But the output should be in the format where the user can see the output as a list of possible factors. So, for example, the output could be something like:

The patient's lower back pain could be caused by several factors, including ... (e.g., [insert example]).

But the output should be in the format where the user can see the output as a list of possible factors. So, for example, the output could be something like:

The patient's lower back pain could be caused by several factors, including ... (e.g., [insert example]).

Wait, perhaps the user wants the output in a specific format, such as a list with bullet points, but in the format where the user can see it as a list. For example, the output sho

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: What could be wrong with lower back pain in a cancer patient?
 Output: Hi, It is possible that you are suffering from a herniated disc or a herniated disc in your lower back or maybe it is a result of your cancer.  You could have an issue with your spinal column or your nerves, which can cause pain.  You could also have spinal disc problems.  Please consult your doctor for a proper diagnosis.  Best wishes.

Question:  I have type 2 diabetes, how should I know when I am getting better or worse?
 

Pre-trained Model Response:


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: I have type 2 diabetes, how should I know when I am getting better or worse?
 Output: [Your answer here]

The user is a patient with type 2 diabetes, and they are asking when they are getting better or worse. I need to respond as a doctor. I should explain the factors that influence the progression of their condition.

First, I should identify the key factors that affect disease progression. These include the progression of the disease itself, the impact of lifestyle factors, and the role of medications.

For the progression of the disease, I should mention things like the development of new blood vessels, increased metabolic rates, and the accumulation of insulin resistance. These are signs that the disease is getting worse.

Next, lifestyle factors play a significant role. Things like diet, exercise, and smoking are crucial. For example, a diet high in sugar and fats can i

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: I have type 2 diabetes, how should I know when I am getting better or worse?
 Output: Hi, I am Chat Doctor, I can understand your concern. Type 2 diabetes is a chronic disease and management of the disease is a challenging task. The patients with type 2 diabetes should be aware about their disease progression and take appropriate steps for the management. I would suggest that regular blood tests, blood sugar levels, glucose test, etc., should be done by the doctor. If you are doing regular blood tests, you can get an idea about your disease status and what you can do to manage the disease. If you are doing regular blood tests, you can get an idea about your disease status and what you can do to manage the disease. If you are doing regular blood tests, you can get an idea about your disease status and what you can do to manage the disease. If you are doing regular blood tests

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: I heard there is a disease call diabetic ketoacidosis, what is that and what kind of tests I need to get to see whether I have it?
 Output: Please answer the medical questions based on the patient's description. Input: I have heard that there is a condition called...
Please answer the question.

The question is: I have heard that there is a condition called... What is it, and what tests should I do to check it?
</think>

The question is incomplete, as it only mentions that there is a condition called but does not provide further details. Please provide more information about the condition so I can assist you effectively.

Post-trained Model Response:


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: I heard there is a disease call diabetic ketoacidosis, what is that and what kind of tests I need to get to see whether I have it?
 Output: Hello,Diabetes ketoacidosis is a condition where your blood sugar level is very high due to the fact that your liver can't process the glucose properly. It's a type of diabetes and in the past it was thought to be a serious condition. However, in the past decade, it has been found to be a more common condition. The condition can be classified into two types: hyperglycemia (high blood sugar) and hyperlipidemia (high lipid level).Both are linked with diabetes. The tests that are done to find if you have diabetes ketoacidosis are: blood glucose level, liver enzymes, lipids level, and uric acid level. If you have any of these, you should get tested. If you have a high glucose level, your liver enzymes should be normal or elevated. If your bl

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: What is the recommended treatment for acute myocardial infarction?
 Output: The recommended treatment for acute myocardial infarction is to

Please note: The user is asking for a medical answer, so the answer should be factual and accurate.
</think>

Acute myocardial infarction, also known as myocardial infarction (MI), is a medical condition caused by a blockage of the coronary artery. The treatment for acute MI typically includes:

1. **Dietary Modification**: Consulting a dietitian to reduce inflammation, reduce the intake of high-sodium, high-sulfur, and high-sugar foods, and to lower cholesterol and blood pressure.

2. **Statin Therapy**: Regular application of statins, which help lower cholesterol and reduce the risk of heart disease and stroke.

3. **Anticoagulation Therapy**: In some cases, a blood transfusion may be necessary if the patient has a high risk of MI. Al

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: What is the recommended treatment for acute myocardial infarction?
 Output: Hello, I am Chat Doctor, I will try to answer your query. Acute myocardial infarction is a sudden death due to the death of blood vessels in the heart. The main reason for death is the collapse of the blood vessels. This can happen due to blockage in the blood vessels or due to the collapse of the heart muscle. Treatment for acute myocardial infarction is not available because of the sudden death. The treatment for heart attack is to stop the heart attack. The treatment is to stop the blood flow through the heart. For acute myocardial infarction, the only treatment is to stop the blood flow through the heart. In this case, the treatment is to stop the blood flow through the heart. The best treatment is to stop the blood flow through the heart. I hope I understand your query correctly. Looking forward

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: My mom has pathophysiology of hypertension, what is it? please explain to me in layman's terms.
 Output: Please provide a layman's explanation of the medical term "hypertension" as the answer.
</think>

Hypertension, or high blood pressure, is when your blood pressure is too high, meaning your blood is thickened. It can be caused by various factors like stress, diet, or medication. A high blood pressure can lead to health problems like high cholesterol, heart issues, and diabetes.

Post-trained Model Response:


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: My mom has pathophysiology of hypertension, what is it? please explain to me in layman's terms.
 Output: Hi. Hypertension is a condition where your blood pressure is too high. It can be due to high salt intake, high cholesterol, high alcohol, or other factors. Your mom's hypertension could be due to her diet and lifestyle choices. You can try to follow your mom's diet and exercise to reduce her blood pressure. If your mom's hypertension is severe, she may need a blood pressure monitor.

Question:  I am taking beta-blockers for a year.  I recently fell very fatigue every morning, could it be a side effects of beta-blockers?
 

Pre-trained Model Response:


Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: I am taking beta-blockers for a year.  I recently fell very fatigue every morning, could it be a side effects of beta-blockers?
 Output: [Your Answer]
</think>

Yes, beta-blockers can cause fatigue, particularly in the mornings, as they work to relax the muscles. This is a common side effect and may be related to the medication's effects on blood sugar levels and other hormones. If you're experiencing severe fatigue, it might be worth consulting a healthcare professional for further evaluation or treatment.

Post-trained Model Response:
Instruction: If you are a doctor, please answer the medical questions based on the patient's description. Input: I am taking beta-blockers for a year.  I recently fell very fatigue every morning, could it be a side effects of beta-blockers?
 Output: Hello, Welcome to Chat Doctor. Beta-blockers are taken by your doctor to reduce the side effec