# Finetune an OSS model for out bot

We will use the [trl]() library to make our life easy! Most of the code comes from the official [trl finetune example](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py)

In [1]:
# !pip install accelerate transformers datasets bitsandbytes peft trl

In [2]:
from dataclasses import dataclass, field
from typing import Optional

import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, HfArgumentParser, TrainingArguments

import wandb

from trl import SFTTrainer

What is really handy here is the data preprocessing that is baked into the `SFTTrainer` class, this trainer is a thing wrapper around the transformer's `Trainer` but adds the necessary preprocessing needed to format and pack our instruction dataset.

## Data

We will grab our dataset previously created

In [3]:
training_data_path = "dataset/"

In [4]:
# by default the split is called train
ds = load_dataset("json", data_files=f"{training_data_path}/*.json")["train"].shuffle()

In [5]:
ds

Dataset({
    features: ['user', 'answer'],
    num_rows: 616
})

In [6]:
ds[0:3]

{'user': ['Just...',
  "We'll be able to get out of here. We'll be able to get out of here. We'll be able to get out of here.",
  'So you have to dry them all the way from the home side.'],
 'answer': ['other()', 'other()', 'other()']}

In [7]:
splitted_ds = ds.train_test_split(test_size=0.1)

Let's save this split in Hugging Face dataset format (fast parquet files unde the hood)

In [8]:
splitted_ds.save_to_disk(f"{training_data_path}/split_dataset")

Saving the dataset (0/1 shards):   0%|          | 0/554 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/62 [00:00<?, ? examples/s]

Let's save this to W&B

In [9]:
# with wandb.init(project="otto", job_type="data_split"):
#     at = wandb.Artifact(name="split_dataset",
#                         type="dataset",
#                         description="The generated data splitted in 90/10")
#     at.add_dir(f"{training_data_path}/split_dataset")
#     wandb.log_artifact(at)

In [10]:
DATASET_ARTIFACT = 'capecape/otto/split_dataset:v2'

In [11]:
from datasets import load_from_disk
def load_from_artifact(at_address, type="dataset"):
    "Load the dataset from an Artifact"
    if wandb.run is not None:
        artifact = wandb.use_artifact(at_address, type=type)
    else:
        from wandb import Api
        api = Api()
        artifact = api.artifact(at_address, type=type)
    artifact_dir = artifact.download()
    return load_from_disk(artifact_dir)

## Prepare data for Training

> Depending on the model you will need to change this formatting function

We will train a Llama2 model from MetaAI, depending if it is the `chat` or `vanilla` version, you will need to format your instructions differently. My to go place to find these format is the hugginface model card (but many times it is missing), the official paper (can be hard to find) or the [Axolotl training library](https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/prompt_strategies/llama2_chat.py)

In [12]:
def _create_prompt(user, answer=""):
    "Format the prompt to style"
    return ("Below is an instruction that describes a task. Write a response that appropriately completes the request.\n"
            "### User: {user}\n"
            "### Answer: {answer}").format(user=user, answer=answer)

def create_prompt(row): return _create_prompt(**row)

In [13]:
print(create_prompt(ds[0]))

Below is an instruction that describes a task. Write a response that appropriately completes the request.
### User: Just...
### Answer: other()


In [14]:
tqdm.pandas()

MODEL_NAME = 'meta-llama/Llama-2-7b-hf'

# Define and parse arguments.
@dataclass
class ScriptArguments:
    """
    The name of the Casual LM model we wish to fine with SFTTrainer
    """

    model_name: Optional[str] = field(default=MODEL_NAME, metadata={"help": "the model name"})
    dataset_artifact: Optional[str] = field(
        default="otto dataset", metadata={"help": "the dataset name"}
    )
    log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"})
    learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
    batch_size: Optional[int] = field(default=2, metadata={"help": "the batch size"})
    seq_length: Optional[int] = field(default=256, metadata={"help": "Input sequence length"})
    gradient_accumulation_steps: Optional[int] = field(
        default=16, metadata={"help": "the number of gradient accumulation steps"}
    )
    load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "load the model in 8 bits precision"})
    load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
    use_peft: Optional[bool] = field(default=True, metadata={"help": "Wether to use PEFT or not to train adapters"})
    trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
    output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
    peft_lora_r: Optional[int] = field(default=64, metadata={"help": "the r parameter of the LoRA adapters"})
    peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
    logging_steps: Optional[int] = field(default=1, metadata={"help": "the number of logging steps"})
    use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"})
    # num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
    max_steps: Optional[int] = field(default=100, metadata={"help": "the number of training steps"})
    save_steps: Optional[int] = field(
        default=100, metadata={"help": "Number of updates steps before two checkpoint saves"}
    )
    save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
    push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
    hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})

## Model

We can load the model with all the bells and whistles from Transformers!

In [15]:
script_args = ScriptArguments()
script_args

ScriptArguments(model_name='meta-llama/Llama-2-7b-hf', dataset_artifact='otto dataset', log_with='wandb', learning_rate=1.41e-05, batch_size=2, seq_length=256, gradient_accumulation_steps=16, load_in_8bit=True, load_in_4bit=False, use_peft=True, trust_remote_code=False, output_dir='output', peft_lora_r=64, peft_lora_alpha=16, logging_steps=1, use_auth_token=True, max_steps=100, save_steps=100, save_total_limit=10, push_to_hub=False, hub_model_id=None)

In [16]:
# parser = HfArgumentParser(ScriptArguments)
# script_args = parser.parse_args_into_dataclasses()[0]

In [17]:
# Step 1: Load the model
if script_args.load_in_8bit and script_args.load_in_4bit:
    raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif script_args.load_in_8bit or script_args.load_in_4bit:
    quantization_config = BitsAndBytesConfig(
        load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
    )
    # Copy the model to each device
    device_map = {"": Accelerator().local_process_index}
    torch_dtype = torch.bfloat16
else:
    device_map = None
    quantization_config = None
    torch_dtype = None

model = AutoModelForCausalLM.from_pretrained(
    script_args.model_name,
    quantization_config=quantization_config,
    device_map=device_map,
    trust_remote_code=script_args.trust_remote_code,
    torch_dtype=torch_dtype,
    use_auth_token=script_args.use_auth_token,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [18]:
# Step 3: Define the training arguments
training_args = TrainingArguments(
    output_dir=script_args.output_dir,
    per_device_train_batch_size=script_args.batch_size,
    per_gpu_eval_batch_size=script_args.batch_size,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
    learning_rate=script_args.learning_rate,
    logging_steps=script_args.logging_steps,
    # num_train_epochs=script_args.num_train_epochs,
    max_steps=script_args.max_steps,
    report_to=script_args.log_with,
    save_steps=script_args.save_steps,
    save_total_limit=script_args.save_total_limit,
    push_to_hub=script_args.push_to_hub,
    hub_model_id=script_args.hub_model_id,
)

# Step 4: Define the LoraConfig
if script_args.use_peft:
    peft_config = LoraConfig(
        r=script_args.peft_lora_r,
        lora_alpha=script_args.peft_lora_alpha,
        bias="none",
        task_type="CAUSAL_LM",
    )
else:
    peft_config = None

Now we need to instantiate the `SFTTrainer` with the correct preprocessing:
- We want to pack sequences to a certain length (longer means more memory usage)
- We want to tokenize
- We also want to apply our prompt

In [19]:
script_args.seq_length

256

In [20]:
import evaluate
import numpy as np

def token_accuracy(eval_preds):
    accuracy = evaluate.load("accuracy")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return accuracy.compute(predictions=predictions.reshape(-1), references=labels.reshape(-1))

In [21]:
training_args.eval_steps = 5
training_args.evaluation_strategy = "steps"

In [None]:
wandb.init(project="otto", job_type="finetune")
    
ds = load_from_artifact(DATASET_ARTIFACT)
    
# Step 5: Define the Trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    args=training_args,
    max_seq_length=script_args.seq_length,
    packing=True,
    formatting_func=create_prompt,
    peft_config=peft_config,
    compute_metrics=token_accuracy,
)

[34m[1mwandb[0m: Currently logged in as: [33mcapecape[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   7 of 7 files downloaded.  


to be sure, let's check the dataloader

In [None]:
dl = trainer.get_train_dataloader()
b = next(iter(dl))
trainer.tokenizer.decode(b["input_ids"][0])

Let's sample from the model during Training, to do this we will add a custom WandbCallback that has access to the Trainer object (and model and tokenizer). Normally, callback don't have access to these, and that's why we need to add it to the instantiated Trainer.

In [None]:
from functools import partial
from transformers import GenerationConfig, Trainer
from transformers.integrations import WandbCallback

def has_exisiting_wandb_callback(trainer: Trainer):
    for item in trainer.callback_handler.callbacks:
        if isinstance(item, WandbCallback):
            return True
    return False

def _generate(prompt, model, tokenizer, gen_config):
    tokenized_prompt = tokenizer(prompt, return_tensors='pt')['input_ids'].cuda()
    with torch.inference_mode():
        output = model.generate(inputs=tokenized_prompt, 
                                generation_config=gen_config)
    return tokenizer.decode(output[0][len(tokenized_prompt[0]):], skip_special_tokens=True)


class LLMSampleCB(WandbCallback):
    def __init__(self, trainer, test_dataset, num_samples=10, max_new_tokens=256):
        super().__init__()
        self.sample_dataset = test_dataset.select(range(num_samples))
        self.gen_config = GenerationConfig.from_pretrained(trainer.model.name_or_path,
                                                           max_new_tokens=max_new_tokens)
        self.generate = partial(_generate, 
                                model=trainer.model, 
                                tokenizer=trainer.tokenizer, 
                                gen_config=self.gen_config)
        
        #  we need to know if a wandb callback already exists
        if has_exisiting_wandb_callback(trainer):
            # if it does, we need to remove it
            trainer.callback_handler.pop_callback(WandbCallback)

    def log_generations_table(self, examples):
        records_table = wandb.Table(columns=["prompt", "generation"] + list(self.gen_config.to_dict().keys()))
        for example in tqdm(examples, leave=False):
            prompt = example["text"]
            generation = self.generate(prompt=prompt[-1000:])
            records_table.add_data(prompt, generation, *list(self.gen_config.to_dict().values()))
        self._wandb.log({"sample_predictions":records_table})
    
    def on_evaluate(self, args, state, control,  **kwargs):
        super().on_evaluate(args, state, control, **kwargs)
        self.log_generations_table(self.sample_dataset)

In [None]:
create_test_prompt = lambda row: {"text": _create_prompt(row["user"], "")}

test_dataset = ds["test"].map(create_test_prompt)

In [None]:
wandb_cb = LLMSampleCB(trainer, test_dataset=test_dataset, num_samples=4, max_new_tokens=256)
trainer.add_callback(wandb_cb)

In [None]:
trainer.train()

In [None]:
wandb.finish()

In [None]:
# Step 6: Save the model
trainer.save_model(script_args.output_dir)