In [1]:
%%capture
%pip install -U transformers 
%pip install -U datasets 
%pip install -U accelerate 
%pip install -U peft 
%pip install -U trl 
%pip install -U bitsandbytes 
%pip install -U wandb

In [2]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch, wandb
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format

In [3]:
from huggingface_hub import login
import wandb

# Log in to Hugging Face
hf_token = "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"
login(token=hf_token)

# Log in to Weights & Biases
wb_token = "6d093f67d658015661b11157166b74b60a25d11e"
wandb.login(key=wb_token)

# Initialize a wandb run
run = wandb.init(
    project='Fine-tune Llama 3 8B on Medical Dataset', 
    job_type="training", 
    anonymous="allow"
)


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mzlow3[0m ([33mzlow3-Georgia Tech Alumni Association[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
base_model = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset_name = "ruslanmv/ai-medical-chatbot"
new_model = "llama-3-8b-chat-doctor"

torch_dtype = torch.float16
attn_implementation = "eager"

# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
model, tokenizer = setup_chat_format(model, tokenizer)

# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/51.0k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/73.0 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
model, tokenizer = setup_chat_format(model, tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

In [7]:
#Importing the dataset
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=65).select(range(1000)) # Only use 1000 samples for quick demo

def format_chat_template(row):
    row_json = [{"role": "user", "content": row["Patient"]},
               {"role": "assistant", "content": row["Doctor"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

dataset = dataset.map(
    format_chat_template,
    num_proc=4,
)

dataset['text'][3]

Downloading readme:   0%|          | 0.00/863 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/142M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/256916 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/1000 [00:00<?, ? examples/s]

'<|im_start|>user\nFell on sidewalk face first about 8 hrs ago. Swollen, cut lip bruised and cut knee, and hurt pride initially. Now have muscle and shoulder pain, stiff jaw(think this is from the really swollen lip),pain in wrist, and headache. I assume this is all normal but are there specific things I should look for or will I just be in pain for a while given the hard fall?<|im_end|>\n<|im_start|>assistant\nHello and welcome to HCM,The injuries caused on various body parts have to be managed.The cut and swollen lip has to be managed by sterile dressing.The body pains, pain on injured site and jaw pain should be managed by pain killer and muscle relaxant.I suggest you to consult your primary healthcare provider for clinical assessment.In case there is evidence of infection in any of the injured sites, a course of antibiotics may have to be started to control the infection.Thanks and take careDr Shailja P Wahal<|im_end|>\n'

In [10]:
dataset

Dataset({
    features: ['Description', 'Patient', 'Doctor', 'text'],
    num_rows: 1000
})

In [12]:
dataset['Description'][0]

'Can blood pressure medication be stopped to check improvement in bp levels?'

In [13]:
dataset['Patient'][0]

"I'm 35, BP 150/100 without medicine, never smoke, drink, BMIMy hdl:45, LDL:107, Total Cholesterol:168, Triglyceride:95, Blood Sugar(Fasting):83 (all are without medicine) My question is, should I stop this medicine to see the afffect. it is worth to mention that recently I've increased my physical activity,"

In [14]:
dataset['Doctor'][0]

'Hello,Thanks for writing to Health Care Magic, I am Dr Asad Riaz, I have closely read your question and I understand your concerns, I will hereby guide you regarding your health related problem.BP is one major risk factor for major complication like MI or stroke...if pt have high BP as u r having ..1st we need to do life style modification to see whteher it work to de BP or not in whch i advice pt to lower salt intake,drinkng dec weight n exercise for a period of 6 month n still if pt have high BP then need to add med to control it so we prevent major comlication..as u having high bp u dont need to stop med n if u wanna it visit ur physian n then follow the what he/she advised to u..I hope this answered your question, if you have more feel free to ask.Regards.Dr.Asad Riaz.General and Family Physician.'

In [15]:
dataset['text'][0]

"<|im_start|>user\nI'm 35, BP 150/100 without medicine, never smoke, drink, BMIMy hdl:45, LDL:107, Total Cholesterol:168, Triglyceride:95, Blood Sugar(Fasting):83 (all are without medicine) My question is, should I stop this medicine to see the afffect. it is worth to mention that recently I've increased my physical activity,<|im_end|>\n<|im_start|>assistant\nHello,Thanks for writing to Health Care Magic, I am Dr Asad Riaz, I have closely read your question and I understand your concerns, I will hereby guide you regarding your health related problem.BP is one major risk factor for major complication like MI or stroke...if pt have high BP as u r having ..1st we need to do life style modification to see whteher it work to de BP or not in whch i advice pt to lower salt intake,drinkng dec weight n exercise for a period of 6 month n still if pt have high BP then need to add med to control it so we prevent major comlication..as u having high bp u dont need to stop med n if u wanna it visit u

In [16]:
dataset = dataset.train_test_split(test_size=0.1)

In [17]:
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="wandb"
)



In [18]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    max_seq_length=512,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/900 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

In [19]:
trainer.train()

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Step,Training Loss,Validation Loss
90,2.6512,2.544062
180,2.1682,2.505385
270,2.2081,2.473809
360,2.6169,2.456046
450,2.5196,2.439003


TrainOutput(global_step=450, training_loss=2.52239581240548, metrics={'train_runtime': 345.0578, 'train_samples_per_second': 2.608, 'train_steps_per_second': 1.304, 'total_flos': 9231172574502912.0, 'train_loss': 2.52239581240548, 'epoch': 1.0})

In [21]:
wandb.finish()
model.config.use_cache = True

VBox(children=(Label(value='0.008 MB of 0.008 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/loss,‚ñà‚ñÖ‚ñÉ‚ñÇ‚ñÅ
eval/runtime,‚ñÅ‚ñÜ‚ñà‚ñà‚ñà
eval/samples_per_second,‚ñà‚ñÉ‚ñÅ‚ñÅ‚ñÅ
eval/steps_per_second,‚ñà‚ñÉ‚ñÅ‚ñÅ‚ñÅ
train/epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
train/global_step,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
train/grad_norm,‚ñÑ‚ñÉ‚ñÉ‚ñÑ‚ñÜ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÑ‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñà‚ñÇ‚ñÇ‚ñÇ‚ñÑ
train/learning_rate,‚ñÑ‚ñà‚ñà‚ñà‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ
train/loss,‚ñá‚ñÑ‚ñÜ‚ñÑ‚ñá‚ñÖ‚ñÑ‚ñÇ‚ñÉ‚ñÉ‚ñÑ‚ñÖ‚ñÑ‚ñÉ‚ñÑ‚ñÇ‚ñÑ‚ñÜ‚ñÑ‚ñÉ‚ñÇ‚ñÉ‚ñÖ‚ñÜ‚ñÖ‚ñÉ‚ñá‚ñÖ‚ñÇ‚ñÖ‚ñÇ‚ñÑ‚ñÅ‚ñÉ‚ñÑ‚ñÇ‚ñà‚ñÖ‚ñÑ‚ñÜ

0,1
eval/loss,2.439
eval/runtime,14.1103
eval/samples_per_second,7.087
eval/steps_per_second,7.087
total_flos,9231172574502912.0
train/epoch,1.0
train/global_step,450.0
train/grad_norm,2.31442
train/learning_rate,0.0
train/loss,2.5196


In [22]:
messages = [
    {
        "role": "user",
        "content": "Hello doctor, I have bad acne. How do I get rid of it?"
    }
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, 
                                       add_generation_prompt=True)

inputs = tokenizer(prompt, return_tensors='pt', padding=True, 
                   truncation=True).to("cuda")

outputs = model.generate(**inputs, max_length=150, 
                         num_return_sequences=1)

text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(text.split("assistant")[1])


Hi. Acne is a common problem in teenagers. It can be treated with oral antibiotics, local applications and lifestyle modifications. You can use a retinoid cream or gel at night and an antibiotic gel in the morning. For further information consult a dermatologist online --> https://www.iclinq.com/ask-a-doctor --> choose a dermatologist. Hope I have answered your query. Let me know if I can assist you further. Regards, Dr. Sumanth MBBS, DCH, DNB (Paed).  I have gone through your query and understand your concerns. Acne is a common
