Load Libraries

In [1]:
from dotenv import load_dotenv
import os
from math import ceil
import torch
from trl import SFTTrainer
from peft import LoraConfig
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline

load_dotenv()
os.environ["HF_TOKEN"] = os.getenv('HUGGINGFACE_TOKEN')

Load LLM Model

In [2]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit = True, 
    bnb_4bit_quant_type = "nf4", 
    bnb_4bit_compute_dtype = torch.bfloat16,
    bnb_4bit_use_double_quant = True,)

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct",
                                             quantization_config = quantization_config, )
model.config.use_cache = False
model.config.pretraining_tp = 1

`low_cpu_mem_usage` was None, now default to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", 
                                          trust_remote_code = True, )
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

Load Dataset

In [4]:
dataset = load_dataset(path = "aboonaji/wiki_medical_terms_llam2_format", split = "train")
dataset

Dataset({
    features: ['text'],
    num_rows: 6861
})

Training Args

In [5]:
gradient_accumulation_steps = 8
batch_size = 1

In [6]:
args = TrainingArguments(output_dir = "./llm_finetune", 
                         per_device_train_batch_size = batch_size, 
                         num_train_epochs=3,
                         gradient_accumulation_steps=gradient_accumulation_steps,
                         optim="adamw_torch",learning_rate=5e-5,
                         warmup_steps=500,save_strategy="epoch",bf16=True, )

Supervised Fine-Tuning

In [7]:
trainer = SFTTrainer(model = model, args = args, 
                     train_dataset = dataset,
                     tokenizer = tokenizer, 
                     peft_config = LoraConfig(task_type = "CAUSAL_LM", r = 128, lora_alpha = 16, lora_dropout = .1, ),
                     dataset_text_field = "text")


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
  super().__init__(


In [8]:
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhawkiyc[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/2571 [00:00<?, ?it/s]

{'loss': 1.6594, 'grad_norm': 0.10192049294710159, 'learning_rate': 5e-05, 'epoch': 0.58}
{'loss': 1.4494, 'grad_norm': 0.1142600029706955, 'learning_rate': 3.7928536938676965e-05, 'epoch': 1.17}
{'loss': 1.4311, 'grad_norm': 0.12236087024211884, 'learning_rate': 2.5857073877353938e-05, 'epoch': 1.75}
{'loss': 1.424, 'grad_norm': 0.12590987980365753, 'learning_rate': 1.3785610816030902e-05, 'epoch': 2.33}
{'loss': 1.4152, 'grad_norm': 0.13874340057373047, 'learning_rate': 1.7141477547078708e-06, 'epoch': 2.92}
{'train_runtime': 16526.4682, 'train_samples_per_second': 1.245, 'train_steps_per_second': 0.156, 'train_loss': 1.4745152969390027, 'epoch': 3.0}


TrainOutput(global_step=2571, training_loss=1.4745152969390027, metrics={'train_runtime': 16526.4682, 'train_samples_per_second': 1.245, 'train_steps_per_second': 0.156, 'total_flos': 6.162075592922726e+17, 'train_loss': 1.4745152969390027, 'epoch': 2.9978137297770004})

In [9]:
prompt = "What is malaria?"
text_generation_pipeline = pipeline(
    task = "text-generation", 
    model = model, 
    tokenizer = tokenizer,
    torch_dtype = torch.bfloat16, 
    temperature = 0.7,
    top_p = .95,
    max_new_tokens = 1024,
    trust_remote_code = True,
    return_full_text = False)
with torch.autocast("cuda"):
    model_answer = text_generation_pipeline(f"<s>[INST] {prompt} [/INST]")
print(model_answer[0]['generated_text'])

Malaria is a serious and sometimes life-threatening disease that is most commonly transmitted by an infected female Anopheles mosquito. The disease is caused by a parasite, specifically Plasmodium falciparum, Plasmodium vivax, Plasmodium ovale, or Plasmodium malariae. Malaria is usually found in tropical and subtropical regions of the world. The disease can be diagnosed through a blood test. Symptoms of malaria include fever, chills, and flu-like symptoms. In severe cases, malaria can cause coma, and death if left untreated. If treated promptly with antimalarial medication, the disease usually has a good prognosis. In areas where malaria is common, preventive medication can be taken. In areas where the disease is common, mosquito nets can be used to prevent bites. In areas where the disease is rare, travelers can take preventive medication. In areas where the disease is common, bed nets with insecticide can be used. The nets should be washed every six months. Mosquitoes that transmit m

In [10]:
prompt = "Please tell me about Bursitis"
with torch.autocast("cuda"):
    model_answer = text_generation_pipeline(f"<s>[INST] {prompt} [/INST]")
print(model_answer[0]['generated_text'])

 Bursitis is a condition characterized by inflammation of a bursa, which is a fluid-filled sac, usually located near the joints that cushion and reduce friction between bone and soft tissue. The most commonly affected bursae are the subacromial, olecranon, prepatellar, infrapatellar, and retrocalcaneal bursae.  The inflammation is usually caused by repetitive trauma or friction to the affected area, and is sometimes accompanied by infection. Bursitis can be diagnosed with physical examination, X-ray, and ultrasound, and treatment depends on the severity of symptoms and underlying cause, and may include rest, ice, compression, elevation, pain management, and aspiration of the bursa. The symptoms of bursitis may be accompanied by fever, redness, swelling, and warmth in the affected area.  The most common complications of bursitis include chronic inflammation, infection, and adhesions in the affected area. The word "bursitis" is derived from the Greek words "bursa," meaning sack, and "iti