### Resources
- [Fine-tune Llama2 with DPO](https://huggingface.co/blog/dpo-trl) | [codebase](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2/scripts)

### Setup

In [1]:
from datasets import load_dataset, Dataset

import torch
from transformers import BitsAndBytesConfig, AutoTokenizer, AutoModelForCausalLM, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer
from ml_collections import config_dict

# huggingface-cli login
# wandb login

### Load dataset

In [2]:
dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    # split="train",
    # data_dir="data/rl"
    streaming=True
)

train_ds = dataset["train"]
test_ds = dataset["test"]

Resolving data files:   0%|          | 0/72 [00:00<?, ?it/s]

In [3]:
dataset

IterableDatasetDict({
    train: IterableDataset({
        features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
        n_shards: 72
    })
    test: IterableDataset({
        features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
        n_shards: 12
    })
})

### Sample dataset

In [4]:
sample_size = 1000
train_sample_data = []
test_sample_data = []

for i, example in enumerate(train_ds):
    if i == sample_size:
        break

    train_sample_data.append(example)

    if (i + 1) % 1000 == 0:
        print(f"[INFO] processing {i+1} of {sample_size}...")

for i, example in enumerate(test_ds):
    if i == sample_size:
        break

    test_sample_data.append(example)

    if (i + 1) % 1000 == 0:
        print(f"[INFO] processing {i+1} of {sample_size}...")

[INFO] processing 1000 of 1000...
[INFO] processing 1000 of 1000...


In [5]:
train_ds_sample = Dataset.from_list(train_sample_data)
test_ds_sample = Dataset.from_list(test_sample_data)

In [6]:
print(train_ds_sample)
print(test_ds_sample)

Dataset({
    features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
    num_rows: 1000
})
Dataset({
    features: ['qid', 'question', 'date', 'metadata', 'response_j', 'response_k'],
    num_rows: 1000
})


### Preprocessing

In [7]:
def return_prompt_and_responses(samples):
    output = {
        "prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
        "chosen": samples["response_j"],
        "rejected": samples["response_k"]
    }

    return output

In [8]:
original_columns = train_ds_sample.column_names
original_columns

['qid', 'question', 'date', 'metadata', 'response_j', 'response_k']

In [9]:
train_ds_sample_prepared = train_ds_sample.map(
    return_prompt_and_responses,
    batched=True,
    # batch_size=1000,
    remove_columns=original_columns
)

test_ds_sample_prepared = test_ds_sample.map(
    return_prompt_and_responses,
    batched=True,
    # batch_size=1000,
    remove_columns=original_columns
)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [10]:
train_ds_sample_prepared

Dataset({
    features: ['prompt', 'chosen', 'rejected'],
    num_rows: 1000
})

### Supervised Fine-tuning

In [11]:
script_args = config_dict.ConfigDict()
script_args.model_name = "meta-llama/Llama-2-7b-hf"
script_args.lora_r = 8
script_args.lora_alpha = 16
script_args.lora_dropout = 0.05

In [12]:
def formatting_func(example):
    text = example["prompt"] + example["chosen"]
    return text

# print(formatting_func(train_ds_sample_prepared[1]))

In [13]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)

model_kwargs = dict(
    device_map={"": 0},
    trust_remote_code=True,
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
    use_cache=False,
    quantization_config=bnb_config
)

base_model = AutoModelForCausalLM.from_pretrained(script_args.model_name, **model_kwargs)

# base_model = AutoModelForCausalLM.from_pretrained(
#     script_args.model_name,
#     quantization_config=bnb_config,
#     device_map={"": 0},
#     trust_remote_code=True,
#     token=True
# )

# base_model.config.use_cache = False

tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_size = "right"

peft_config = LoraConfig(
    r=script_args.lora_r,
    lora_alpha=script_args.lora_alpha,
    lora_dropout=script_args.lora_dropout,
    target_modules=["q_proj", "v_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)

training_args = TrainingArguments(
    output_dir="./sft",
    max_steps=500,
    logging_steps=10,
    save_steps=10,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    gradient_checkpointing=False,
    group_by_length=False,
    learning_rate=1e-4,
    lr_scheduler_type="cosine",
    warmup_steps=100,
    weight_decay=0.05,
    optim="paged_adamw_32bit",
    bf16=True,
    remove_unused_columns=False,
    run_name="sft_llama2",
    report_to="wandb"
)
    
trainer = SFTTrainer(
    model=base_model,
    train_dataset=train_ds_sample_prepared,
    eval_dataset=test_ds_sample_prepared,
    peft_config=peft_config,
    packing=True,
    max_seq_length=None,
    tokenizer=tokenizer,
    args=training_args,
    formatting_func=formatting_func
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [14]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mmatt24[0m. Use [1m`wandb login --relogin`[0m to force relogin


The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


Step,Training Loss


KeyboardInterrupt: 