In [None]:
!pip3 install datasets==2.13.1 transformers==4.39.1 peft==0.10.0 bitsandbytes==0.43.0 accelerate==0.28.0 flash-attn==2.5.6

In [None]:
import os
import torch
from functools import partial
from datasets import load_dataset

from transformers import Trainer
from transformers import TrainingArguments
from transformers import AutoTokenizer
from transformers import BitsAndBytesConfig
from transformers import AutoModelForCausalLM

from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

os.environ["WANDB_DISABLED"] = "true"

In [None]:
# Quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
)

In [None]:
# create the model
base_model = "microsoft/phi-2"

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    attn_implementation="flash_attention_2",
    quantization_config=bnb_config,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map={"": 0}
)

In [None]:
# get the tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=False)
tokenizer.add_tokens(["<|im_start|>", "<PAD>"])
tokenizer.pad_token = "<PAD>"
tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))
model.config.eos_token_id = tokenizer.eos_token_id

In [None]:
# load the dataset
templates = [
    "<|im_start|>assistant\n{msg}<|im_end|>",
    "<|im_start|>user\n{msg}<|im_end|>"
]
IGNORE_INDEX = -100


def tokenize(input, max_length, tokenizer):
    input_ids, attention_mask, labels = [], [], []

    for i, msg in enumerate([input["command"], input["cfr"]]):
        isHuman = i % 2 == 0

        msg_chatml = templates[isHuman].format(msg=msg)
        msg_tokenized = tokenizer(
            msg_chatml, truncation=False, add_special_tokens=False)

        input_ids += msg_tokenized["input_ids"]
        attention_mask += msg_tokenized["attention_mask"]
        labels += [IGNORE_INDEX] * len(msg_tokenized["input_ids"]
                                       ) if isHuman else msg_tokenized["input_ids"]

    return {
        "input_ids": input_ids[:max_length], # tokens of the text
        "attention_mask": attention_mask[:max_length], # binary tensor that indicates which tokens should be attended
        "labels": labels[:max_length], # ground truth tokens
    }


def prepare_dataset(tokenizer):

    dataset = load_dataset("dataset", split="train")

    return dataset.map(
        partial(tokenize, max_length=1024, tokenizer=tokenizer),
        batched=False,
        num_proc=os.cpu_count(),
        remove_columns=dataset.column_names
    )


# collate to prepare a batch of data for input to a LLM for training
def collate(elements, tokenizer):
    tokens = [e["input_ids"] for e in elements]
    tokens_maxlen = max([len(t) for t in tokens])

    for i, sample in enumerate(elements):
        input_ids = sample["input_ids"]
        attention_mask = sample["attention_mask"]
        labels = sample["labels"]

        pad_len = tokens_maxlen - len(input_ids)

        input_ids.extend(pad_len * [tokenizer.pad_token_id])
        attention_mask.extend(pad_len * [0])
        labels.extend(pad_len * [IGNORE_INDEX])

    batch = {
        "input_ids": torch.tensor([e["input_ids"] for e in elements]),
        "labels": torch.tensor([e["labels"] for e in elements]),
        "attention_mask": torch.tensor([e["attention_mask"] for e in elements]),
    }

    return batch

dataset = prepare_dataset(tokenizer)

In [None]:
# LoRA configuration
lora_config = LoraConfig(

    # Lora attention dimension
    r=32,

    # Scaling factor that changes how the adaptation layer's weights affect the base model's.
    # Higher alpha means the LoRA layers act more strongly on the base model
    lora_alpha=32,

    # List of module names or regex expression of the module names to replace with LoRA
    target_modules=["Wqkv", "out_proj", "fc1", "fc2"],

    # List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint
    modules_to_save=["lm_head", "embed_tokens"],

    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)

In [None]:
model = prepare_model_for_kbit_training(
    model, use_gradient_checkpointing=True)
model.config.use_cache = False
model = get_peft_model(model, lora_config)

In [None]:
# Set training arguments
training_arguments = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    max_steps=-1,
    fp16=False,
    bf16=False,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=1,
    gradient_checkpointing=True,
    learning_rate=2e-4,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    weight_decay=0.001,
    optim="paged_adamw_32bit",
    lr_scheduler_type="cosine",
    group_by_length=True,
    save_steps=0,
    logging_steps=20,
)

# Set supervised fine-tuning parameters
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=training_arguments,
    data_collator=partial(collate, tokenizer=tokenizer),
    train_dataset=dataset,
)

In [None]:
# Train model
trainer.train()

# Save trained model
new_model = "phi-2-cfr-lora"
trainer.model.save_pretrained(new_model)

# Save full model
model.save_pretrained(new_model)
tokenizer.save_pretrained(new_model)