In [1]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, DataCollatorForLanguageModeling
from peft import LoraConfig
from trl import SFTTrainer

  from .autonotebook import tqdm as notebook_tqdm


[2024-02-28 17:06:21,151] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
raw_dataset = load_dataset("timdettmers/openassistant-guanaco", split="train")
model_id = "meta-llama/Llama-2-7b-hf"

Repo card metadata block was not found. Setting CardData to empty.


In [3]:
# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

In [4]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [5]:
model = AutoModelForCausalLM.from_pretrained(model_id, 
                                             quantization_config=bnb_config, 
                                             device_map="auto",
                                             attn_implementation="flash_attention_2"
                                            )

Loading checkpoint shards: 100%|█████████████████████████████████████| 2/2 [00:04<00:00,  2.23s/it]


In [6]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=32,
    bias="none",
    task_type="CAUSAL_LM", 
)

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="output",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    logging_strategy="steps",
    logging_steps=20,
    bf16=True,
    optim="paged_adamw_8bit",
    
)

# Initialize our Trainer
trainer = SFTTrainer(
    model=model,
    peft_config=peft_config,
    args=training_args,
    dataset_text_field="text",
    packing=True,
    train_dataset=raw_dataset,
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)
# Train the model
trainer.train()