<a href="https://colab.research.google.com/github/mudogruer/SLMs/blob/main/med_mixtral.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **MedMixtral: LLM Fine-Tuning with Predibase**

This quickstart will show you how to prompt, fine-tune, and deploy LLMs in Predibase. We'll be following a code generation use case where our end result will be a fine-tuned Mixtral model that takes in natural language as input and returns code as output.

In [5]:
!pip install -U predibase --quiet
!pip install -q -U transformers bert-score evaluate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m37.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m51.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m65.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m731.7/731.7 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m410.6/410.6 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m121.6/121.6 MB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━

# **Setup**

You'll first need to initialize your PredibaseClient object and configure your API token.

In [2]:
from predibase import PredibaseClient

pc = PredibaseClient(token="API_KEY")

# **Prompt a deployed LLM**

For our code generation use case, let's first see how Mixtral performs out of the box.

If you are in the Predibase SaaS environment, you have access to shared [serverless LLM deployments](https://docs.predibase.com/ui-guide/llms/query-llm/shared_deployments), including Llama 2 7B.

If you are in a VPC environment, you'll need to first [deploy a pretrained LLM](https://docs.predibase.com/user-guide/inference/dedicated_deployments#pretrained-llm-deployment).

In [6]:
llm_deployment = pc.LLM("pb://deployments/mixtral-8x7b-instruct-v0-1")
result: list = llm_deployment.prompt("""
    Answer the following question shortly.

    ### Question: Malaria relapse common with which type plasmodium species?

    ### Answer:
""", max_new_tokens=256)
print(result.response)

Malaria relapse is most commonly associated with the Plasmodium vivax and Plasmodium ovale species. These species can remain dormant in the liver for extended periods, leading to relapses even after initial treatment and apparent recovery.


# **Fine-tune a pretrained LLM**

Next we'll upload a dataset and fine-tune to see if we can get better performance.

The [MedMCQA](https://github.com/medmcqa/medmcqa) dataset is used for fine-tuning large language models to follow instructions to produce code from natural language and consists of the following columns:

- `question` that describes a question
- `exp` when additional context is required for the instruction
- the expected `output`


For the sake of this quickstart, we've created a version of the Code Alpaca dataset with fewer rows so that the model trains significantly faster.

**Now we will perform the following actions to start our fine-tuning job:**
1. Upload the dataset to Predibase for training
2. Create a prompt template to use for fine-tuning
3. Select the LLM we want to fine-tune
4. Kick off the fine-tuning job


In [None]:
# Upload the dataset to Predibase (estimated time: 2 minutes due to creation of Predibase dataset with dataset profile)
# If you've already uploaded the dataset before, you can skip uploading and get the dataset directly with
dataset = pc.get_dataset("med_train", "file_uploads")
#dataset = pc.upload_dataset("xzy.csv")

In [None]:
dataset

Dataset(id=11346, name=med_train, object_name=c99f31ad5a104e748682f8cd72c340d4, connection_id=6363, author=mustafa.dogruer@iu-study.org, created=2024-03-21T21:06:02.125022Z, updated=2024-03-21T21:06:02.125022Z)

In [None]:
# Define the template used to prompt the model for each example
# Note the 4-space indentation, which is necessary for the YAML templating.
prompt_template = """
    Given a passage, you need to accurately identify and extract relevant spans of text that answer specific questions. Provide concise and coherent responses based on the information present in the passage as well as a reasonable coherent explanation for your response.
    ### Passage: {exp}

    ### Question: {question}

    ### Answer:
"""

# Specify the Huggingface LLM you want to fine-tune
# Kick off a fine-tuning job on the uploaded dataset
llm = pc.LLM("hf://mistralai/Mixtral-8x7B-Instruct-v0.1")
job = llm.finetune(
    prompt_template=prompt_template,
    target="answer",
    dataset=dataset,
    repo="med_mixtral"
)

# Wait for the job to finish and get training updates and metrics
model = job.get()

✓ Queued 0:02:46   
✓ Preprocessing 0:05:32   


┌──────────┬──────────┬──────────────────┬──────────────────────────┬──────────┬──────────┬──────────┐
│  epochs  [0m│   time   [0m│     feature      [0m│          metric          [0m│  train   [0m│   val    [0m│   test   [0m│
├──────────┼──────────┼──────────────────┼──────────────────────────┼──────────┼──────────┼──────────┤
│ 3864/173681 steps ■■□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□□ │

# **Prompt your fine-tuned LLM**

Predibase supports both real-time inference, as well as [batch inference](https://docs.predibase.com/user-guide/inference/batch_prediction).

#### **Real-time inference using _LoRAX_** (Recommended)

[LoRA eXchange (LoRAX)](https://predibase.com/blog/lorax-the-open-source-framework-for-serving-100s-of-fine-tuned-llms-in) allows you to prompt your fine-tuned LLM without needing to create a new deployment for each model you want to prompt. Predibase automatically loads your fine-tuned weights on top of a shared LLM deployment on demand. While this means that there will be a small amount of additional latency, the benefit is that a single LLM deployment can support many different fine-tuned model versions without requiring additional compute.

Note: Inference using dynamic adapter deployments is available to both SaaS and VPC users. Predibase provides shared [serverless base LLM deployments](https://docs.predibase.com/user-guide/inference/serverless_deployments) for use in our SaaS environment. VPC users need [deploy their own base model](https://docs.predibase.com/user-guide/inference/dedicated_deployments#pretrained-llm-deployment).

In [3]:
# Since our model was fine-tuned from a Llama-2-7b base, we'll use the shared deployment with the same model type.
base_deployment = pc.LLM("pb://deployments/mixtral-8x7b-instruct-v0-1")

# Now we just specify the adapter to use, which is the model we fine-tuned.
model = pc.get_model("med_mixtral")
adapter_deployment = base_deployment.with_adapter(model)



In [4]:
question_exp = "This is a single choice question. You need to choose one of those options: 1- Leukemoid reaction, 2- Leukopenia, 3- Myeloid metaplasia, 4- Neutrophilia. Which one is true?"
question = "A 40-year-old man presents with 5 days of productive cough and fever. Pseudomonas aeruginosa is isolated from a pulmonary abscess. CBC shows an acute effect characterized by marked leukocytosis (50,000 mL) and the differential count reveals a shift to left in granulocytes. Which of the following terms best describes these hematologic findings?"

In [5]:
# Recall that our model was fine-tuned using a template that accepts an {instruction}
# and an {input}. This template is automatically applied when prompting.
result = adapter_deployment.prompt(
    {"exp": question_exp,
    "question": question},
    max_new_tokens=256)

print(result.response)

Neutrophilia


In [7]:
import pandas as pd

dataset_test = pd.read_csv("med_test.csv")
dataset_test.head()

Unnamed: 0,question,exp,cop,opa,opb,opc,opd,subject_name,topic_name,id,choice_type,answer
0,Which of the following is not true for myelina...,,1,Impulse through myelinated fibers is slower th...,Membrane currents are generated at nodes of Ra...,Saltatory conduction of impulses is seen,Local anesthesia is effective only when the ne...,Physiology,,45258d3d-b974-44dd-a161-c3fccbdadd88,multi,Impulse through myelinated fibers is slower th...
1,Which of the following is not true about glome...,Ans-a. The oncotic pressure of the fluid leavi...,1,The oncotic pressure of the fluid leaving the ...,Glucose concentration in the capillaries is th...,Constriction of afferent aeriole decreases the...,Hematocrit of the fluid leaving the capillarie...,Physiology,,b944ada9-d776-4c2a-9180-3ae5f393f72d,multi,The oncotic pressure of the fluid leaving the ...
2,A 29 yrs old woman with a pregnancy of 17 week...,,3,No test is required now as her age is below 35...,Ultra sound at this point of time will definit...,Amniotic fluid samples plus chromosomal analys...,blood screening at this point of time will cle...,Medicine,,b64a9cd7-d076-4c55-8be1-f9c44fece6cc,single,Amniotic fluid samples plus chromosomal analys...
3,Axonal transport is:,Fast anterograde (400 mm/day) transport occurs...,3,Antegrade,Retrograde,Antegrade and retrograde,,Physiology,,c6365cce-507c-40f6-90a2-46b867f47b6e,multi,Antegrade and retrograde
4,Low insulin to glucagon ratio is seen in all o...,Answer- A. Glycogen synthesisLow insulin to gl...,1,Glycogen synthesis,Glycogen breakdown,Gluconeogenesis,Ketogenesis,Biochemistry,,72c1c5e0-b64f-4eef-bf22-ecfb60c5c19c,multi,Glycogen synthesis


In [8]:
data_subset = dataset_test[:1000]

In [9]:
data_subset.shape

(1000, 12)

In [19]:
for ind in df_deneme.index:
    d= {"exp":df_deneme['exp'][ind], "question":df_deneme['question'][ind]}
    print(d)

{'exp': nan, 'question': 'Which of the following is not true for myelinated nerve fibers:'}
{'exp': 'Ans-a. The oncotic pressure of the fluid leaving the capillaries is less than that of fluid entering it Guyton I LpJ1 4-.;anong 23/e p653-6_)Glomerular oncotic pressure (due to plasma protein content) is higher than that of filtrate oncotic pressure in Bowman\'s capsule"Since glucose is freely filtered and the fluid in the Bowman\'s capsule is isotonic with plasma, the concentration of glucose in the filtrate is the same as in the capillaries', 'question': "Which of the following is not true about glomerular capillaries')"}
{'exp': nan, 'question': 'A 29 yrs old woman with a pregnancy of 17 week has a 10 years old boy with down syndrome. She does not want another down syndrome kid; best advice to her is'}
{'exp': 'Fast anterograde (400 mm/day) transport occurs by kinesin molecular motor and retrograde transport (200 mm/day) occurs by dynein molecular motor.', 'question': 'Axonal transpo

In [10]:
#inference for test dataset
answers = []
for ind in data_subset.index:
    prompt = {"exp":data_subset['exp'][ind], "question":data_subset['question'][ind]}
    answer = adapter_deployment.prompt(prompt,max_new_tokens=256)
    answers.append(answer.response)

In [11]:
answers[:10]

['They are not surrounded by Schwann cells',
 'The oncotic pressure of the fluid leaving the capillaries is less than that of fluid entering it',
 'Nan',
 'Bidirectional',
 'Glycogen synthesis',
 '0.01',
 'Pregnant woman with sore throat can be stated immediately on oseltamivir without diagnostic testing under category B',
 'Anterior ethmoidal aery',
 'Broad QRS complex with normal sinus rhythm',
 'Pulmonary atresia']

In [12]:
reference = data_subset["answer"]
reference[:10]

0    Impulse through myelinated fibers is slower th...
1    The oncotic pressure of the fluid leaving the ...
2    Amniotic fluid samples plus chromosomal analys...
3                             Antegrade and retrograde
4                                   Glycogen synthesis
5                                                 0.01
6    Pregnant woman with sore throat can be staed i...
7                              Anterior ethmoidal aery
8                                 Electrical alternans
9                                    Pulmonary atresia
Name: answer, dtype: object

In [13]:
reference = list(data_subset["answer"])
reference[:10]

['Impulse through myelinated fibers is slower than non-myelinated fibers',
 'The oncotic pressure of the fluid leaving the capillaries is less than that of fluid entering it',
 'Amniotic fluid samples plus chromosomal analysis will definitely tell her that next baby will be down syndromic or not',
 'Antegrade and retrograde',
 'Glycogen synthesis',
 '0.01',
 'Pregnant woman with sore throat can be staed immediately on oseltamivir without diagnostic testing under category B',
 'Anterior ethmoidal aery',
 'Electrical alternans',
 'Pulmonary atresia']

In [14]:
#evaluate with bert-score using distilbert model
from evaluate import load
import numpy as np
bertscore = load("bertscore")
predictions = answers
references = list(data_subset["answer"])
results = bertscore.compute(predictions=predictions, references=references, model_type="distilbert-base-uncased")
print("precision: ",round(np.mean(list(results["precision"])),5))
print("recall: ",round(np.mean(list(results["recall"])),5))
print("f1: ",round(np.mean(list(results["f1"])),5))

Downloading builder script:   0%|          | 0.00/7.95k [00:00<?, ?B/s]

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

precision:  0.82437
recall:  0.80586
f1:  0.8142




In [21]:
print("precision std: ",round(np.std(list(results["precision"])),5))
print("recall std: ",round(np.std(list(results["recall"])),5))
print("f1 std: ",round(np.std(list(results["f1"])),5))

precision std:  0.1571
recall std:  0.17338
f1 std:  0.16437


In [16]:
data_subset["predicted_answer"] = answers

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  data_subset["predicted_answer"] = answers


In [17]:
data_subset.head()

Unnamed: 0,question,exp,cop,opa,opb,opc,opd,subject_name,topic_name,id,choice_type,answer,predicted_answer
0,Which of the following is not true for myelina...,,1,Impulse through myelinated fibers is slower th...,Membrane currents are generated at nodes of Ra...,Saltatory conduction of impulses is seen,Local anesthesia is effective only when the ne...,Physiology,,45258d3d-b974-44dd-a161-c3fccbdadd88,multi,Impulse through myelinated fibers is slower th...,They are not surrounded by Schwann cells
1,Which of the following is not true about glome...,Ans-a. The oncotic pressure of the fluid leavi...,1,The oncotic pressure of the fluid leaving the ...,Glucose concentration in the capillaries is th...,Constriction of afferent aeriole decreases the...,Hematocrit of the fluid leaving the capillarie...,Physiology,,b944ada9-d776-4c2a-9180-3ae5f393f72d,multi,The oncotic pressure of the fluid leaving the ...,The oncotic pressure of the fluid leaving the ...
2,A 29 yrs old woman with a pregnancy of 17 week...,,3,No test is required now as her age is below 35...,Ultra sound at this point of time will definit...,Amniotic fluid samples plus chromosomal analys...,blood screening at this point of time will cle...,Medicine,,b64a9cd7-d076-4c55-8be1-f9c44fece6cc,single,Amniotic fluid samples plus chromosomal analys...,Nan
3,Axonal transport is:,Fast anterograde (400 mm/day) transport occurs...,3,Antegrade,Retrograde,Antegrade and retrograde,,Physiology,,c6365cce-507c-40f6-90a2-46b867f47b6e,multi,Antegrade and retrograde,Bidirectional
4,Low insulin to glucagon ratio is seen in all o...,Answer- A. Glycogen synthesisLow insulin to gl...,1,Glycogen synthesis,Glycogen breakdown,Gluconeogenesis,Ketogenesis,Biochemistry,,72c1c5e0-b64f-4eef-bf22-ecfb60c5c19c,multi,Glycogen synthesis,Glycogen synthesis


In [19]:
data_subset.to_csv('answer_med_test.csv', index=False)

In [20]:
data_subset.to_excel('answer_med_test.xlsx')

In [None]:
comparison_dataset = load_dataset('sciq',split='test[:20%]')

In [None]:
answers_comparison = []
for question in comparison_dataset["question"]:
    prompt = question
    answer = adapter_deployment.prompt(prompt,temperature=0.1,max_new_tokens=256)
    answers_comparison.append(answer.response)

In [None]:
bertscore = load("bertscore")
predictions = answers_comparison
references = comparison_dataset["correct_answer"]
results = bertscore.compute(predictions=predictions, references=references, model_type="distilbert-base-uncased")
print("precision: ",round(np.mean(list(results["precision"])),5))
print("recall: ",round(np.mean(list(results["recall"])),5))
print("f1: ",round(np.mean(list(results["f1"])),5))

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



NameError: name 'load' is not defined

During handling of the above exception, another exception occurred:

AttributeError: 'NameError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

AssertionError
NameError: name 'load' is not defined

During handling of the above exception, another exception occurred:

AttributeError: 'NameError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

TypeError: object of type 'NoneType' has no len()

During handling of the above exception, another exception occurred:

AttributeError: 'TypeError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

AssertionError
NameError: name 'load' is not defined

During handling of the above exception, another exception occurred:

AttributeError: 'NameError' object has no attribute '_render_traceback_'

During handling of the 