In [None]:
import os, random, re
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"

from typing import List, Literal, Optional

import torch

from accelerate import Accelerator, DeepSpeedPlugin, notebook_launcher
from datasets import DatasetDict, concatenate_datasets, load_dataset
from peft import LoraConfig, PeftConfig
from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer, set_seed
from trl import SFTTrainer


from configs import DataArguments, ModelArguments, SFTConfig

## Configuration

In [None]:
data_args, model_args, training_args = (
    DataArguments(),
    ModelArguments(),
    SFTConfig(output_dir="../../models/alignment-handbook/zephyr-7b-sft-lora"),
)

# data args
data_args.dataset_splits = ["train_sft", "test_sft"]
data_args.dataset_mixer = {"HuggingFaceH4/ultrachat_200k": 1.0}
data_args.preprocessing_num_workers = 12

# model args
model_args.model_name_or_path = "mistralai/Mistral-7B-v0.1"
model_args.torch_dtype = "auto"
model_args.use_flash_attention_2 = True  # default = False
# quantization
model_args.load_in_4bit = True  # default = False
model_args.load_in_8bit = False
model_args.bnb_4bit_quant_type = "nf4"
model_args.use_bnb_nested_quant = False
# LoRA
model_args.use_peft = True  # default = False
model_args.lora_r = 64  # default = 16
model_args.lora_alpha = 16  # default = 32
model_args.lora_dropout = 0.1  # default = 0.05
model_args.lora_target_modules = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
]  # default = None
model_args.lora_modules_to_save = None

# trainer config (SFT)
training_args.bf16 = True
training_args.do_eval = True
training_args.evaluation_strategy = "epoch"
training_args.gradient_accumulation_steps = 128
training_args.gradient_checkpointing = True
training_args.gradient_checkpointing_kwargs = {"use_reentrant": False}
training_argshub_model_id = "zephyr-7b-sft-lora"
training_argshub_strategy = "every_save"
training_argslearning_rate = 2.0e-05
training_argslog_level = "info"
training_argslogging_steps = 5
training_argslogging_strategy = "steps"
training_args.lr_scheduler_type = "cosine"
training_args.max_seq_length = 2048
training_args.max_steps = -1
training_args.num_train_epochs = 1
training_args.output_dir = "../../models/alignment-handbook/zephyr-7b-sft-lora"
training_args.overwrite_output_dir = True
training_args.per_device_eval_batch_size = 4 #8
training_args.per_device_train_batch_size = 2 #4
training_args.push_to_hub = True
training_args.report_to = "none" # default = ["tensorboard"]
training_args.save_strategy = "no"
training_args.save_total_limit = None
training_args.seed = 42


shuffle_datasets = True

In [None]:
print(data_args)
print("="*80)
print(model_args)
print("="*80)
print(training_args)

## Step 1: Build Datasets

In [None]:
raw_datasets = DatasetDict()
raw_train_datasets = []
raw_val_datasets = []
fracs = []

for ds, frac in data_args.dataset_mixer.items():
    fracs.append(frac)
    for split in data_args.dataset_splits:
        if "train" in split:
            raw_train_datasets.append(
                load_dataset(
                    ds,
                    split=split,
                )
            )
        elif "test" in split:
            raw_val_datasets.append(
                load_dataset(
                    ds,
                    split=split,
                )
            )
        else:
            raise ValueError(f"Split type {split} not recognized as one of test or train.")

if any(frac < 0 for frac in fracs):
    raise ValueError("Dataset fractions cannot be negative.")

In [None]:
if len(raw_train_datasets) > 0:
    train_subsets = []
    for dataset, frac in zip(raw_train_datasets, fracs):
        train_subset = dataset.select(range(int(frac * len(dataset))))
        train_subsets.append(train_subset)
    if shuffle_datasets:
        raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
    else:
        raw_datasets["train"] = concatenate_datasets(train_subsets)

In [None]:
# No subsampling for test datasets to enable fair comparison across models
if len(raw_val_datasets) > 0:
    if shuffle_datasets:
        raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42)
    else:
        raw_datasets["test"] = concatenate_datasets(raw_val_datasets)

In [None]:
if len(raw_datasets) == 0:
    raise ValueError(
        f"Dataset {data_args.dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
    )

In [None]:
print(f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}")

In [None]:
raw_datasets

In [None]:
raw_datasets["train"][0]

## Step 2: Build Tokenizer

In [None]:
DEFAULT_CHAT_TEMPLATE = """\
{% for message in messages %}
{% if message['role'] == 'user' %}
{{ '<|user|>\n' + message['content'] + eos_token }}
{% elif message['role'] == 'system' %}
{{ '<|system|>\n' + message['content'] + eos_token }}
{% elif message['role'] == 'assistant' %}
{{ '<|assistant|>\n'  + message['content'] + eos_token }}
{% endif %}
{% if loop.last and add_generation_prompt %}
{{ '<|assistant|>' }}
{% endif %}
{% endfor %}"""

In [None]:
print(model_args.model_name_or_path)
print(model_args.model_revision)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, revision=model_args.model_revision)

print(tokenizer.eos_token, tokenizer.eos_token_id, tokenizer.truncation_side, tokenizer.model_max_length, tokenizer.chat_template, )
print(tokenizer.special_tokens_map)


In [None]:
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

if data_args.truncation_side is not None:
    tokenizer.truncation_side = data_args.truncation_side

# Set reasonable default for models without max length
if tokenizer.model_max_length > 100_000:
    tokenizer.model_max_length = 2048

if data_args.chat_template is not None:
    tokenizer.chat_template = data_args.chat_template
elif tokenizer.chat_template is None:
    tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE

In [None]:
print(tokenizer.eos_token, tokenizer.eos_token_id, tokenizer.truncation_side, tokenizer.model_max_length, tokenizer.chat_template, )
print(tokenizer.special_tokens_map)

## Step 3: Apply Chat Template

In [None]:
def apply_chat_template_for_sft(example, tokenizer):
    messages = example["messages"]
    
    # We add an empty system message if there is none
    if messages[0]["role"] != "system":
        messages.insert(0, {"role": "system", "content": ""})
        
    example["text"] = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=False
    )
    return example

In [None]:
raw_datasets = raw_datasets.map(apply_chat_template_for_sft, fn_kwargs={"tokenizer": tokenizer})

In [None]:
raw_datasets

In [None]:
print(raw_datasets["train"][500]["text"])

In [None]:
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]

In [None]:
# with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
#     for index in random.sample(range(len(raw_datasets["train"])), 3):
#         print(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")

## Step 4: Define Base Model

In [None]:
torch_dtype = (
    model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
torch_dtype

In [None]:
def get_current_device() -> int:
    """Get the current device. For GPU we return the local process index to enable multiple GPU training."""
    # return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
    return 0
get_current_device()

In [None]:
def get_kbit_device_map() -> dict[str, int] | None:
    """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
    return {"": get_current_device()} if torch.cuda.is_available() else None

get_kbit_device_map()

In [None]:
model_args.load_in_4bit, model_args.load_in_8bit, model_args.bnb_4bit_quant_type, model_args.use_bnb_nested_quant

In [None]:
if model_args.load_in_4bit:
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,  # For consistency with model weights, we use the same value as `torch_dtype` which is float16 for PEFT models
        bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
        bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
    )
elif model_args.load_in_8bit:
    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
    )
else:
    quantization_config = None
    
quantization_config

In [None]:
print(model_args.model_revision, model_args.trust_remote_code, model_args.use_flash_attention_2, torch_dtype, training_args.gradient_checkpointing)

In [None]:
model_kwargs = dict(
    revision=model_args.model_revision,
    trust_remote_code=model_args.trust_remote_code,
    use_flash_attention_2=model_args.use_flash_attention_2,
    torch_dtype=torch_dtype,
    use_cache=False if training_args.gradient_checkpointing else True,
    device_map=get_kbit_device_map(),
    quantization_config=quantization_config,
)

## Step 5: Initialize Trainer

In [None]:
print(model_args.use_peft, model_args.lora_r, model_args.lora_alpha, model_args.lora_dropout, model_args.lora_target_modules, model_args.lora_modules_to_save)

In [None]:
if model_args.use_peft is False:
    peft_config = None

peft_config = LoraConfig(
    r=model_args.lora_r,
    lora_alpha=model_args.lora_alpha,
    lora_dropout=model_args.lora_dropout,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=model_args.lora_target_modules,
    modules_to_save=model_args.lora_modules_to_save,
)

peft_config

In [None]:
trainer = SFTTrainer(
    model=model_args.model_name_or_path,
    model_init_kwargs=model_kwargs,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    dataset_text_field="text",
    max_seq_length=training_args.max_seq_length,
    tokenizer=tokenizer,
    packing=True,
    peft_config=peft_config,
)

## Step 6: Train

In [None]:
def train():
    set_seed(training_args.seed)
    
    deepspeed_plugin = DeepSpeedPlugin(
        offload_optimizer_device=None,
        offload_param_device=None,
        zero3_init_flag=True,
        zero3_save_16bit_model=True,
        zero_stage=3
    )
    
    accelerator = Accelerator(mixed_precision="bf16", deepspeed_plugin=deepspeed_plugin)
    
    
    train_result = trainer.train()

In [None]:
notebook_launcher(train, mixed_precision="bf16", num_nodes=1, num_processes=2)