In [None]:
!pip install -U -q bitsandbytes wandb git+https://github.com/huggingface/transformers sentencepiece accelerate datasets peft  trl huggingface_hub flash-attn

In [None]:
from wandb import login
login(key="XXX")

In [None]:
from huggingface_hub import login
login(token="XXX")

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, BitsAndBytesConfig, set_seed
from peft import LoraConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM, setup_chat_format
from datasets import load_dataset
from accelerate import Accelerator
import torch

modelpath = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
    modelpath,    
    quantization_config = BitsAndBytesConfig(
        load_in_4bit = True,
        bnb_4bit_compute_dtype = torch.bfloat16,
        bnb_4bit_quant_type = "nf4",
    ),
    attn_implementation = "flash_attention_2",         
)
tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast = False)

model, tokenizer = setup_chat_format(model, tokenizer)
if tokenizer.pad_token in [None, tokenizer.eos_token]: 
    tokenizer.pad_token = tokenizer.unk_token

dataset = load_dataset("g-ronimo/oasst2_top4k_en")

training_arguments = TrainingArguments(
    output_dir = "out",
    evaluation_strategy = "steps",
    label_names = ["labels"],
    per_device_train_batch_size = 16,
    gradient_accumulation_steps = 1,
    save_steps = 250,
    eval_steps = 250,
    logging_steps = 1, 
    learning_rate = 0.0002,
    lr_scheduler_type = "constant",
    optim = 'paged_adamw_32bit',
    bf16 = True,
    gradient_checkpointing = True,
    group_by_length = True,
)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset["train"],
    eval_dataset = dataset['test'],
    data_collator = DataCollatorForCompletionOnlyLM(
        instruction_template = "<|im_start|>user", 
        response_template = "<|im_start|>assistant", 
        tokenizer = tokenizer, 
        mlm = False),
    max_seq_length = 512,
    peft_config = LoraConfig(target_modules = "all-linear", modules_to_save = ["lm_head", "embed_tokens"]),
    args = training_arguments,
)

trainer.train()