#### Install Dependencies

In [None]:
!pip install transformers peft bitsandbytes datasets accelerate loralib 

In [3]:
import torch
import os
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


##### Setup QLoRA config

In [6]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant = True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dype=torch.bfloat16
)

Unused kwargs: ['bnb_4bit_compute_dype']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


##### Get baseline model

In [5]:
model_id = "meta-llama/Meta-Llama-3-8B"
baseline_model  = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    use_cache=True,
    device_map=0,
    
)

NameError: name 'bnb_config' is not defined

Setup LoRA config

In [2]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

lora_config = LoraConfig(
    r=64,
    lora_alpha = 64,
    lora_dropout = 0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

In [None]:
# Setup model for 4bit SFT
base_model = prepare_model_for_kbit_training(baseline_model)
q_model = get_peft_model(base_model, lora_config)

##### Download and process dataset

In [5]:
from datasets import load_dataset

In [6]:
dataset_id = "c-s-ale/alpaca-gpt4-data"
dataset = load_dataset(dataset_id)

Downloading readme: 100%|██████████| 1.39k/1.39k [00:00<00:00, 1.64MB/s]
Downloading data: 100%|██████████| 43.4M/43.4M [00:02<00:00, 20.6MB/s]
Generating train split: 100%|██████████| 52002/52002 [00:00<00:00, 127068.36 examples/s]


In [9]:
# Start with small sample set of dataset
dataset_sample_set =dataset['train'].select(range(5000))

In [10]:
def preprocess_prompt(sample, gen_response=True):
    prompt = "Generate a simple instruction that could result in the provided context."
    prompt += f"[INST]CONTEXT: {sample['output']}[/INST]"
    
    if gen_response:
        prompt += "INSTRUCTION: "
        prompt += f"{sample['instruction']}"
    
    return [prompt]
        
        
        

    

In [11]:
preprocess_prompt(dataset_sample_set[0])[0]

'Generate a simple instruction that could result in the provided context.[INST]CONTEXT: 1. Eat a balanced and nutritious diet: Make sure your meals are inclusive of a variety of fruits and vegetables, lean protein, whole grains, and healthy fats. This helps to provide your body with the essential nutrients to function at its best and can help prevent chronic diseases.\n\n2. Engage in regular physical activity: Exercise is crucial for maintaining strong bones, muscles, and cardiovascular health. Aim for at least 150 minutes of moderate aerobic exercise or 75 minutes of vigorous exercise each week.\n\n3. Get enough sleep: Getting enough quality sleep is crucial for physical and mental well-being. It helps to regulate mood, improve cognitive function, and supports healthy growth and immune function. Aim for 7-9 hours of sleep each night.[/INST]INSTRUCTION: Give three tips for staying healthy.'

##### Setup Tokenizer

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_id)

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


##### Setup SFT Training Arguments

In [25]:
training_args = TrainingArguments(
    output_dir="llama3-8b-sft-instruct",
    num_train_epochs=5,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    optim='paged_adamw_32bit',
    logging_steps=1,
    learning_rate=2e-4,
    bf16=True,
    tf32=False, #Using laptop old GPU
    lr_scheduler_type="constant",
)


In [None]:
from trl import SFTTrainer

In [None]:
max_seq_len = 2048

trainer = SFTTrainer(
    model=q_model,
    train_dataset = dataset_sample_set,
    peft_config = lora_config,
    max_seq_length=max_seq_len,
    tokenizer=tokenizer,
    formatting_func= preprocess_prompt,
    args=training_args
)