# Finetune Mistral-7B on Vertex AI

[Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) is a large language model (LLM) developed by `Mistral AI` and is an instruct fine-tuned version of [Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1).

In this tutorial you will learn how to finetune [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) on Vertex AI. 


What you'll learn in this tutorial:

1. [Setup development environment](#1-setup-development-environment)
2. [Load Dataset](#2-load-dataset)
3. [Fine-tune Mistral-7b using `trl` and `SFTTrainer`](#3-fine-tune-mistral-7b-using-trl-and-sfttrainer)

## 1. Setup development environment


In this example, we will use the Vertex AI Workbench instance with A100 and the [Hugging Face Deep Learning Containers](https://cloud.google.com/deep-learning-containers/docs/choosing-container#hugging-face). The Hugging Face PyTorch DLC comes with all important libraries, like Transformers, Datasets, PEFT, TRL and other packages pre-installed this makes it super easy to get started, since there is no need for environment management. You can now find all Hugging Face containers on [Google Cloud](https://cloud.google.com/deep-learning-containers/docs/choosing-container#hugging-face).


**ToDo**: Add info on how to spin-up a workbench instance or small intro about Vertex AI Workbench Instance.

**ToDo**: Update the link for the image once, GPU containers are released. 


Once the instance is up and running, we can access a Jupyter environment, which we can use for preparing our dataset and launching the training.

## 2. Load and prepare Dataset

We will use [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k) an open source dataset of instruction-following records on categories outlined in the [InstructGPT paper](https://arxiv.org/abs/2203.02155), including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization.\n

```python
{
  "instruction": "What is world of warcraft",
  "context": "",
  "response": "World of warcraft is a massive online multi player role playing game. It was released in 2004 by bizarre entertainment"
}
```
To load and preprocess the `Dolly` dataset, we use the 🤗 Datasets library.

In [4]:
from datasets import load_dataset

To instruct tune our model we need to convert our structured examples into a collection of tasks described via instructions. We define a `formatting_function` that takes a sample and returns a string with our format instruction.

In [5]:
def format_dolly(sample):
    instruction = f"### Instruction\n{sample['instruction']}"
    context = (
        f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
    )
    response = f"### Answer\n{sample['response']}"
    # join all the parts together
    prompt = "\n\n".join(
        [i for i in [instruction, context, response] if i is not None]
    )
    sample["text"] = prompt
    return sample



In [6]:
raw_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

Downloading readme: 100%|██████████████████████████████████████████████████████| 8.20k/8.20k [00:00<00:00, 24.1MB/s]
Downloading data: 100%|████████████████████████████████████████████████████████| 13.1M/13.1M [00:00<00:00, 16.9MB/s]
Generating train split: 15011 examples [00:00, 262035.99 examples/s]


Before applying formatting on our entire dataset, lets test our formatting function on a random example.


In [11]:
from random import randrange

print(format_dolly(raw_dataset[randrange(len(raw_dataset))]))

{'instruction': 'Who is the antagonist on The X-Files?', 'context': '', 'response': 'There are many antagonists on The X-Files, but the most long-running individual antagonist is The Cigarette Smoking Man (CSM) also known as C.G.B. Spender. He was primarily responsible for orchestrating conspiracies. He was a member of "The Syndicate," which was a mysterious shadow government group that covered up the existence of extraterrestrial life.', 'category': 'classification', 'text': '### Instruction\nWho is the antagonist on The X-Files?\n\n### Answer\nThere are many antagonists on The X-Files, but the most long-running individual antagonist is The Cigarette Smoking Man (CSM) also known as C.G.B. Spender. He was primarily responsible for orchestrating conspiracies. He was a member of "The Syndicate," which was a mysterious shadow government group that covered up the existence of extraterrestrial life.'}


We can see that the dataset was properly formatted and everything has been appended into one field.

In [26]:
# apply prompt template
format_dataset = raw_dataset.map(
    format_dolly, remove_columns=list(raw_dataset.features)
)

# select only 2500 examples for faster training
format_dataset = format_dataset.shuffle(seed=42).select(range(2500))

## 3. Fine-tune Mistral-7b using `trl` and `SFTTrainer`

We will use the [SFTTrainer](https://huggingface.co/docs/trl/en/sft_trainer) from  🤗 `trl` to fine-tune our model. The `SFTTrainer`  is built on top of the 🤗 Transformers `Trainer` and inherits all the core functionalities like logging, evaluation, and checkpointing, but offers additional enhancements like:

- Packing datasets for more efficient training
- PEFT (parameter-efficient fine-tuning) support including Q-LoRA
- Preparing the model and tokenizer for conversational fine-tuning (e.g. adding special tokens)

You can read about it in the [trl docs](https://huggingface.co/docs/trl/en/sft_trainer)


As, we all know LLMs are known to be large, and running or training them in consumer hardware is a huge challenge for users and accessibility. Therefore, we  are going to use [QLoRA](https://arxiv.org/abs/2106.09685), a technqiue technique to reduce the memory footprint of LLMs during finetuning, without sacrificing performance. How it works: 

- Quantize the pretrained model to 4 bits and freezing it.
- Attach small, trainable adapter layers. (LoRA)
- Finetune only the adapter layers, while using the frozen quantized model for context.

To further enhance training efficiency, we'll incorporate a recently introduced, high-performance attention mechanism `Flash Attention 2` alongside `QLoRA`. It is nicely integrated with Transformers. It is up to 3x faster than the standard attention mechanism

In [17]:
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from transformers import TrainingArguments, Trainer
import torch
from trl import SFTTrainer

[2024-03-25 20:41:25,549] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [18]:
model_id = "mistralai/Mistral-7B-Instruct-v0.2" # Hugging Face model id

In [21]:
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,                     #  quantize the model to 4-bits when you load it
    bnb_4bit_quant_type="nf4",             # use a special 4-bit data type for weights initialized from a normal distribution
    bnb_4bit_use_double_quant=True,        # use a nested quantization scheme to quantize the already quantized weights
    bnb_4bit_compute_dtype=torch.bfloat16, # Use float16 when running on a GPU(T4, V100) where bfloat16 is not supported
)                                          # conversion from bfloat16 to float16 may lead to overflow (and opposite may lead to loss of precision)

In [22]:
# Load model
model = AutoModelForCausalLM.from_pretrained(model_id, 
                                             quantization_config=bnb_config, 
                                             device_map="auto",
                                             attn_implementation="flash_attention_2"  # use flash-attention-2 for faster training
                                            )

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

Downloading shards: 100%|█████████████████████████████████████████████████████████████| 3/3 [00:54<00:00, 18.11s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████| 3/3 [00:05<00:00,  1.99s/it]


For using QLoRA with SFTTrainer, we need to create our LoraConfig and pass it as an argument to the SFTTrainer.

In [23]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    task_type="CAUSAL_LM", 
)

Before we can start our training we need to define the hyperparameters (TrainingArguments) we want to use.

In [24]:
training_args = TrainingArguments(
    output_dir = "output",               # directory to save trained model
    num_train_epochs = 1,                # number of training epochs
    learning_rate = 2e-4,                # learning rate for training
    optim="paged_adamw_8bit",            # optimizer for training
    per_device_train_batch_size = 1,     # batch size per device during training
    gradient_accumulation_steps = 4,     # Number of steps to accumulate gradients before updating the model
    logging_steps = 10,                   # log every 10 steps
    bf16 = True                          # Use float16 when running on a GPU(T4, V100) where bfloat16 is not supported
                                         # conversion from bfloat16 to float16 may lead to overflow (and opposite may lead to loss of precision)                                       
)


In [27]:
## Initialize the trl SFTTrainer
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = format_dataset,
    dataset_text_field = "text", # field that contains the text in the dataset
    args = training_args,
    peft_config = peft_config,
)

Map: 100%|█████████████████████████████████████████████████████████████| 2500/2500 [00:00<00:00, 7781.95 examples/s]


In [28]:
# start training
trainer.train()

# save model
trainer.save_model()

The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in torch.bfloat16.


Step,Training Loss
10,2.2079
20,1.9515
30,1.7392
40,1.6536
50,1.7724
60,1.5581
70,1.5356
80,1.5328
90,1.5113
100,1.6441


In [None]:
# free the memory again
del model
del trainer
torch.cuda.empty_cache()