Load Libraries

In [1]:
from dotenv import load_dotenv
import os
from math import ceil
import torch
from trl import SFTTrainer
from peft import LoraConfig
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

Load LLM Model

In [2]:
quantization_config = BitsAndBytesConfig(load_in_8bit = True,)

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B-Instruct",
                                             quantization_config = quantization_config, 
                                             token = os.getenv("HUGGINGFACE_TOKEN"))
model.config.use_cache = False
model.config.pretraining_tp = 1

`low_cpu_mem_usage` was None, now set to True since model is quantized.


In [3]:
tokenizer = AutoTokenizer.from_pretrained("pcuenq/Llama-3.2-1B-Instruct-tokenizer", 
                                          trust_remote_code = True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'right'

Load Dataset

In [4]:
dataset = load_dataset(path = "aboonaji/wiki_medical_terms_llam2_format", split = "train")
dataset

Dataset({
    features: ['text'],
    num_rows: 6861
})

Training Args

In [5]:
sample_size = dataset.num_rows
batch_size = 4

In [6]:
args = TrainingArguments(output_dir = "./llama_finetune", 
                         per_device_train_batch_size = batch_size, 
                         max_steps = 5 * (ceil(sample_size/batch_size)))

Supervised Fine-Tuning

In [7]:
trainer = SFTTrainer(model = model, args = args, 
                     train_dataset = dataset,
                     tokenizer = tokenizer, 
                     peft_config = LoraConfig(task_type = "CAUSAL_LM", r = 128, lora_alpha = 16, lora_dropout = .1, ),
                     dataset_text_field = "text")


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
max_steps is given, it will override any value given in num_train_epochs


In [8]:
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mhawkiyc[0m. Use [1m`wandb login --relogin`[0m to force relogin


  6%|▌         | 500/8580 [11:29<3:12:31,  1.43s/it]

{'loss': 1.8533, 'grad_norm': 0.13843896985054016, 'learning_rate': 4.708624708624709e-05, 'epoch': 0.29}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 12%|█▏        | 1000/8580 [23:04<2:52:11,  1.36s/it]

{'loss': 1.7424, 'grad_norm': 0.17795269191265106, 'learning_rate': 4.4172494172494175e-05, 'epoch': 0.58}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 17%|█▋        | 1500/8580 [34:37<2:44:39,  1.40s/it]

{'loss': 1.7314, 'grad_norm': 0.19075030088424683, 'learning_rate': 4.125874125874126e-05, 'epoch': 0.87}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 23%|██▎       | 2000/8580 [46:16<2:34:03,  1.40s/it]

{'loss': 1.7516, 'grad_norm': 0.26657453179359436, 'learning_rate': 3.834498834498835e-05, 'epoch': 1.17}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 29%|██▉       | 2500/8580 [57:51<2:23:18,  1.41s/it]

{'loss': 1.7234, 'grad_norm': 0.18379968404769897, 'learning_rate': 3.5431235431235434e-05, 'epoch': 1.46}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 35%|███▍      | 3000/8580 [1:09:24<2:08:59,  1.39s/it]

{'loss': 1.7201, 'grad_norm': 0.18118135631084442, 'learning_rate': 3.251748251748252e-05, 'epoch': 1.75}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 41%|████      | 3500/8580 [1:20:56<2:00:38,  1.42s/it]

{'loss': 1.7078, 'grad_norm': 0.1875627189874649, 'learning_rate': 2.9603729603729606e-05, 'epoch': 2.04}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 47%|████▋     | 4000/8580 [1:32:30<1:49:39,  1.44s/it]

{'loss': 1.7167, 'grad_norm': 0.19575345516204834, 'learning_rate': 2.6689976689976692e-05, 'epoch': 2.33}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 52%|█████▏    | 4500/8580 [1:44:09<1:36:21,  1.42s/it]

{'loss': 1.7323, 'grad_norm': 0.22222746908664703, 'learning_rate': 2.377622377622378e-05, 'epoch': 2.62}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 58%|█████▊    | 5000/8580 [1:55:44<1:23:10,  1.39s/it]

{'loss': 1.6974, 'grad_norm': 0.22694863379001617, 'learning_rate': 2.0862470862470865e-05, 'epoch': 2.91}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 64%|██████▍   | 5500/8580 [2:07:13<54:49,  1.07s/it]  

{'loss': 1.6982, 'grad_norm': 0.3130342662334442, 'learning_rate': 1.794871794871795e-05, 'epoch': 3.21}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 70%|██████▉   | 6000/8580 [2:18:53<1:01:06,  1.42s/it]

{'loss': 1.7119, 'grad_norm': 0.17226332426071167, 'learning_rate': 1.5034965034965034e-05, 'epoch': 3.5}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 76%|███████▌  | 6500/8580 [2:30:34<49:46,  1.44s/it]  

{'loss': 1.7183, 'grad_norm': 0.18550890684127808, 'learning_rate': 1.2121212121212122e-05, 'epoch': 3.79}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 82%|████████▏ | 7000/8580 [2:42:09<37:32,  1.43s/it]

{'loss': 1.7166, 'grad_norm': 0.2062445878982544, 'learning_rate': 9.207459207459208e-06, 'epoch': 4.08}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 87%|████████▋ | 7500/8580 [2:53:37<25:38,  1.42s/it]

{'loss': 1.6913, 'grad_norm': 0.2218349725008011, 'learning_rate': 6.2937062937062944e-06, 'epoch': 4.37}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 93%|█████████▎| 8000/8580 [3:05:17<13:06,  1.36s/it]

{'loss': 1.7189, 'grad_norm': 0.2134796380996704, 'learning_rate': 3.3799533799533803e-06, 'epoch': 4.66}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
 99%|█████████▉| 8500/8580 [3:16:52<01:54,  1.43s/it]

{'loss': 1.7083, 'grad_norm': 0.18375183641910553, 'learning_rate': 4.662004662004662e-07, 'epoch': 4.95}



Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.

Cannot access gated repo for url https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct/resolve/main/config.json.
Access to model meta-llama/Llama-3.2-1B-Instruct is restricted. You must have access to it and be authenticated to access it. Please log in. - silently ignoring the lookup for the file config.json in meta-llama/Llama-3.2-1B-Instruct.
100%|██████████| 8580/8580 [3:18:41<00:00,  1.39s/it]

{'train_runtime': 11929.895, 'train_samples_per_second': 2.877, 'train_steps_per_second': 0.719, 'train_loss': 1.725535412768384, 'epoch': 5.0}





TrainOutput(global_step=8580, training_loss=1.725535412768384, metrics={'train_runtime': 11929.895, 'train_samples_per_second': 2.877, 'train_steps_per_second': 0.719, 'total_flos': 2.01700558334976e+17, 'train_loss': 1.725535412768384, 'epoch': 5.0})

In [9]:
prompt = "What is malaria?"
text_generation_pipeline = pipeline(task = "text-generation", model = model, tokenizer = tokenizer, max_length = 300)
model_answer = text_generation_pipeline(f"<s>[INST] {prompt} [/INST]")
print(model_answer[0]['generated_text'])

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


<s>[INST] What is malaria? [/INST] Malaria is a serious and sometimes life-threatening disease caused by a parasite of the Plasmodium genus. The parasite infects red blood cells and causes the red blood cells to die. The disease is spread by the bite of an infected female Anopheles mosquito. Symptoms include fever, chills, flu-like symptoms, and in severe cases, a headache, vomiting, and in some cases, kidney failure and death. Malaria is a major public health problem in tropical and subtropical regions of the world, with an estimated 228 million cases and 435,000 deaths in 2019.
The disease is caused by the Plasmodium species P. vivax, P. ovale, P. malariae, and P. falciparum, which are all transmitted by the Anopheles mosquito. The parasite is transmitted through the bite of an infected female Anopheles mosquito that feeds on the blood of an infected individual, usually after 10 days of infection. The parasite multiplies in the red blood cells of the host, and the infected red blood 

In [10]:
user_prompt = "Please tell me about Bursitis"
text_generation_pipeline = pipeline(task = "text-generation", model = model, tokenizer = tokenizer, max_length = 300)
model_answer = text_generation_pipeline(f"<s>[INST] {user_prompt} [/INST]")
print(model_answer[0]['generated_text'])

<s>[INST] Please tell me about Bursitis [/INST] Bursitis is inflammation of a bursa, a fluid-filled sac that cushions a joint, tendon, or muscle. Bursitis is often caused by repetitive motion, injury, or infection. It can be a chronic or acute condition. Bursitis is usually painless and can be associated with other conditions such as arthritis. Bursitis is a common cause of joint pain and can be treated with medication, physical therapy, or surgery.
Inflammation of the bursa, the bursa is a fluid-filled sac that cushions a joint, tendon, or muscle. The bursa is filled with synovial fluid, a fluid that is produced by the synovial membrane. The bursa is located in the area of the joint, tendon, or muscle that is subjected to repetitive motion. The bursa is a protective layer that reduces friction between the joint, tendon, or muscle and the bone. When the bursa becomes inflamed, the friction between the joint, tendon, or muscle and the bone increases, causing pain and inflammation.
Bursi

wandb: ERROR Error while calling W&B API: context deadline exceeded (<Response [500]>)
wandb: ERROR Error while calling W&B API: context deadline exceeded (<Response [500]>)
wandb: ERROR Error while calling W&B API: context deadline exceeded (<Response [500]>)
wandb: ERROR Error while calling W&B API: context deadline exceeded (<Response [500]>)
