# Fine-tuning Mistral 7B for Andalusian Spanish

This notebook compiles all necessary steps to apply language transfer via fine-tuning to Mistral 7B.

## Acknowledgements

This code is a simplified and modified version of BrevDev's [Fine-tuning Mistral on your own data 🤙](https://github.com/brevdev/notebooks/blob/main/mistral-finetune-own-data.ipynb). Check it out if you want a more detailed explanation of the process executed here.

## General imports

In [None]:
import os
from datetime import datetime

## Settings

In [None]:
# Set run mode
mode = 'preprod' # 'preprod' for hyperparameter testing or 'prod' for final modelling

In [None]:
# Set max token window for each run mode
max_tokens = {
    "preprod" : 125,
    "prod" : 1250
}

In [None]:
# Set datasets location
datasets = {    
        "train" : os.path.join('data', 'processed', f'conversations_2E_ES_AND_{mode}_train.jsonl'),
        "eval" : os.path.join('data', 'processed', f'conversations_2E_ES_AND_{mode}_val.jsonl'),
}


`r` is the rank of the low-rank matrix used in the adapters, which thus controls the number of parameters trained. A higher rank will allow for more expressivity, but there is a compute tradeoff. The LoRA paper uses the values 1, 2, 4, 8 and 64.

`alpha` is the scaling factor for the learned weights. The weight matrix is scaled by `alpha/r`, and thus a higher value for `alpha` assigns more weight to the LoRA activations. The QLoRA paper reports desirable results with `alpha = 0.5 * r` or `alpha = 0.25 * r`. You can chosse higher values of alpha depending on your needs.

In [None]:
# Set QLoRA parameters
# Recommended parameters based on hyperparameter tuning for prod mode
qlora_parameters = {
    'r' : 1, 
    'alpha' : 0.5,
}

In [None]:
# Set training parameters
training_parameters = {
    'learning_rate' : 2e-4,
    'save_steps' : 50, # Save checkpoints every n steps
    'eval_steps' : 50, # Evaluate every n steps
}

## 0. Install modules

In [None]:
# You only need to run this once per machine
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U datasets scipy ipywidgets matplotlib

## 1. Preparation

### Data imports

In [None]:
from datasets import load_dataset

train_dataset = load_dataset('json', data_files = datasets['train'], split='train')
eval_dataset = load_dataset('json', data_files = datasets['eval'], split='train')

### Accelerator

In [None]:
from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

### HuggingFace

Since [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) is now a gated model, you first need to accept the usage in 🤗 website and log in.

In [None]:
# Log into Hugging Face
# Alternatively, use !huggingface-cli login --token 
from huggingface_hub import notebook_login
notebook_login()

### Weights & Biases

Set up WandB to track training. Skip this section if you don't want to use WandB.

In [None]:
!pip install -q wandb -U

import wandb, os
wandb.login()

wandb_project = "mistral-andalusian"
if len(wandb_project) > 0:
    os.environ["WANDB_PROJECT"] = wandb_project

## 2. Load Base Model

Load Mistral 7B using 4-bit quantization

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model_id = "mistralai/Mistral-7B-v0.1"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="auto")

### Formatting prompts
Creates a `formatting_func` to structure training examples as prompts.

In [None]:
def formatting_func(example):
    text = f"### Preƨunʌa: {example['input']}\n ### Γeьpueьʌa: {example['output']}"
    return text

## 3. Tokenization

Set up the tokenizer. Add padding on the left as it [makes training use less memory](https://ai.stackexchange.com/questions/41485/while-fine-tuning-a-decoder-only-llm-like-llama-on-chat-dataset-what-kind-of-pa).

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    base_model_id,
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
)
tokenizer.pad_token = tokenizer.eos_token

Let's tokenize with padding and truncation, and set up the tokenize function to make labels and input_ids the same. This is basically what [self-supervised fine-tuning is](https://neptune.ai/blog/self-supervised-learning).

In [None]:
max_length = max_tokens[mode]

def generate_and_tokenize_prompt(prompt):
    result = tokenizer(
        formatting_func(prompt),
        truncation=True,
        max_length=max_length,
        padding="max_length",
    )
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt)

Check that `input_ids` is padded on the left with the `eos_token` (2) and there is an `eos_token` 2 added to the end, and the prompt starts with a `bos_token` (1).

In [None]:
print(tokenized_train_dataset[1]['input_ids'])

### 4. Set Up LoRA

Now, to start our fine-tuning, we have to apply some preprocessing to the model to prepare it for training. For that use the `prepare_model_for_kbit_training` method from PEFT.

In [None]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [None]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

We will apply QLoRA to all the linear layers of the model. Those layers are `q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`, and `lm_head`.


Here we define the LoRA config.

`r` is the rank of the low-rank matrix used in the adapters, which thus controls the number of parameters trained. A higher rank will allow for more expressivity, but there is a compute tradeoff.

`alpha` is the scaling factor for the learned weights. The weight matrix is scaled by `alpha/r`, and thus a higher value for `alpha` assigns more weight to the LoRA activations.

The values used in the QLoRA paper were `r=64` and `lora_alpha=16`, and these are said to generalize well, but we will use `r=32` and `lora_alpha=64` so that we have more emphasis on the new fine-tuned data while also reducing computational complexity.

In [None]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r= qlora_parameters['r'],
    lora_alpha= qlora_parameters['alpha'],
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, config)
print_trainable_parameters(model)

See how the model looks different now, with the LoRA adapters added:

In [None]:
print(model)

### 5. Run Training!

In [None]:
if torch.cuda.device_count() > 1: # If more than 1 GPU
    model.is_parallelizable = True
    model.model_parallel = True

In [None]:
model = accelerator.prepare_model(model)

In [None]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

project = "mistral-andalusian"
base_model_name = "mistral"
run_name = base_model_name + "-" + project
output_dir = "./" + run_name

trainer = Trainer(
    model = model,
    train_dataset = tokenized_train_dataset,
    eval_dataset = tokenized_val_dataset,
    args = TrainingArguments(
        output_dir = output_dir,
        warmup_steps = 1,
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 1,
        gradient_checkpointing = True,
        # max_steps = max_steps,
        num_train_epochs  =  1,
        learning_rate= training_parameters['learning_rate'],
        # lr_scheduler_type = "reduce_lr_on_plateau",
        # lr_scheduler_kwargs = {}
        bf16 = True,
        optim = "paged_adamw_8bit",
        load_best_model_at_end = True,
        logging_steps = 10,              # When to start reporting loss
        logging_dir = "./logs",        # Directory for storing logs
        save_strategy = "steps",       # Save the model checkpoint every logging step
        save_steps = training_parameters['save_steps'],               # Save the model checkpoint every n steps
        evaluation_strategy = "steps",  # Evaluate the model every logging step
        eval_steps = training_parameters['eval_steps'],                # Evaluate the model every n steps
        do_eval = True,                # Perform evaluation at the end of training
        report_to = "wandb",           # Comment this out if you don't want to use weights & baises
        run_name = f'{datetime.now().strftime("%Y%m%d%H%M%S")}_mode_{mode}_r_{qlora_parameters["r"]}_alpha_{qlora_parameters["alpha"]}_lr_{training_parameters["learning_rate"]}'
    ),
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

# Save checkpoint after training
trainer.save_model(output_dir + "/final_model")

# Push the final model to the hub
trainer.push_to_hub('jgchaparro/MistrAND-7B-v2') # NOTE: still to be tested