# Llama2 Neurips exploratory model using TRL

In this notebook, I

1. Train a Llama2 13B model on an A100 40GB using TRL & transformers
1. Upload the model weights to the MLFlow tracking server
1. Test & document various ways to optimize model training

The Neurips 2023 competition limits participants in (1) the models that can be used, and (2) the data that can be used. Hence, to be successful in the competition we need to make sure that we:

1. Can evaluate faster than our competitors on the HELM benchmark
1. Squeeze every inch of performance out of the A100 GPU that we're using on GCP. This should be big enough to load the 7B and perhaps the 13B models dependending on the settings. There are a bunch of tricks for this, see [here](https://lightning.ai/pages/community/tutorial/pytorch-memory-vit-llm).
1. Use as much data as we can find. This data needs to be open-source and cannot be machine-generated.

[Lightning](https://lightning.ai/pages/community/tutorial/neurips2023-llm-efficiency-guide/) have made a nice tutorial / getting started guide that supports most models. A [pull request](https://github.com/Lightning-AI/lit-gpt/pull/412) integrating Lit-GPT and HELM has recently been merged so we can use this framework to evaluate our models.

See [this repository](https://github.com/Lightning-AI/lit-gpt) for the starter code used in this notebook.

### GCP

on GCP Vertex workbench, ensure to select the option 'Python 3 (with Intel® MKL)', else we cannot install pytorch nightly properly

To enable monitoring of the GPU, first enable jupyter extensions in the menu to the far-left of the jupyterlab instance.

Then, install the `jupyterlab-nvdashboard` extension. NB: you must restart the notebook server for the dashboard to show up.

### Notes

1. I noticed that there is a [flash attention](https://github.com/huggingface/trl@flash-attn-sft) branch. I tried it out and it slows down training by quite a bit, probably due to the fact that we need to pack sequences and we cannot pad them as we would normally do.

In [None]:
!pip install -q -U bitsandbytes trl mlflow datasets
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q datasets

In [None]:
import torch

type(torch.bfloat16)

In [None]:
import os

import torch
import mlflow
import numpy as np
import pandas as pd
import random
from peft import (
    get_peft_config,
    PeftModel,
    PeftConfig,
    get_peft_model,
    LoraConfig,
    TaskType,
    prepare_model_for_kbit_training
)
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainerCallback,
    TrainingArguments,
)
from trl import SFTTrainer
from datasets import load_dataset, Dataset

In [None]:
# Taken from the lit-gpt repository
def generate_prompt(example):
    """Generates a standardized message to prompt the model with an instruction, optional input and a
    'response' field."""

    if example["context"]:
        return (
            "Below is an instruction that describes a task, paired with an input that provides further context. "
            "Write a response that appropriately completes the request.\n\n"
            f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['context']}\n\n### Response:"
        )
    return (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        f"### Instruction:\n{example['instruction']}\n\n### Response:"
    )

### Load data

We use dolly-15K for now, and format it using the lit-gpt formatting style.

In [None]:
df_dolly = load_dataset("databricks/databricks-dolly-15k")
df_dolly = pd.DataFrame(df_dolly['train'])
df_dolly

In [None]:
df_dolly["prompt"] = df_dolly.apply(generate_prompt, axis=1)
df_dolly["response"] = df_dolly["response"] + "\n### End"
df_dolly = df_dolly[["prompt", "response"]]

In [None]:
df = df_dolly.copy()
df["text"] = df["prompt"] + df["response"]
df.drop(columns=["prompt", "response"], inplace=True)

In [None]:
dataset = Dataset.from_pandas(df).train_test_split(test_size=0.05, seed=42)

### Model & training options

We define:

1. BitsAndBytes configuration for loading the base model efficiently
1. Training options for the SFTtrainer
1. PEFT options (LoRA)

I'm following some best practices that you can find [here](https://huggingface.co/docs/transformers/perf_train_gpu_one). I also take some settings from the [Open Platypus paper](https://arxiv.org/pdf/2308.07317.pdf).

In [None]:
## Model
model_id = "facebook/opt-350m"

In [None]:
## BitsAndBytes
load_in_4bit = True
bnb_4bit_use_double_quant = True
bnb_4bit_quant_type = "nf4"
bnb_4bit_compute_dtype = torch.bfloat16

bnb_config = BitsAndBytesConfig(
    load_in_4bit=load_in_4bit,
    bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=bnb_4bit_compute_dtype
)

In [None]:
## PEFT
# target_modules = [
#     'gate_proj','down_proj', 'up_proj',
#     'k_proj','lm_head', 'q_proj',
#     'v_proj','o_proj'
# ]
target_modules = [
    "k_proj", "q_proj", "v_proj"
]
r = 16
lora_alpha = 16
lora_dropout = 0.05
bias = "none"
task_type = "CAUSAL_LM"

lora_config = LoraConfig(
    r=r,
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    bias=bias,
    target_modules = target_modules,
    task_type=task_type,
)

In [None]:
## Training
# TODO: something with logging & logging artifacts
base_dir = "out"

save_strategy="steps"
save_steps=100
save_total_limit=3

num_train_epochs = 2
evaluation_strategy = "steps"
logging_strategy = "steps"
eval_steps = 100
logging_steps = 25

per_device_train_batch_size = 4
gradient_accumulation_steps = 16 # virtual batch size = 4 * 8 = 64. See https://huggingface.co/docs/transformers/perf_train_gpu_one#batch-size-choice
gradient_checkpointing = True
per_device_eval_batch_size = 2
eval_accumulation_steps = 8
#max_steps=50 # only debugging

learning_rate = 4e-4
lr_scheduler_type = "cosine"
max_grad_norm = 0.3
warmup_steps = 100
optim = 'paged_adamw_8bit' # See: https://huggingface.co/docs/transformers/perf_train_gpu_one#8bit-adam

dataloader_pin_memory = True
dataloader_num_workers = 1

tf32 = False # See: https://huggingface.co/docs/transformers/perf_train_gpu_one#tf32
group_by_length = True

# Error: -- FileNotFoundError: [Errno 2] No such file or directory: 'ldconfig'
torch_compile = False # See: https://huggingface.co/docs/transformers/perf_train_gpu_one#using-torchcompile

if tf32:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

training_args = TrainingArguments(
    output_dir=base_dir,
    evaluation_strategy=evaluation_strategy,
    eval_steps=eval_steps,
    num_train_epochs = num_train_epochs,
    logging_strategy = logging_strategy,
    logging_steps=logging_steps,
    save_strategy=save_strategy,
    save_steps=save_steps,
    save_total_limit=save_total_limit,
    per_device_eval_batch_size=per_device_eval_batch_size,
    eval_accumulation_steps=eval_accumulation_steps,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    learning_rate=learning_rate,
    #bf16=True,
    #max_steps=max_steps,
    tf32=True if tf32 else False,
    fp16=True if not tf32 else False,
    max_grad_norm=max_grad_norm,
    warmup_steps=warmup_steps,
    lr_scheduler_type=lr_scheduler_type,
    group_by_length=group_by_length,
    torch_compile=torch_compile,
    dataloader_pin_memory=dataloader_pin_memory,
    dataloader_num_workers=dataloader_num_workers
)

In [None]:
# Other stuff
#MLFLOW_TRACKING_URI="localhost"
#MLFLOW_EXPERIMENT="jasper-train-testing"
#mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)

### Loading & configuring the model

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# When loading 1st time this will be slow
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto"
)

In [None]:
model = prepare_model_for_kbit_training(model)

In [None]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

### Train the model

In [None]:
import os
from transformers.integrations import MLflowCallback

os.environ["HF_MLFLOW_LOG_ARTIFACTS"] = "1"
os.environ["MLFLOW_EXPERIMENT_NAME"] = "test"

callbacks = [MLflowCallback()]

In [None]:
trainer = SFTTrainer(
    model,
    train_dataset=dataset['train'],
    eval_dataset = dataset['test'],
    dataset_text_field="text",
    max_seq_length=2048,
    args=training_args,
    callbacks=callbacks
)

In [None]:
# mlflow.end_run()

In [None]:
# for name, module in trainer.model.named_modules():
#     if "norm" in name:
#         module = module.to(torch.float32)

In [None]:
# Use callbacks (https://huggingface.co/docs/trl/sft_trainer#trl.SFTTrainer.callbacks) & https://huggingface.co/docs/transformers/v4.33.0/en/main_classes/callback#transformers.integrations.MLflowCallback

with mlflow.start_run():
  trainer.train()

In [None]:
model.save_pretrained(lora_adapter, save_adapter=True, save_config=True)

### Merge model weights

In [None]:
import os

import torch
import transformers
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer  # noqa: F402

## Model
model_id = "meta-llama/Llama-2-13b-hf"

tokenizer = LlamaTokenizer.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# When loading 1st time this will be slow
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    load_in_8bit=False, # Load in full precision
    torch_dtype=torch.float16,
    device_map={"":0},
)

In [None]:
lora_model = PeftModel.from_pretrained(
    model,
    "out/checkpoint-400",
    device_map={"":0},
    torch_dtype=torch.bfloat16,
)

In [None]:
merged_model = lora_model.merge_and_unload()

In [None]:
merged_model

In [None]:
merged_model.save_pretrained("merged/llama2-13B-instruct-dolly")