In [50]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from transformers.trainer import TrainingArguments
from peft import LoraConfig, AutoPeftModelForCausalLM, get_peft_model, prepare_model_for_int8_training, PeftModel
from trl import SFTTrainer, setup_chat_format
from datasets import load_dataset
from transformers import pipeline

In [2]:
with open('hf_token.key', 'r') as f:
    hf_token = f.read()

In [3]:
base_model = "meta-llama/Meta-Llama-3-8B-Instruct"
new_model = "llama-3-8b-counsel"
torch_dtype = torch.bfloat16

In [4]:
# QLoRA Config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True
)

In [5]:
# Load Model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config = bnb_config,
    torch_dtype = torch_dtype,
    device_map={'':torch.cuda.current_device()}
)

model.config.use_cache=False

Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.87s/it]


In [6]:
tokenizer = AutoTokenizer.from_pretrained(base_model, token = hf_token)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
model, tokenizer = setup_chat_format(model, tokenizer)

In [8]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type = "CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj'],
)

model = get_peft_model(model, peft_config)

In [9]:
dataset_name = "ruslanmv/ai-medical-chatbot"
dataset = load_dataset(dataset_name, split="all")
dataset = dataset.shuffle(seed=42).select(range(1000))

In [10]:
def format_chat_template(row):
    row_json = [
        {"role" : "user", "content": row['Patient']},
        {"role" : "assistant", "content": row['Doctor']}
    ]
    
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

In [11]:
dataset = dataset.map(
            format_chat_template,
            num_proc=4
        )

dataset = dataset.train_test_split(test_size=0.1)

In [12]:
dataset['train'][0]

{'Description': 'How can cramps all over the body and jaw be treated?',
 'Patient': 'Hello, Im 51yrs.old and I have been have muscle cramps all over my body now for 5days...Went to hospital and got Fluids w an IV ,but at that time it was only my legs...then later that day I started have cramps in arms, jaw, upper lip and need some ideas what to do please...',
 'Doctor': 'Hello, I have studied your case. Cramps in the whole body are usually due to lack of sodium and calcium in the body. So you need to check if your calcium and sodium are normal. If you are on high blood pressure medicines then chances of salt deficiency are higher. So go through this test. You can also take salt with juices and water. Hope I have answered your query. Let me know if I can assist you further. Regards, Dr. Naveen Kumar Sharma, Orthopaedic Surgeon, Joint Replacement',
 'text': '<|im_start|>user\nHello, Im 51yrs.old and I have been have muscle cramps all over my body now for 5days...Went to hospital and got 

In [13]:
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    save_strategy="epoch",
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    bf16=True,
    group_by_length=True,
    report_to="none"
)

In [14]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    max_seq_length=128,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)

Map: 100%|██████████| 900/900 [00:00<00:00, 1940.09 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 1710.89 examples/s]
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [15]:
trainer.train()



Step,Training Loss,Validation Loss
12,2.7578,2.786915
24,2.5654,2.681115
36,2.6663,2.65308
48,2.6236,2.636054




TrainOutput(global_step=56, training_loss=2.7839642422539845, metrics={'train_runtime': 644.4814, 'train_samples_per_second': 1.396, 'train_steps_per_second': 0.087, 'total_flos': 5174017369767936.0, 'train_loss': 2.7839642422539845, 'epoch': 0.99})

### Saving the trained model

In [16]:
trainer.model.save_pretrained(new_model)



### Model Evaluation

In [18]:
model.config.use_cache = True

messages = [
    {
        "role": "user",
        "content": "Hello doctor, I have bad acne. How do I get rid of it?"
    }
]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

inputs = tokenizer(prompt, return_tensors='pt', padding=True,  truncation=True).to(model.device)

outputs = model.generate(**inputs, max_length=150, num_return_sequences=1)

text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(text.split("assistant")[1])

Merging the base model with the adapter

In [33]:
base_model = "meta-llama/Meta-Llama-3-8B-Instruct"
new_model = "llama-3-8b-counsel"
torch_dtype = torch.bfloat16

In [34]:
# Load Model
base_model_reload = AutoModelForCausalLM.from_pretrained(
    base_model,
    return_dict=True,
    low_cpu_mem_usage=True,    
    trust_remote_code=True,
    torch_dtype = torch_dtype,
    device_map={"":torch.cuda.current_device()}
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:08<00:00,  2.22s/it]


In [35]:
tokenizer = AutoTokenizer.from_pretrained(base_model, token = hf_token)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [36]:
base_model_reload, tokenizer = setup_chat_format(base_model_reload, tokenizer)

Merge adapter with the base model

In [37]:
model = PeftModel.from_pretrained(base_model_reload, new_model)

In [38]:
model = model.merge_and_unload()

Model Inference from Merged Model

In [40]:
messages = [{"role": "user", "content": "Hello doctor, I have bad acne. How do I get rid of it?"}]

In [41]:
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

In [53]:
messages = [{"role": "user", "content": "Hello doctor, I have bad acne. How do I get rid of it?"}]

prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    device_map={"":torch.cuda.current_device()}
)

outputs = pipe(prompt, max_new_tokens=120, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
print(outputs[0]["generated_text"])

<|im_start|>user
Hello doctor, I have bad acne. How do I get rid of it?<|im_end|>
<|im_start|>assistant
                                                                                                                        


In [39]:
model.save_pretrained("llama-3-8b-chat-doctor")
tokenizer.save_pretrained("llama-3-8b-chat-doctor")

('llama-3-8b-chat-doctor/tokenizer_config.json',
 'llama-3-8b-chat-doctor/special_tokens_map.json',
 'llama-3-8b-chat-doctor/tokenizer.json')