Load Libraries

In [1]:
from dotenv import load_dotenv
import os
from math import ceil
import torch
from trl import SFTTrainer
from peft import LoraConfig
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline

load_dotenv()
os.environ["HF_TOKEN"] = os.getenv('HUGGINGFACE_TOKEN')

Load LLM Model

In [2]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit = True, 
    bnb_4bit_quant_type = "nf4", 
    bnb_4bit_compute_dtype = torch.bfloat16,)

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",
                                             quantization_config = quantization_config, )
model.config.use_cache = False
model.config.pretraining_tp = 1

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", 
                                          trust_remote_code = True, )
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

Load Dataset

In [4]:
dataset = load_dataset(path = "aboonaji/wiki_medical_terms_llam2_format", split = "train")
dataset

Dataset({
    features: ['text'],
    num_rows: 6861
})

Training Args

In [5]:
gradient_accumulation_steps = 8
batch_size = 1

In [6]:
args = TrainingArguments(output_dir = "./llama_finetune", 
                         per_device_train_batch_size = batch_size, 
                         num_train_epochs=3,
                         gradient_accumulation_steps=gradient_accumulation_steps,
                         optim="adamw_torch",learning_rate=5e-5,
                         warmup_steps=500,save_strategy="epoch",bf16=True, )

Supervised Fine-Tuning

In [7]:
trainer = SFTTrainer(model = model, args = args, 
                     train_dataset = dataset,
                     tokenizer = tokenizer, 
                     peft_config = LoraConfig(task_type = "CAUSAL_LM", r = 128, lora_alpha = 16, lora_dropout = .1, ),
                     dataset_text_field = "text")


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


In [8]:
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhawkiyc[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/2571 [00:00<?, ?it/s]

{'loss': 1.5151, 'grad_norm': 0.1564282327890396, 'learning_rate': 5e-05, 'epoch': 0.58}
{'loss': 1.1573, 'grad_norm': 0.16565777361392975, 'learning_rate': 3.7928536938676965e-05, 'epoch': 1.17}
{'loss': 1.1089, 'grad_norm': 0.13002827763557434, 'learning_rate': 2.5857073877353938e-05, 'epoch': 1.75}
{'loss': 1.108, 'grad_norm': 0.13994261622428894, 'learning_rate': 1.3785610816030902e-05, 'epoch': 2.33}
{'loss': 1.0926, 'grad_norm': 0.17559672892093658, 'learning_rate': 1.7141477547078708e-06, 'epoch': 2.92}
{'train_runtime': 15876.7396, 'train_samples_per_second': 1.296, 'train_steps_per_second': 0.162, 'train_loss': 1.193868968701001, 'epoch': 3.0}


TrainOutput(global_step=2571, training_loss=1.193868968701001, metrics={'train_runtime': 15876.7396, 'train_samples_per_second': 1.296, 'train_steps_per_second': 0.162, 'total_flos': 6.162075592922726e+17, 'train_loss': 1.193868968701001, 'epoch': 2.9978137297770004})

In [42]:
prompt = "What is malaria?"
text_generation_pipeline = pipeline(
    task = "text-generation", 
    model = model, 
    tokenizer = tokenizer, 
    max_new_tokens = 512, )
with torch.autocast("cuda"):
    model_answer = text_generation_pipeline(f"<s>[INST] {prompt} [/INST]")
print(model_answer[0]['generated_text'])

<s>[INST] What is malaria? [/INST] Malaria is a serious disease caused by a parasite that is transmitted to humans through the bite of an infected mosquito. There are four different species of the malaria parasite, and they cause different types of malaria. The most common type is Plasmodium falciparum, which is found primarily in sub-Saharan Africa. Malaria is caused by the bite of an infected female Anopheles mosquito. The parasite is transmitted when the mosquito bites a person, and the parasite is then transmitted to the person’s liver. The parasite then multiplies in the liver and is released into the bloodstream, where it causes illness. Malaria is typically characterized by fever, chills, and flu-like symptoms. In severe cases, malaria can cause coma, seizures, and death. Malaria is treated with antimalarial medication, and in severe cases, hospitalization may be necessary. Malaria is a major cause of illness and death in tropical and subtropical regions of the world, particular

In [43]:
user_prompt = "Please tell me about Bursitis"
text_generation_pipeline = pipeline(task = "text-generation", model = model, tokenizer = tokenizer, max_new_tokens = 300)
with torch.autocast("cuda"):
    model_answer = text_generation_pipeline(f"<s>[INST] {user_prompt} [/INST]")
print(model_answer[0]['generated_text'])

<s>[INST] Please tell me about Bursitis [/INST] Bursitis is inflammation of a bursa, which is a fluid-filled sac or cavity that cushions and reduces friction between soft tissues and bones in the body. Bursae are located around joints and in areas prone to friction. The inflammation of a bursa can be caused by injury, infection, or overuse. Bursitis can cause pain, swelling, and redness of the affected area. The pain may be sharp and stabbing, or it may be a dull ache. The pain may be constant, or it may come and go. The pain may be worse with movement of the affected joint or area. Bursitis can be acute or chronic. Acute bursitis is sudden and short-lived. Chronic bursitis is long-lasting and may be caused by a chronic condition. Bursitis can be caused by a bacterial infection, which is known as septic bursitis. Septic bursitis is a medical emergency that requires immediate treatment. Bursitis can be treated with rest, ice, compression, and elevation (RICE). Pain relievers may be pres