# Fine-tune Mistral

## Goal

Which results do we get if we fine-tune Mistral?

## Imports

In [None]:
import torch
import gc
import time
import re
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from tqdm.auto import tqdm
import yaml
import os
import hashlib

from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
    pipeline, TrainingArguments
)
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datasets import Dataset

plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (20, 5)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 16

pd.set_option('display.max_colwidth', 200)

## Load model

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.float16,
    bnb_4bit_use_double_quant= True,
    llm_int8_enable_fp32_cpu_offload= True,
    llm_int8_skip_modules=['gate', 'lm_head'],
)

torch.cuda.empty_cache()

In [None]:
auto_device_map = {
    'model.embed_tokens': 0,
    'model.layers.0': 0,
    'model.layers.1': 0,
    'model.layers.2': 0,
    'model.layers.3': 0,
    'model.layers.4': 0,
    'model.layers.5': 0,
    'model.layers.6': 0,
    'model.layers.7': 0,
    'model.layers.8': 0,
    'model.layers.9': 0,
    'model.layers.10': 0,
    'model.layers.11': 0,
    'model.layers.12': 0,
    'model.layers.13': 0,
    'model.layers.14': 1,
    'model.layers.15': 1,
    'model.layers.16': 1,
    'model.layers.17': 1,
    'model.layers.18': 1,
    'model.layers.19': 1,
    'model.layers.20': 1,
    'model.layers.21': 1,
    'model.layers.22': 1,
    'model.layers.23': 1,
    'model.layers.24': 1,
    'model.layers.25': 1,
    'model.layers.26': 1,
    'model.layers.27': 1,
    'model.layers.28': 1,
    'model.layers.29': 1,
    'model.layers.30': 1,
    'model.layers.31': 1,
    'model.norm': 1,
    'lm_head': 1
 }

def create_shared_device_map(transition_layer):
    shared_device_map = {}
    for idx, key in enumerate(auto_device_map):
        if idx <= transition_layer:
            shared_device_map[key] = 0
        else:
            shared_device_map[key] = 1
    return shared_device_map

def create_intertwined_device_map():
    device_map = {}
    for idx, key in enumerate(auto_device_map):
        if idx == 0:
            device_map[key] = 1
        elif idx >= 33:
            device_map[key] = 0
        else:
            device_map[key] = idx % 2
    return device_map

In [None]:
model_path = '/home/gbarbadillo/data/Mistral-7B-Instruct-v0.2/'
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=bnb_config,
    device_map='auto',
    trust_remote_code=True,
    )

In [None]:
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code=True)
tokenizer.add_special_tokens({'pad_token': '<pad>'})
tokenizer.padding_side = 'right' # by default is left, for training right seems to be better
model.resize_token_embeddings(len(tokenizer))

In [None]:
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer)

def chat_with_mixtral(prompt, max_new_tokens=200, verbose=True, do_sample=False, temperature=0.7, top_p=0.95):
    if not prompt.startswith('<s>[INST]'):
        print('Formatting the prompt to Mixtral needs.')
        prompt = f'<s>[INST] {prompt} [/INST]'
    start = time.time()

    if do_sample:
        sampling_kwargs = dict(do_sample=True, temperature=temperature, top_p=top_p)
    else:
        sampling_kwargs = dict(do_sample=False)

    sequences = pipe(
        prompt ,
        max_new_tokens=max_new_tokens,
        # https://www.reddit.com/r/LocalLLaMA/comments/184g120/mistral_fine_tuning_eos_and_padding/
        # https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/discussions/106
        pad_token_id=tokenizer.eos_token_id,
        **sampling_kwargs,
        return_full_text=False,
    )
    response = sequences[0]['generated_text']
    #response = re.sub(r'[\'"]', '', response)
    if verbose:
        stop = time.time()
        time_taken = stop-start
        n_tokens = len(tokenizer.tokenize(response))
        print(f"Execution Time : {time_taken:.1f} s, tokens per second: {n_tokens/time_taken:.1f}")
    return response

In [None]:
def print_gpu_memory():
    for device in range(torch.cuda.device_count()):
        print(f'GPU {device} memory allocated: {torch.cuda.memory_allocated(device)/1024**3:.1f} GB, max memory allocated: {torch.cuda.max_memory_allocated(device)/1024**3:.1f} GB')
print_gpu_memory()

If I don't quantize the model the memory usage is good for my PC, but too much for the submission.

```
GPU 0 memory allocated: 14.5 GB, max memory allocated: 14.5 GB
GPU 1 memory allocated: 14.5 GB, max memory allocated: 14.5 GB
```

With quantization memory usage is:

```
GPU 0 memory allocated: 3.3 GB, max memory allocated: 3.3 GB
GPU 1 memory allocated: 4.1 GB, max memory allocated: 4.1 GB
```

In [None]:
model

## Prepare data

In [None]:
prompt_template = """<s>[INST][prompt-recovery]
Analyze the original and rewritten text and answer with the most likely text prompt that was given to rewrite or make stylistic changes to the original text.

- The text prompt should be a single sentence. Reply just with a short sentence and do not add any notes or comments.
- Sometimes the rewritten text will have hints about the text prompt. For example if it starts by
  Reworded, Rephrased, Translated, Update etc. you should include that word in the text prompt.
- Unless necessary do not make reference to details in the original text and keep the text prompt abstract and generic.

## Original text

{original_text}

## Rewritten text

{rewritten_text}

[/INST] The most likely text prompt given to transform the original text into the rewritten text was: {response} </s>"""
response_template = "The most likely text prompt given to transform the original text into the rewritten text was:"

In [None]:
def prepare_dataframe_for_training(filepath, target_col='gpt4_normalized_response'):
    df = pd.read_csv(filepath)
    texts = []
    for _, row in df.iterrows():
        texts.append(prompt_template.format(original_text=row['original_text'],
                                            rewritten_text=row['rewritten_text'],
                                            response=row[target_col]))
    df['text'] = texts
    df['n_tokens'] = df['text'].apply(lambda x: len(tokenizer.tokenize(x)))
    return df

In [None]:
train_df_1 = prepare_dataframe_for_training('/mnt/hdd0/Kaggle/llm_prompt_recovery/data/high_quality_dataset_v1.csv',
                                          target_col='rewrite_prompt')
train_df_2 = prepare_dataframe_for_training('/mnt/hdd0/Kaggle/llm_prompt_recovery/data/mooney_test_with_gpt4.csv',
                                            target_col='gpt4_prompt')
bad_indices = [164, 181, 235]
print(train_df_2.loc[bad_indices].rewritten_text.values)
train_df_2.drop(bad_indices, inplace=True)
train_df_3 = prepare_dataframe_for_training('/mnt/hdd0/Kaggle/llm_prompt_recovery/data/gemma_suppl_rewrite_curated_with_gpt4.csv',
                                         target_col='gpt4_prompt')
train_df = pd.concat([train_df_1, train_df_2, train_df_3], ignore_index=True).reset_index(drop=True)
train_df.head()

In [None]:
eval_df_indices = train_df.sample(frac=0.1, random_state=42).index
eval_df = train_df.loc[eval_df_indices].copy()
train_df.drop(eval_df_indices, inplace=True)

In [None]:
plt.hist(train_df['n_tokens'], bins=50, alpha=0.5, label='train', cumulative=True, density=True)
plt.ylim(0, 1)
plt.grid()
plt.legend(loc=0)
plt.xlabel('Number of tokens')
plt.title('Token distribution of the texts');

In [None]:
print(f'There were {len(train_df)} samples for training and {len(eval_df)} samples for evaluation.')
max_seq_length = 640
train_df = train_df[train_df['n_tokens'] <= max_seq_length]
eval_df = eval_df[eval_df['n_tokens'] <= max_seq_length]
print(f'There are {len(train_df)} samples for training and {len(eval_df)} samples for evaluation.')

In [None]:
print(f'One epoch is {len(train_df)//16} steps.')

In [None]:
train_dataset = Dataset.from_pandas(train_df)
eval_dataset = Dataset.from_pandas(eval_df)

## Inference before training

In [None]:
for text in train_df['text'].values[:5]:
    print(chat_with_mixtral(text.split(response_template)[0] + response_template))

```
# without quantization
Execution Time : 1.5 s, tokens per second: 11.1
 "Rewrite the speech using the masculine pronoun for the speaker."
Execution Time : 1.5 s, tokens per second: 17.7
 "Rewrite the text about the most loyal friend a person can have, but this time focus on horses instead of dogs."
Execution Time : 1.1 s, tokens per second: 15.1
 "Rewrite the text in Mandarin Chinese for a Chinese audience."
Execution Time : 1.1 s, tokens per second: 15.9
 "Rewrite this text in Spanish for a Spanish-speaking audience."
Execution Time : 0.9 s, tokens per second: 15.8
 "Rewrite this announcement in Portuguese for an international audience."

# with quantization
Execution Time : 1.8 s, tokens per second: 14.2
 "Rewrite the text using 'he' or 'his' instead of'she' or 'her' throughout."
Execution Time : 1.0 s, tokens per second: 19.7
 "Rewrite the text about a loyal friend using a different animal as an example."
Execution Time : 1.2 s, tokens per second: 19.9
 "Rewrite the text in simplified Chinese for a Chinese audience while maintaining the original meaning and style."
Execution Time : 0.6 s, tokens per second: 17.5
 "Rewrite this text in Spanish."
Execution Time : 0.7 s, tokens per second: 19.1
 "Rewrite this announcement in Portuguese for an international audience."
```

In [None]:
for text in eval_df['text'].values[:5]:
    print(chat_with_mixtral(text.split(response_template)[0] + response_template))

## Train

In [None]:
model = prepare_model_for_kbit_training(model)
model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False # Gradient checkpointing is used by default but not compatible with caching

In [None]:
peft_config = LoraConfig(
    # lora_alpha: LoRA scaling factor.
    lora_alpha=64, #64,
    lora_dropout=0.1, # 0.1, althought Vaca suggested to use 0.05 for big models
    # r: the rank of the update matrices, expressed in int. Lower rank results in smaller update matrices with fewer trainable parameters.
    r=16, #16
    bias="none",
    task_type="CAUSAL_LM",
    # target_modules: The modules (for example, attention blocks) to apply the LoRA update matrices.
    target_modules= ['k_proj', 'q_proj', 'v_proj', 'o_proj']
)

In [None]:
logging_steps = len(train_df)//16
print(f'Logging steps: {logging_steps}')
training_arguments = TrainingArguments(
        output_dir="/mnt/hdd0/Kaggle/llm_prompt_recovery/trainings/2024-04-08_new_trainings/05_mistral",
        evaluation_strategy="steps",
        do_eval=True,
        optim="paged_adamw_8bit",
        per_device_train_batch_size=8, # 4-16 should be fine for lora.
        gradient_accumulation_steps=2,
        per_device_eval_batch_size=8,
        log_level="debug",
        save_steps=logging_steps, #50,
        logging_steps=logging_steps, #50,
        learning_rate=2e-5, # maybe we can increase this
        eval_steps=logging_steps, #50,
        max_steps=logging_steps*20, #300,
        warmup_steps=30,
        lr_scheduler_type="linear",
)

In [None]:
data_collator = DataCollatorForCompletionOnlyLM(tokenizer=tokenizer, response_template=response_template)
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    data_collator=data_collator,
    args=training_arguments,
)

trainer.train()

## Make a few inferences 

In [None]:
for text in train_df['text'].values[:5]:
    print(chat_with_mixtral(text.split(response_template)[0] + response_template))

In [None]:
for text in eval_df['text'].values[:5]:
    print(chat_with_mixtral(text.split(response_template)[0] + response_template))

## TODO

- [ ] rslora? https://huggingface.co/docs/peft/main/en/conceptual_guides/lora#common-lora-parameters-in-peft
- [ ] Reduce alpha, or reduce r as well
- [ ] Try with Vaca's parameters