# Finetune Mistral-7B on Dolly using 🤗 peft, trl, bitsandbytes, Flash Attention 2 & transformers\n

This notebook runs on top of the image built using this Dockerfile: [GitHub Link](https://github.com/huggingface/Google-Cloud-Containers/blob/main/containers/pytorch/training/gpu/2.1/transformers/4.38.1/py310/Dockerfile)

Using this image you don't need to install any packages, as all needed packages are already there.

## Import libraries and specify model to use

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, DataCollatorForLanguageModeling
from peft import LoraConfig
from trl import SFTTrainer

In [None]:
model_id = "mistralai/Mistral-7B-v0.1"

## Load and prepare Dataset

We will use [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k) an open source dataset of instruction-following records on categories outlined in the [InstructGPT paper](https://arxiv.org/abs/2203.02155), including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization.\n

```python
{
  "instruction": "What is world of warcraft",
  "context": "",
  "response": "World of warcraft is a massive online multi player role playing game. It was released in 2004 by bizarre entertainment"
}
```
To load and preprocess the `Dolly` dataset, we use the 🤗 Datasets library.

In [None]:
def format_dolly(sample):
    instruction = f"### Instruction\n{sample['instruction']}"
    context = (
        f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
    )
    response = f"### Answer\n{sample['response']}"
    # join all the parts together
    prompt = "\n\n".join(
        [i for i in [instruction, context, response] if i is not None]
    )
    sample["text"] = prompt
    return sample



In [None]:
raw_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
# apply prompt template
format_dataset = raw_dataset.map(
    format_dolly, remove_columns=list(raw_dataset.features)
)

In [None]:
# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

##  Fine-Tune Mistral 7B with QLoRA

In [None]:
# BitsAndBytes 4bit config

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_id, 
                                             quantization_config=bnb_config, 
                                             device_map="auto",
                                             attn_implementation="flash_attention_2"
                                            )

In [None]:
# LoRA config

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=32,
    bias="none",
    task_type="CAUSAL_LM", 
)

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="output",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    num_train_epochs=3,
    logging_strategy="steps",
    logging_steps=20,
    bf16=True,
    optim="paged_adamw_8bit",
    
)

# Initialize our Trainer
trainer = SFTTrainer(
    model=model,
    peft_config=peft_config,
    args=training_args,
    dataset_text_field="text",
    packing=True,
    train_dataset=format_dataset,
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)
# Train the model
trainer.train()

# save model
trainer.save_model()