In [None]:
import os
from dataclasses import dataclass, field
from typing import Optional
from types import SimpleNamespace

import torch
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments

from trl import SFTTrainer
from trl.trainer import ConstantLengthDataset

## Config

In [None]:
config = SimpleNamespace(
    model_name="meta-llama/Llama-2-7b-hf", # make sure you accepted all the agreements to use this model
    # data
    dataset_name="lvwerra/stack-exchange-paired",
    subset="data/finetune",
    split="train",
    size_valid_set=4000,
    streaming=True,
    shuffle_buffer=5000,
    seq_length=1024,
    # models
    lora_alpha=16.0,
    lora_dropout=0.05,
    lora_r=8,
    # training
    learning_rate=0.0001,
    lr_scheduler_type="cosine",
    optimizer_type="paged_adamw_32bit",
    num_warmup_steps=100,
    max_steps=500,
    logging_steps=10,
    save_steps=10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    gradient_checkpointing=True,
    group_by_length=False, # the `--group_by_length` option is only available for `Dataset`, not `IterableDataset
    weight_decay=0.05,
    # logging
    log_with="wandb",
    output_dir="./llama2-stackexchange/sft/results",
    log_freq=1,
)

## HF Objects

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(
    config.model_name, # "meta-llama/Llama-2-7b-hf"
    quantization_config=bnb_config,
    device_map={"": 0},
    trust_remote_code=True,
)

base_model.config.use_cache = False

In [None]:
peft_config = LoraConfig(
    r=config.lora_r,
    lora_alpha=config.lora_alpha, # 16.0
    lora_dropout=config.lora_dropout, # 0.05
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # Fix weird overflow issue with fp16 training

## 1. Data

In [None]:
def prepare_sample_text(example):
    """Prepare the text from a sample of the dataset."""
    text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
    return text

In [None]:
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
    """Estimate the average number of characters per token in the dataset."""
    total_characters, total_tokens = 0, 0
    for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
        text = prepare_sample_text(example)
        total_characters += len(text)
        if tokenizer.is_fast:
            total_tokens += len(tokenizer(text).tokens())
        else:
            total_tokens += len(tokenizer.tokenize(text))

    return total_characters / total_tokens

In [None]:
def create_datasets(tokenizer, args):
    dataset = load_dataset(
        args.dataset_name, # "lvwerra/stack-exchange-paired"
        data_dir=args.subset, # "data/finetune"
        split=args.split, # "train"
        use_auth_token=True,
        num_proc=args.num_workers if not args.streaming else None,
        streaming=args.streaming,
    )
    if args.streaming:
        print("Loading the dataset in streaming mode")
        valid_data = dataset.take(args.size_valid_set) # 4000
        train_data = dataset.skip(args.size_valid_set) # * - 4000
        train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None) # 5000
    else:
        dataset = dataset.train_test_split(test_size=0.005, seed=None)
        train_data = dataset["train"]
        valid_data = dataset["test"]
        print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")

    chars_per_token = chars_token_ratio(train_data, tokenizer)
    print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")

    train_dataset = ConstantLengthDataset(
        tokenizer,
        train_data,
        formatting_func=prepare_sample_text,
        infinite=True,
        seq_length=args.seq_length,
        chars_per_token=chars_per_token,
    )
    valid_dataset = ConstantLengthDataset(
        tokenizer,
        valid_data,
        formatting_func=prepare_sample_text,
        infinite=False,
        seq_length=args.seq_length,
        chars_per_token=chars_per_token,
    )
    return train_dataset, valid_dataset

In [None]:
train_dataset, eval_dataset = create_datasets(tokenizer, config)

## Train

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")

In [None]:
training_args = TrainingArguments(
    output_dir=config.output_dir, # "./llama2-stackexchange/sft/results"
    per_device_train_batch_size=config.per_device_train_batch_size, # 4
    gradient_accumulation_steps=config.gradient_accumulation_steps, # 2
    per_device_eval_batch_size=config.per_device_eval_batch_size, # 1
    learning_rate=config.learning_rate, # 1e-4
    logging_steps=config.logging_steps, # 10
    max_steps=config.max_steps, # 500
    report_to=config.log_with, # "wandb"
    save_steps=config.save_steps, # 10
    group_by_length=config.group_by_length, # False
    lr_scheduler_type=config.lr_scheduler_type, # "cosine"
    warmup_steps=config.num_warmup_steps, # 100
    optim=config.optimizer_type, # "paged_adam2_32bit"
    bf16=True,
    remove_unused_columns=False,
    run_name="llama2_sft_stackexchange",
)

In [None]:
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args,
)

In [None]:
trainer.train()

## Save

In [None]:
trainer.save_model(config.output_dir)

In [None]:
output_dir = os.path.join(config.output_dir, "final_checkpoint")
trainer.model.save_pretrained(output_dir)

In [None]:
# Free memory for merging weights
del base_model
torch.cuda.empty_cache()

In [None]:
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()

output_merged_dir = os.path.join(config.output_dir, "final_merged_checkpoint")
model.save_pretrained(output_merged_dir, safe_serialization=True)