# Fine-tune Mixtral

## Goal

Learn to fine-tune Mixtral and make predictions with fine-tuned model.

## References

- [Fine-tune Mixtral-8x7B on Your Computer (QLoRA)](https://colab.research.google.com/drive/1VDa0lIfqiwm16hBlIlEaabGVTNB3dN1A?usp=sharing)
- https://huggingface.co/blog/4bit-transformers-bitsandbytes
- [bnb 4bit training](https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing)
- https://huggingface.co/docs/datasets/en/index
- https://huggingface.co/docs/transformers/en/peft
- https://huggingface.co/docs/peft/main/en/conceptual_guides/lora
- https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.merge_and_unload
- [Taller + AMA: Entrenamiento de LLMs, Alejandro Vaca](https://www.youtube.com/watch?v=458UWBlBdtI&t=3494s)
- https://github.com/somosnlp/recursos/blob/main/hackathon_2024/entrenamiento_llm_instrucciones.ipynb
- https://gathnex.medium.com/mistral-7b-fine-tuning-a-step-by-step-guide-52122cdbeca8

> LoRA does not add any inference latency because adapter weights can be merged with the base model.

## Imports

In [None]:
import torch
import gc
import time
import re
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from tqdm.auto import tqdm
import yaml
import os
import hashlib

from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
    pipeline, TrainingArguments
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import Dataset

plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (20, 5)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 16

pd.set_option('display.max_colwidth', 200)

## Load model

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.float16,
    bnb_4bit_use_double_quant= True,
    llm_int8_enable_fp32_cpu_offload= True)

torch.cuda.empty_cache()
gc.collect()

In [None]:
auto_device_map = {
    'model.embed_tokens': 0,
    'model.layers.0': 0,
    'model.layers.1': 0,
    'model.layers.2': 0,
    'model.layers.3': 0,
    'model.layers.4': 0,
    'model.layers.5': 0,
    'model.layers.6': 0,
    'model.layers.7': 0,
    'model.layers.8': 0,
    'model.layers.9': 0,
    'model.layers.10': 0,
    'model.layers.11': 0,
    'model.layers.12': 0,
    'model.layers.13': 0,
    'model.layers.14': 1,
    'model.layers.15': 1,
    'model.layers.16': 1,
    'model.layers.17': 1,
    'model.layers.18': 1,
    'model.layers.19': 1,
    'model.layers.20': 1,
    'model.layers.21': 1,
    'model.layers.22': 1,
    'model.layers.23': 1,
    'model.layers.24': 1,
    'model.layers.25': 1,
    'model.layers.26': 1,
    'model.layers.27': 1,
    'model.layers.28': 1,
    'model.layers.29': 1,
    'model.layers.30': 1,
    'model.layers.31': 1,
    'model.norm': 1,
    'lm_head': 1
 }

def create_shared_device_map(transition_layer):
    shared_device_map = {}
    for idx, key in enumerate(auto_device_map):
        if idx <= transition_layer:
            shared_device_map[key] = 0
        else:
            shared_device_map[key] = 1
    return shared_device_map

def create_intertwined_device_map():
    device_map = {}
    for idx, key in enumerate(auto_device_map):
        if idx == 0:
            device_map[key] = 1
        elif idx >= 33:
            device_map[key] = 0
        else:
            device_map[key] = idx % 2
    return device_map

In [None]:
model_path = '/home/gbarbadillo/data/mixtral-8x7b-instruct-v0.1-hf/'
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=bnb_config,
    device_map=create_shared_device_map(16),
    trust_remote_code=True,
    )

In [None]:
# this solve the problem of not ending the generations
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code=True)
# this is needed to do batch inference
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
gc.collect()

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code=True)
tokenizer.add_special_tokens({'pad_token': '<pad>'})
tokenizer.padding_side = 'right' # by default is left, for training right seems to be better
gc.collect()
model.resize_token_embeddings(len(tokenizer))

In [None]:
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer)

def chat_with_mixtral(prompt, max_new_tokens=200, verbose=True, do_sample=False, temperature=0.7, top_p=0.95):
    if not prompt.startswith('<s>[INST]'):
        print('Formatting the prompt to Mixtral needs.')
        prompt = f'<s>[INST] {prompt} [/INST]'
    start = time.time()

    if do_sample:
        sampling_kwargs = dict(do_sample=True, temperature=temperature, top_p=top_p)
    else:
        sampling_kwargs = dict(do_sample=False)

    sequences = pipe(
        prompt ,
        max_new_tokens=max_new_tokens,
        # https://www.reddit.com/r/LocalLLaMA/comments/184g120/mistral_fine_tuning_eos_and_padding/
        # https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/discussions/106
        pad_token_id=tokenizer.eos_token_id,
        **sampling_kwargs,
        return_full_text=False,
    )
    response = sequences[0]['generated_text']
    #response = re.sub(r'[\'"]', '', response)
    if verbose:
        stop = time.time()
        time_taken = stop-start
        n_tokens = len(tokenizer.tokenize(response))
        print(f"Execution Time : {time_taken:.1f} s, tokens per second: {n_tokens/time_taken:.1f}")
    return response

In [None]:
def print_gpu_memory():
    for device in range(torch.cuda.device_count()):
        print(f'GPU {device} memory allocated: {torch.cuda.memory_allocated(device)/1024**3:.1f} GB, max memory allocated: {torch.cuda.max_memory_allocated(device)/1024**3:.1f} GB')

## Prepare data

In [None]:
prompt_template = """<s>[INST]
Analyze the original and rewritten text and answer with the most likely text prompt that was given to rewrite or make stylistic changes to the original text.

- The text prompt should be a single sentence. Reply just with a short sentence and do not add any notes or comments.
- Sometimes the rewritten text will have hints about the text prompt. For example if it starts by
  Reworded, Rephrased, Translated, etc. you should include that word in the text prompt.
- Unless necessary do not make reference to details in the original text and keep the text prompt abstract and generic.

## Original text

{original_text}

## Rewritten text

{rewritten_text}

## Output format

Let's do the task step by step:

1. On a first step analyze the differences of the texts in less than 30 words.
2. On a second step write the most likely prompt using json format
[/INST] {response} </s>"""

In [None]:
def prepare_dataframe_for_training(filepath, target_col='gpt4_normalized_response'):
    df = pd.read_csv(filepath)
    texts = []
    for _, row in df.iterrows():
        texts.append(prompt_template.format(original_text=row['original_text'],
                                            rewritten_text=row['rewritten_text'],
                                            response=row[target_col]))
    df['text'] = texts
    df['n_tokens'] = df['text'].apply(lambda x: len(tokenizer.tokenize(x)))
    return df

In [None]:
train_df = prepare_dataframe_for_training('/mnt/hdd0/Kaggle/llm_prompt_recovery/data/mooney_test_with_gpt4.csv')
eval_df = prepare_dataframe_for_training('/mnt/hdd0/Kaggle/llm_prompt_recovery/data/gemma_suppl_rewrite_curated_with_gpt4.csv')

In [None]:
bad_indices = [164, 181, 235]
print(train_df.loc[bad_indices].rewritten_text.values)
train_df = train_df.drop(bad_indices)

In [None]:
plt.hist(train_df['n_tokens'], bins=20, alpha=0.5, label='train')
plt.hist(eval_df['n_tokens'], bins=20, alpha=0.5, label='eval')
plt.legend(loc=0)
plt.xlabel('Number of tokens')
plt.title('Token distribution of the texts');

I could train with just 512 tokens of context if I remove ~18% of the data, seems like a good deal.

In [None]:
train_df = train_df[train_df['n_tokens'] <= 512]
eval_df = eval_df[eval_df['n_tokens'] <= 512]
print(f'There are {len(train_df)} samples for training and {len(eval_df)} samples for evaluation.')

For the first experiments let's simply split the data in two.

In [None]:
train_dataset = Dataset.from_pandas(train_df)
eval_dataset = Dataset.from_pandas(eval_df)

## Inference before training

In [None]:
for text in train_df['text'].values[:5]:
    print(chat_with_mixtral(text.split('[/INST]')[0] + '[/INST]'))

In [None]:
for text in eval_df['text'].values[:5]:
    print(chat_with_mixtral(text.split('[/INST]')[0] + '[/INST]'))

## Train

In [None]:
model = prepare_model_for_kbit_training(model)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False # Gradient checkpointing is used by default but not compatible with caching

In [None]:
peft_config = LoraConfig(
    # lora_alpha: LoRA scaling factor.
    lora_alpha=64, #64,
    lora_dropout=0.1, # 0.1, althought Vaca suggested to use 0.05 for big models
    # r: the rank of the update matrices, expressed in int. Lower rank results in smaller update matrices with fewer trainable parameters.
    r=16, #16
    bias="none",
    task_type="CAUSAL_LM",
    # target_modules: The modules (for example, attention blocks) to apply the LoRA update matrices.
    target_modules= ['k_proj', 'q_proj', 'v_proj', 'o_proj']
)

In [None]:
training_arguments = TrainingArguments(
        output_dir="/mnt/hdd0/Kaggle/llm_prompt_recovery/trainings/2024-03-27_bigger_dataset/01_r16_alpha64",
        evaluation_strategy="steps",
        do_eval=True,
        optim="paged_adamw_8bit",
        per_device_train_batch_size=8, # 4-16 should be fine for lora.
        gradient_accumulation_steps=2,
        per_device_eval_batch_size=8,
        log_level="debug",
        save_steps=50, #50,
        logging_steps=50, #50,
        learning_rate=2e-5, # maybe we can increase this
        eval_steps=50, #50,
        max_steps=1000, #300,
        warmup_steps=30,
        lr_scheduler_type="linear",
)

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_arguments,
)

trainer.train()

First train with default parameters

- I have registered to Weights & Biases because it has a free personal account.
- The model is training and it is expected to take 3.5 hours for 300 steps, each GPU is using 17GB of memory.
- The batch size is 16 with the initial configuration

Second train with updated parameters (50 epochs, 1024 max_seq_length)

- Now the model is using 23 and 24 GB of GPU memory
- Training speed was halved due to duplicating the max_seq_length, this implies that if I filter and train only on short samples I could train much faster.
- The model has been trained succesfully for 50 steps in less than one hour
- https://wandb.ai/guillermobarbadillo/huggingface/runs/u6az7ac1

Third train: limit the max seq lenght again to 512 by removing part of the dataset.

- https://wandb.ai/guillermobarbadillo/huggingface/runs/eqx2bm11
- The train is expected to take around 3 hours

4 train. 02_r8_alpha32

- Val loss is slightly higher
- https://wandb.ai/guillermobarbadillo/huggingface/runs/kb27pqnb

Add new pad token

- Is it training faster? Yes, it has trained in 1h25m
- https://wandb.ai/guillermobarbadillo/huggingface/runs/ywc27wou

Final train with all the train data

- Estimated train time is around 4 hours, I believe it will overfit earlier.
- https://wandb.ai/guillermobarbadillo/huggingface/runs/jn7oiuqp?nw=nwuserguillermobarbadillo
- With this setup the best validation loss is achieved around epoch 15.

## Make a few inferences 

In [None]:
for text in train_df['text'].values[:5]:
    print(chat_with_mixtral(text.split('[/INST]')[0] + '[/INST] '))

In [None]:
for text in eval_df['text'].values[:5]:
    print(chat_with_mixtral(text.split('[/INST]')[0] + '[/INST] '))

The conversation does not end correctly, that is weird.

```
Execution Time : 6.5 s, tokens per second: 5.7
 1. The rewritten text uses more playful and engaging language.
2. {prompt: Rewrite the memo with a more engaging tone.}  // json format
Execution Time : 10.1 s, tokens per second: 6.4
 1. The rewritten text is more tentative and polite, using possibly and uncertain.
2. {prompt: Rewrite the message to be more polite and tentative.} 

Note: This prompt is inferred from the changes in the rewritten text and may not be explicitly stated.
Execution Time : 10.1 s, tokens per second: 6.3
 1. The rewritten text changes intricate carvings to cleat-adorned design and emphasizes balance and stability.
2. {prompt: Rewrite the text to emphasize a different aspect.}  // hint: use Rewrite the text to for stylistic changes
Execution Time : 9.8 s, tokens per second: 6.8
 1. The rewritten text changes the subject from a peony to a parser and the action from blooming to parsing.
2. {prompt: Rewrite the sentence with a different subject and action.}  OR  {prompt: Change the subject and action of the sentence.}  [generic prompt]
Execution Time : 6.7 s, tokens per second: 6.3
 1. The rewritten text condenses information, removing some details and making it more concise.
2. {prompt: Summarize the information about the cornet.}  // json format
```


When using pad_token=unk_token it does end the conversation correctly:

```
Execution Time : 5.9 s, tokens per second: 7.4
 1. The rewritten text uses a playful tone and metaphors related to fish and water.
2. {"prompt": "Rewrite the memo in a playful and engaging tone."}
Execution Time : 5.0 s, tokens per second: 6.9
 1. The rewritten text is more formal and uncertain.
2. {"prompt": "Rewrite the text to be more formal and uncertain."}
Execution Time : 8.1 s, tokens per second: 7.8
 1. The rewritten text removes the description of the zither's intricate carvings and replaces it with a reference to its cleat-adorned design.
2. {"prompt": "Rewrite the text to focus on the zither's design."}
Execution Time : 6.1 s, tokens per second: 7.7
 1. The original text is about a peony blooming, while the rewritten text is about a parser working efficiently.
2. {"prompt": "Rewrite the text to be about technology and efficiency."}
Execution Time : 4.6 s, tokens per second: 7.4
 1. The rewritten text is a shorter summary of the original text.
2. {"text prompt": "Rewrite the text as a short summary"}
```

Using a new pad token also results in good generations.
```
Execution Time : 4.7 s, tokens per second: 7.3
 1. The rewritten text is more formal and indirect.
2. {"prompt": "Rewrite the text in a formal and indirect way."}
Execution Time : 5.4 s, tokens per second: 7.6
 1. The original text describes a natural event, while the rewritten text describes a technical process.
2. {"prompt": "Rewrite the text to describe a technical process."}
Execution Time : 5.5 s, tokens per second: 7.5
 1. The rewritten text uses more abstract language and replaces negative words with positive ones.
2. {"prompt": "Rewrite the text using positive language and abstract concepts."}
Execution Time : 7.0 s, tokens per second: 7.6
 1. The rewritten text is more elaborate and descriptive, using metaphors and adjectives to create a vivid image.
2. {"prompt": "Rewrite the text to be more descriptive and imaginative."}
Execution Time : 6.8 s, tokens per second: 7.7
 1. The rewritten text focuses on the transformation of the mural and the captivation of passersby.
2. {"prompt": "Rewrite the text to focus on the mural's impact on the urban landscape."}
```

## TODO

- [x] How to format the dataset? Apparently it is as easy as creating a dataset from a pandas dataframe
- [x] Study token length distribution of the data, how it relates to the max_seq_length?
- [x] Verify the predictions of the model before and after training
- [x] How to save the lora model? The training saves checkpoints
- [x] How to load the lora model? Done on a following notebook
- [ ] rslora? https://huggingface.co/docs/peft/main/en/conceptual_guides/lora#common-lora-parameters-in-peft
- [ ] Reduce alpha, or reduce r as well
- [x] What is the optimal train duration? Probably around 15 epochs
- [ ] Try with Vaca's parameters