In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch, wandb
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format

  from .autonotebook import tqdm as notebook_tqdm


[2024-07-25 18:33:41,039] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
from huggingface_hub import login

hf_token = ''

login(token = hf_token)

wandb.login()
run = wandb.init(
    project='Fine-tune Llama 3 8B', 
    job_type="training", 
    anonymous="allow"
)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/karthick/.cache/huggingface/token
Login successful


[34m[1mwandb[0m: Currently logged in as: [33mj_karthic[0m ([33mhighlight-ing[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
base_model = "meta-llama/Meta-Llama-3-8B-Instruct"

In [4]:
torch_dtype = torch.bfloat16
attn_implementation = "eager"

In [5]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

Loading checkpoint shards: 100%|███████████████████████████████████████████| 4/4 [00:08<00:00,  2.05s/it]


In [6]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
model, tokenizer = setup_chat_format(model, tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [7]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

In [8]:
from datasets import load_dataset
from typing import Dict, List, Optional

dataset_name = "tryhighlight/task_dataset"
#Importing the dataset
input_dataset = load_dataset(dataset_name, split="train")
print (input_dataset)

Dataset({
    features: ['NAME', 'CONVERSATION', 'TASK'],
    num_rows: 30200
})


In [9]:
system_prompt = "The user will provide his/her full name followed by an email thread from their inbox that contains one or more email conversations. Looking at the conversation, detect if there are any TODOs that the user has to complete as a result of the conversation. If yes, just provide the short single line task that can be directly added to the todo list. If there is no task detected as a TODO, just output the exact phrase \"No task\". If the conversation is about a promotional or advertisement related, please output \"No task\". If the conversation is directed or addressed to someone else, then output \"No task\". Just provide the short single line task or the phrase \"No task\", no other explanation is needed."

In [10]:
max_len = 512

In [11]:
def check_length(row):
    # This function checks if the tokenized length is within the max length allowed
    input_content = tokenizer.encode(system_prompt + row["NAME"] + row["CONVERSATION"] + row["TASK"],
                                        add_special_tokens=True,
                                        truncation=False,
                                        return_length=True,
                                        max_length=None)
    return len(input_content) <= max_len - 20 # extra 20 tokens for the chat template

# Filter the dataset to exclude entries that are too long
input_dataset = input_dataset.filter(check_length, num_proc=4)
print(f"Size of the dataset after filtering: {len(input_dataset)} samples")

Size of the dataset after filtering: 17834 samples


In [12]:
tokenizer.special_tokens_map

{'bos_token': '<|im_start|>',
 'eos_token': '<|im_end|>',
 'pad_token': '<|im_end|>',
 'additional_special_tokens': ['<|im_start|>', '<|im_end|>']}

In [13]:
def format_chat_template(row):
    row_json = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": "My name is " + row["NAME"] + " \n" + row["CONVERSATION"]},
            {"role": "assistant", "content": row["TASK"]}]
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)
    return row

input_dataset = input_dataset.map(
    format_chat_template,
    num_proc=4,
)

In [14]:
dataset=input_dataset

In [15]:
dataset = dataset.train_test_split(test_size=0.1, seed=42)

In [16]:
new_model = "llama-3-8b-task-detect"
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=2,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="wandb"
)




In [17]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    max_seq_length=max_len,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
Map: 100%|██████████████████████████████████████████████████| 1784/1784 [00:01<00:00, 1087.53 examples/s]


In [None]:
trainer.train()

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Step,Training Loss,Validation Loss
3210,1.0169,0.717614


