# Prompt tuning Mistral

## Goal

Instead of fine-tuning the model, let's use prompt tuning.

## Imports

In [None]:
import torch
import gc
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from tqdm.auto import tqdm
import wandb
import yaml
import os
import hashlib
import glob

import logging

from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
    pipeline, TrainingArguments
)
from peft import LoraConfig, prepare_model_for_kbit_training, PromptTuningConfig, TaskType, PromptTuningInit
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datasets import Dataset

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (20, 5)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 16

pd.set_option('display.max_colwidth', 200)

## Code

In [None]:
model_to_path = {
    'mistral-7b': '/home/gbarbadillo/data/Mistral-7B-Instruct-v0.2',
    'mixtral-8x7b': '/home/gbarbadillo/data/mixtral-8x7b-instruct-v0.1-hf',
    'llama-13b': '/home/gbarbadillo/data/llama2-13b-chat-hf',
}

def load_model_and_tokenizer(model_name):
    logging.info(f'Loading {model_name}...')
    model_path = model_to_path[model_name]

    bnb_config = BitsAndBytesConfig(
        load_in_4bit= True,
        bnb_4bit_quant_type= "nf4",
        bnb_4bit_compute_dtype= torch.float16,
        bnb_4bit_use_double_quant= True,
        llm_int8_enable_fp32_cpu_offload= True,
        llm_int8_skip_modules=['gate', 'lm_head'],
    )

    torch.cuda.empty_cache()
    gc.collect()
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        quantization_config=bnb_config,
        device_map=get_device_map(model_name),
        trust_remote_code=True,)
    if model_name == 'llama-13b':
        # quirks to update the model default configuration
        model.config.pretraining_tp = 1 # otherwise the quantized model does not work
        model.generation_config.temperature = None
        model.generation_config.top_p = None

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        trust_remote_code=True)
    tokenizer.add_special_tokens({'pad_token': '<pad>'})
    tokenizer.padding_side = 'right' # by default is left, for training right seems to be better
    model.resize_token_embeddings(len(tokenizer))

    gc.collect()
    print_gpu_memory()
    return model, tokenizer

def empty_gpu_vram():
    logging.info('Emptying GPU VRAM...')
    global model, tokenizer, pipe, trainer
    del model
    del tokenizer
    del pipe
    del trainer
    gc.collect()
    gc.collect()
    torch.cuda.empty_cache()
    print_gpu_memory()

def print_gpu_memory():
    for device in range(torch.cuda.device_count()):
        print(f'GPU {device} memory allocated: {torch.cuda.memory_allocated(device)/1024**3:.1f} GB, max memory allocated: {torch.cuda.max_memory_allocated(device)/1024**3:.1f} GB')


auto_device_map = {
    'model.embed_tokens': 0,
    'model.layers.0': 0,
    'model.layers.1': 0,
    'model.layers.2': 0,
    'model.layers.3': 0,
    'model.layers.4': 0,
    'model.layers.5': 0,
    'model.layers.6': 0,
    'model.layers.7': 0,
    'model.layers.8': 0,
    'model.layers.9': 0,
    'model.layers.10': 0,
    'model.layers.11': 0,
    'model.layers.12': 0,
    'model.layers.13': 0,
    'model.layers.14': 1,
    'model.layers.15': 1,
    'model.layers.16': 1,
    'model.layers.17': 1,
    'model.layers.18': 1,
    'model.layers.19': 1,
    'model.layers.20': 1,
    'model.layers.21': 1,
    'model.layers.22': 1,
    'model.layers.23': 1,
    'model.layers.24': 1,
    'model.layers.25': 1,
    'model.layers.26': 1,
    'model.layers.27': 1,
    'model.layers.28': 1,
    'model.layers.29': 1,
    'model.layers.30': 1,
    'model.layers.31': 1,
    'model.norm': 1,
    'lm_head': 1
 }

def create_shared_device_map(transition_layer=16):
    shared_device_map = {}
    for idx, key in enumerate(auto_device_map):
        if idx <= transition_layer:
            shared_device_map[key] = 0
        else:
            shared_device_map[key] = 1
    return shared_device_map

def create_intertwined_device_map():
    device_map = {}
    for idx, key in enumerate(auto_device_map):
        if idx == 0:
            device_map[key] = 1
        elif idx >= 33:
            device_map[key] = 0
        else:
            device_map[key] = idx % 2
    return device_map


def get_device_map(model_name):
    if model_name == 'mixtral-8x7b':
        return create_shared_device_map()
    else:
        return 'auto'

assert get_device_map('foo') == 'auto'
assert get_device_map('mixtral-8x7b') != 'auto'

In [None]:
def chat_with_mixtral(prompt, max_new_tokens=200, verbose=True, do_sample=False, temperature=0.7, top_p=0.95):
    if not prompt.startswith('<s>[INST]'):
        print('Formatting the prompt to Mixtral needs.')
        prompt = f'<s>[INST] {prompt} [/INST]'
    start = time.time()

    if do_sample:
        sampling_kwargs = dict(do_sample=True, temperature=temperature, top_p=top_p)
    else:
        sampling_kwargs = dict(do_sample=False)

    sequences = pipe(
        prompt ,
        max_new_tokens=max_new_tokens,
        # https://www.reddit.com/r/LocalLLaMA/comments/184g120/mistral_fine_tuning_eos_and_padding/
        # https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/discussions/106
        pad_token_id=tokenizer.eos_token_id,
        **sampling_kwargs,
        return_full_text=False,
    )
    response = sequences[0]['generated_text']
    #response = re.sub(r'[\'"]', '', response)
    if verbose:
        stop = time.time()
        time_taken = stop-start
        n_tokens = len(tokenizer.tokenize(response))
        print(f"Execution Time : {time_taken:.1f} s, tokens per second: {n_tokens/time_taken:.1f}")
    return response

### Prompt template

In [None]:
prompt_template = """<s>[INST][prompt-recovery]
Analyze the original and rewritten text and answer with the most likely text prompt that was given to rewrite or make stylistic changes to the original text.

- The text prompt should be a single sentence. Reply just with a short sentence and do not add any notes or comments.
- Sometimes the rewritten text will have hints about the text prompt. For example if it starts by
  Reworded, Rephrased, Translated, Update etc. you should include that word in the text prompt.
- Unless necessary do not make reference to details in the original text and keep the text prompt abstract and generic.

## Original text

{original_text}

## Rewritten text

{rewritten_text}

[/INST] The most likely text prompt given to transform the original text into the rewritten text was: {response} </s>"""
response_template = "The most likely text prompt given to transform the original text into the rewritten text was:"

Below there is a less constrained prompt template.

In [None]:
prompt_template = """<s>[INST][prompt-recovery]
Analyze the original and rewritten text and answer with the most likely text prompt that was given to rewrite or make stylistic changes to the original text.

## Original text

{original_text}

## Rewritten text

{rewritten_text}

[/INST] The most likely text prompt given to transform the original text into the rewritten text was: {response} </s>"""
response_template = "The most likely text prompt given to transform the original text into the rewritten text was:"

### Data

In [None]:
def prepare_dataframe_for_training(filepath, target_col='rewrite_prompt'):
    df = pd.read_csv(filepath)
    texts = []
    for _, row in df.iterrows():
        texts.append(prompt_template.format(original_text=row['original_text'],
                                            rewritten_text=row['rewritten_text'],
                                            response=row[target_col]))
    df['text'] = texts
    df['n_tokens'] = df['text'].apply(lambda x: len(tokenizer.tokenize(x)))
    return df

In [None]:
def filter_too_long_samples(train_df, eval_df, max_seq_length, safe_margin = 10):
    logging.info(f'Filtering samples with more than {max_seq_length} tokens.')
    print(f'There were {len(train_df)} samples for training and {len(eval_df)} samples for evaluation.')
    train_df = train_df[train_df['n_tokens'] <= max_seq_length - safe_margin]
    eval_df = eval_df[eval_df['n_tokens'] <= max_seq_length - safe_margin]
    print(f'There are {len(train_df)} samples for training and {len(eval_df)} samples for evaluation.')
    return train_df, eval_df

### Training

In [None]:
def prepare_model_for_training(model, tokenizer):
    model = prepare_model_for_kbit_training(model)
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.use_cache = False # Gradient checkpointing is used by default but not compatible with caching

In [None]:
def train_model_on_datasets(model_name, dataset_filepaths, experiment_name, peft_config,
                            max_seq_length=640, learning_rate=2e-5, shuffle=True,
                            batch_size=16, max_steps=250):
    global model, tokenizer, pipe, trainer
    model, tokenizer = load_model_and_tokenizer(model_name)
    pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer)

    train_df = pd.concat([prepare_dataframe_for_training(filepath) for filepath in dataset_filepaths])
    eval_df = prepare_dataframe_for_training('/mnt/hdd0/Kaggle/llm_prompt_recovery/data/high_quality_dataset_v1.csv')
    train_df, eval_df = filter_too_long_samples(train_df, eval_df, max_seq_length)
    train_dataset = Dataset.from_pandas(train_df)
    if shuffle:
        logging.info('Shuffling the training dataset.')
        train_dataset = train_dataset.shuffle(seed=42)
    eval_dataset = Dataset.from_pandas(eval_df)
    print(f'Train sample:\n{train_df.text.values[0]}')
    print(f'Test sample:\n{eval_df.text.values[0]}')
    logging.info('Making a few sample inferences')
    for text in train_df['text'].values[:5]:
        print(chat_with_mixtral(text.split(response_template)[0] + response_template))

    prepare_model_for_training(model, tokenizer)
    logging_steps = 25
    output_dir = f"/mnt/hdd0/Kaggle/llm_prompt_recovery/trainings/2024-04-14_prompt_tuning/{experiment_name}"
    training_arguments = TrainingArguments(
            output_dir=output_dir,
            evaluation_strategy="steps",
            do_eval=True,
            optim="paged_adamw_8bit",
            per_device_train_batch_size=batch_size, # 4-16 should be fine for lora.
            gradient_accumulation_steps=1,
            per_device_eval_batch_size=batch_size,
            log_level="debug",
            save_steps=logging_steps, #50,
            logging_steps=logging_steps//4, #50,
            learning_rate=learning_rate, # maybe we can increase this
            eval_steps=logging_steps, #50,
            max_steps=max_steps, #300,
            warmup_steps=0, #30 since we are training the prompt I don't believe we need warmup
            lr_scheduler_type="linear",
    )
    # https://docs.wandb.ai/guides/track/environment-variables
    # https://stackoverflow.com/questions/71106179/log-two-model-runs-with-keras-wandb
    w = wandb.init(reinit=True, project='prompt_tuning', name=experiment_name)
    data_collator = DataCollatorForCompletionOnlyLM(tokenizer=tokenizer, response_template=response_template)
    trainer = SFTTrainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=peft_config,
        dataset_text_field="text",
        max_seq_length=max_seq_length,
        data_collator=data_collator,
        args=training_arguments,
    )
    trainer.train()
    w.finish()

    # logging.info('Making a few sample inferences')
    # for text in train_df['text'].values[:5]:
    #     print(chat_with_mixtral(text.split(response_template)[0] + response_template))
    empty_gpu_vram()

In [None]:
raise

## More experiments with my data

In [None]:
filepaths = [
    '/mnt/hdd0/Kaggle/llm_prompt_recovery/data/high_quality_dataset_imitation_of_leaked_v1.csv',
    '/mnt/hdd0/Kaggle/llm_prompt_recovery/data/high_quality_dataset_v2.csv',
    '/mnt/hdd0/Kaggle/llm_prompt_recovery/data/high_quality_dataset_v3.csv',
    '/mnt/hdd0/Kaggle/llm_prompt_recovery/data/high_quality_dataset_v4_with_hints.csv',
    '/mnt/hdd0/Kaggle/llm_prompt_recovery/data/high_quality_dataset_v5_gpt4.csv',
    '/mnt/hdd0/Kaggle/llm_prompt_recovery/data/high_quality_multi_instruction_v1.csv',
    '/mnt/hdd0/Kaggle/llm_prompt_recovery/data/mooney_test_with_gpt4_v2.csv',
    '/mnt/hdd0/Kaggle/llm_prompt_recovery/data/gemma_suppl_rewrite_curated_with_gpt4.csv',
 ]

In [None]:
model_name = 'mistral-7b'
num_virtual_tokens = 8
prompt_tuning_conf = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM, #This type indicates the model will generate text.
    prompt_tuning_init=PromptTuningInit.RANDOM,  #The added virtual tokens are initializad with random numbers
    num_virtual_tokens=num_virtual_tokens, #Number of virtual tokens to be added and trained.
    tokenizer_name_or_path=model_to_path[model_name] #The pre-trained model.
)
learning_rate = 1e-1
experiment_name = f'{model_name}_prompt_tuning_vtk{num_virtual_tokens}_lr{learning_rate:.0e}_1000steps'
train_model_on_datasets(
    model_name,
    filepaths,
    experiment_name,
    prompt_tuning_conf,
    max_seq_length=640,
    learning_rate=learning_rate,
    max_steps=1000,
)

In [None]:
model_name = 'mixtral-8x7b'
num_virtual_tokens = 8
prompt_tuning_conf = PromptTuningConfig(
    task_type=TaskType.CAUSAL_LM, #This type indicates the model will generate text.
    prompt_tuning_init=PromptTuningInit.RANDOM,  #The added virtual tokens are initializad with random numbers
    num_virtual_tokens=num_virtual_tokens, #Number of virtual tokens to be added and trained.
    tokenizer_name_or_path=model_to_path[model_name] #The pre-trained model.
)
learning_rate = 5e-2
experiment_name = f'{model_name}_prompt_tuning_vtk{num_virtual_tokens}_lr{learning_rate:.0e}_1000steps'
train_model_on_datasets(
    model_name,
    filepaths,
    experiment_name,
    prompt_tuning_conf,
    max_seq_length=640,
    learning_rate=learning_rate,
    batch_size=8,
    max_steps=1000,
)

1e-1 gives better results than 2e-2 learning rate and 1, so it is close to optimum.

just 32k parameters. 131kB!

Inference with Mistral-7B

```
 "Find a way to motivate the team to exceed the quarterly goals by emphasizing the incentives and the team effort."
 "Make the language more urgent and emphasize the importance of addressing any discrepancies or issues related to the acre allocations for the project as soon as possible to prevent any delays or setbacks."
 "Rewrite this sentence in a more playful and engaging way, using metaphors and exaggeration to make the task of doing the dishes sound more exciting and adventurous."
 "Rewrite this statement to focus on the positive aspects of the project timeline extension and frame it as an opportunity rather than a challenge."
 "Critique the optimistic view on sustainability and discuss the challenges in implementing it effectively." This prompt would have encouraged the writer to explore the limitations and complexities of sustainability, as well as the need for systemic changes beyond technological solutions.
```



Predictions from `mistral-7b_prompt_tuning_vtk8_lr2e-03`, does not look good, the format is different and we can see a note at the end.

```
 "Find a way to motivate the team to exceed the quarterly goals by emphasizing the incentives and the team effort."
 "Make the language more urgent and emphasize the importance of addressing any discrepancies or issues related to the acre allocations for the project as soon as possible to prevent any delays or setbacks."
 "Rewrite this sentence in a more playful and engaging way, using metaphors and exaggeration to make the task of doing the dishes sound like an exciting adventure."
 "Rewrite this statement to focus on the positive aspects of the project timeline extension and frame it as an opportunity rather than a challenge."
 "Critique the optimistic view on sustainability and discuss the challenges in implementing it effectively." This prompt would have encouraged the writer to explore the limitations and complexities of sustainability, as well as the need for systemic changes beyond technological solutions.
```

Maybe I'm not using the model correctly.

## TODO

- [x] Verify that I can clean the memory at the end of the training
- [x] Weird message about too long texts