# Generate data with Gemma

## Goal

Train other models such as Phi-2 and Gemma and see the validation loss and leaderboard score.

## Imports

In [None]:
import torch
import gc
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
from tqdm.auto import tqdm
import wandb
import yaml
import os
import hashlib
import glob

import logging

from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
    pipeline, TrainingArguments
)
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datasets import Dataset

logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

plt.plot()
plt.close('all')
plt.rcParams["figure.figsize"] = (20, 5)
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['font.size'] = 16

pd.set_option('display.max_colwidth', 200)

## Code

### Model

In [None]:
model_to_path = {
    'mistral-7b': '/home/gbarbadillo/data/Mistral-7B-Instruct-v0.2',
    'mixtral-8x7b': '/home/gbarbadillo/data/mixtral-8x7b-instruct-v0.1-hf',
    'llama-13b': '/home/gbarbadillo/data/llama2-13b-chat-hf',
    'phi-2': '/home/gbarbadillo/data/phi-2',
    'gemma-2b': '/home/gbarbadillo/data/gemma-2b-it',
    'gemma-7b': '/home/gbarbadillo/data/gemma-7b-it',
    'mistral-22b': '/home/gbarbadillo/data/Mistral-22B-v0.2',
}

def load_model_and_tokenizer(model_name):
    logging.info(f'Loading {model_name}...')
    model_path = model_to_path[model_name]

    if model_name in ['phi-2', 'gemma-2b', 'gemma-7b']:
        bnb_config = None
    else:
        bnb_config = BitsAndBytesConfig(
            load_in_4bit= True,
            bnb_4bit_quant_type= "nf4",
            bnb_4bit_compute_dtype= torch.float16,
            bnb_4bit_use_double_quant= True,
            llm_int8_enable_fp32_cpu_offload= True,
            llm_int8_skip_modules=['gate', 'lm_head'],
        )


    torch.cuda.empty_cache()
    gc.collect()
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        quantization_config=bnb_config,
        device_map=get_device_map(model_name),
        trust_remote_code=True,)
    if model_name == 'llama-13b':
        # quirks to update the model default configuration
        model.config.pretraining_tp = 1 # otherwise the quantized model does not work
        model.generation_config.temperature = None
        model.generation_config.top_p = None


    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        trust_remote_code=True)
    if 'gemma' not in model_name:
        tokenizer.add_special_tokens({'pad_token': '<pad>'})
        model.resize_token_embeddings(len(tokenizer))
    tokenizer.padding_side = 'right' # by default is left, for training right seems to be better

    gc.collect()
    print_gpu_memory()
    return model, tokenizer

def empty_gpu_vram():
    logging.info('Emptying GPU VRAM...')
    global model, tokenizer, pipe
    del model
    del tokenizer
    del pipe
    gc.collect()
    gc.collect()
    torch.cuda.empty_cache()
    print_gpu_memory()

def print_gpu_memory():
    for device in range(torch.cuda.device_count()):
        print(f'GPU {device} memory allocated: {torch.cuda.memory_allocated(device)/1024**3:.1f} GB, max memory allocated: {torch.cuda.max_memory_allocated(device)/1024**3:.1f} GB')


auto_device_map = {
    'model.embed_tokens': 0,
    'model.layers.0': 0,
    'model.layers.1': 0,
    'model.layers.2': 0,
    'model.layers.3': 0,
    'model.layers.4': 0,
    'model.layers.5': 0,
    'model.layers.6': 0,
    'model.layers.7': 0,
    'model.layers.8': 0,
    'model.layers.9': 0,
    'model.layers.10': 0,
    'model.layers.11': 0,
    'model.layers.12': 0,
    'model.layers.13': 0,
    'model.layers.14': 1,
    'model.layers.15': 1,
    'model.layers.16': 1,
    'model.layers.17': 1,
    'model.layers.18': 1,
    'model.layers.19': 1,
    'model.layers.20': 1,
    'model.layers.21': 1,
    'model.layers.22': 1,
    'model.layers.23': 1,
    'model.layers.24': 1,
    'model.layers.25': 1,
    'model.layers.26': 1,
    'model.layers.27': 1,
    'model.layers.28': 1,
    'model.layers.29': 1,
    'model.layers.30': 1,
    'model.layers.31': 1,
    'model.norm': 1,
    'lm_head': 1
 }

def create_shared_device_map(transition_layer):
    shared_device_map = {}
    for idx, key in enumerate(auto_device_map):
        if idx <= transition_layer:
            shared_device_map[key] = 0
        else:
            shared_device_map[key] = 1
    return shared_device_map

def create_intertwined_device_map():
    device_map = {}
    for idx, key in enumerate(auto_device_map):
        if idx == 0:
            device_map[key] = 1
        elif idx >= 33:
            device_map[key] = 0
        else:
            device_map[key] = idx % 2
    return device_map


def get_device_map(model_name):
    if model_name == 'mixtral-8x7b':
        return create_intertwined_device_map()
    else:
        return 'auto'

assert get_device_map('foo') == 'auto'
assert get_device_map('mixtral-8x7b') != 'auto'

In [None]:
def chat_with_mixtral(prompt, max_new_tokens=400, do_sample=False, temperature=0.7, top_p=0.95):
    if do_sample:
        sampling_kwargs = dict(do_sample=True, temperature=temperature, top_p=top_p)
    else:
        sampling_kwargs = dict(do_sample=False)

    sequences = pipe(
        prompt ,
        max_new_tokens=max_new_tokens,
        # https://www.reddit.com/r/LocalLLaMA/comments/184g120/mistral_fine_tuning_eos_and_padding/
        # https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/discussions/106
        pad_token_id=tokenizer.eos_token_id,
        **sampling_kwargs,
        return_full_text=False,
    )
    response = sequences[0]['generated_text']
    return response

### Prompt template

Special tokens from different models:

```python
# gemma
{'bos_token': '<bos>',
 'eos_token': '<eos>',
 'unk_token': '<unk>',
 'pad_token': '<pad>',
 'additional_special_tokens': ['<start_of_turn>', '<end_of_turn>']}
# phi-2
{'bos_token': '<|endoftext|>',
 'eos_token': '<|endoftext|>',
 'unk_token': '<|endoftext|>'}
# mistral-7b
{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}

# llama-13b
{'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>'}
```

In [None]:
response_template = "The most likely text prompt given to transform the original text into the rewritten text was:"

In [None]:
gemma_template_v1 = """<start_of_turn>user
Analyze the original and rewritten text and answer with the most likely text prompt that was given to rewrite or make stylistic changes to the original text.

- The text prompt should be a single sentence. Reply just with a short sentence and do not add any notes or comments.
- Sometimes the rewritten text will have hints about the text prompt. For example if it starts by
  Reworded, Rephrased, Translated, Update etc. you should include that word in the text prompt.
- Unless necessary do not make reference to details in the original text and keep the text prompt abstract and generic.

## Original text

{original_text}

## Rewritten text

{rewritten_text}

<end_of_turn>
<start_of_turn>model
The most likely text prompt given to transform the original text into the rewritten text was: {response}
<end_of_turn><eos>"""
assert response_template in gemma_template_v1

In [None]:
phi_template_v1 = """Instruct: Analyze the original and rewritten text and answer with the most likely text prompt that was given to rewrite or make stylistic changes to the original text.

- The text prompt should be a single sentence. Reply just with a short sentence and do not add any notes or comments.
- Sometimes the rewritten text will have hints about the text prompt. For example if it starts by
  Reworded, Rephrased, Translated, Update etc. you should include that word in the text prompt.
- Unless necessary do not make reference to details in the original text and keep the text prompt abstract and generic.

## Original text

{original_text}

## Rewritten text

{rewritten_text}

Output: The most likely text prompt given to transform the original text into the rewritten text was: {response} <|endoftext|>"""
assert response_template in phi_template_v1

In [None]:
mistral_template = """<s>[INST][prompt-recovery]
Analyze the original and rewritten text and answer with the most likely text prompt that was given to rewrite or make stylistic changes to the original text.

- The text prompt should be a single sentence. Reply just with a short sentence and do not add any notes or comments.
- Sometimes the rewritten text will have hints about the text prompt. For example if it starts by
  Reworded, Rephrased, Translated, Update etc. you should include that word in the text prompt.
- Unless necessary do not make reference to details in the original text and keep the text prompt abstract and generic.

## Original text

{original_text}

## Rewritten text

{rewritten_text}

[/INST] The most likely text prompt given to transform the original text into the rewritten text was: {response} </s>"""
assert response_template in mistral_template

In [None]:
guanaco_template = """<s>### System: You are a helpful assistant.
### Human: Analyze the original and rewritten text and answer with the most likely text prompt that was given to rewrite or make stylistic changes to the original text.

- The text prompt should be a single sentence. Reply just with a short sentence and do not add any notes or comments.
- Sometimes the rewritten text will have hints about the text prompt. For example if it starts by
  Reworded, Rephrased, Translated, Update etc. you should include that word in the text prompt.
- Unless necessary do not make reference to details in the original text and keep the text prompt abstract and generic.

## Original text

{original_text}

## Rewritten text

{rewritten_text}

### Assistant: The most likely text prompt given to transform the original text into the rewritten text was: {response} </s>"""
assert response_template in mistral_template

### Raise

In [None]:
raise

## Generate data

In [None]:
gemma_template = """<start_of_turn>user
{rewrite_prompt}

{original_text}

<end_of_turn>
<start_of_turn>model"""

In [None]:
model, tokenizer = load_model_and_tokenizer('gemma-7b')
pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, pad_token_id=tokenizer.eos_token_id)

In [None]:
def create_gemma_dataset(filepath):
    df = pd.read_csv(filepath)
    responses = []
    for _, row in tqdm(df.iterrows(), total=len(df)):
        prompt = gemma_template.format(rewrite_prompt=row['rewrite_prompt'], original_text=row['original_text'])
        print(prompt)
        response = chat_with_mixtral(prompt)
        responses.append(response)
        print('-'*80)
        print(response)
        print('-'*80)
        print(row['rewritten_text'])
        print('='*80 + '\n\n')
    df['rewritten_text'] = responses
    df.to_csv(filepath.replace('.csv', '_gemma.csv'), index=False)

In [None]:
create_gemma_dataset('/mnt/hdd0/Kaggle/llm_prompt_recovery/data/high_quality_dataset_v2.csv')

Incredibly slow