# First steps with Mixtral

## Goal

Verify that I can use the Mixtral model locally.

## Imports

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import gc
import time
import re

## Code

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.float16,
    bnb_4bit_use_double_quant= True,
    llm_int8_enable_fp32_cpu_offload= True)

torch.cuda.empty_cache()
gc.collect()

In [None]:
model_path = '/mnt/hdd0/Kaggle/llm_prompt_recovery/models/mixtral-8x7b-instruct-v0.1-hf'
model_path = '/home/gbarbadillo/data/mixtral-8x7b-instruct-v0.1-hf/'
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,)

- 24 min when loading from HDD (reading at 62MB/s)
- 1 min when loading from SDD (reading at 1.5GB/s)

In [None]:
from transformers import pipeline, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
    model_path,
    trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.eos_token_id # this is needed to do batch inference
gc.collect()

In [None]:
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer)

def chat_with_mixtral(prompt, max_new_tokens=200, verbose=True, do_sample=False, temperature=0.7, top_p=0.95):
    if not prompt.startswith('<s>[INST]'):
        print('Formatting the prompt to Mixtral needs.')
        prompt = f'<s>[INST] {prompt} [/INST]'
    start = time.time()

    if do_sample:
        sampling_kwargs = dict(do_sample=True, temperature=temperature, top_p=top_p)
    else:
        sampling_kwargs = dict(do_sample=False)

    sequences = pipe(
        prompt ,
        max_new_tokens=max_new_tokens,
        # https://www.reddit.com/r/LocalLLaMA/comments/184g120/mistral_fine_tuning_eos_and_padding/
        # https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1/discussions/106
        pad_token_id=tokenizer.eos_token_id,
        **sampling_kwargs,
        return_full_text=False,
    )
    response = sequences[0]['generated_text']
    response = re.sub(r'[\'"]', '', response)
    if verbose:
        stop = time.time()
        time_taken = stop-start
        n_tokens = len(tokenizer.tokenize(response))
        print(f"Execution Time : {time_taken:.1f} s, tokens per second: {n_tokens/time_taken:.1f}")
    return response

## Chatting

In [None]:
for _ in range(2):
    print(chat_with_mixtral('write a poem about real madrid', max_new_tokens=25))

In [None]:
print(chat_with_mixtral('Write an essay about the future of digital identity.', 200))

- It is generating at a speed of 10.4 tokens per second, when using `torch.float16`
- When using `torch.bfloat16` it generated at 8.9 tokens per second

## Input formatting

I'm not sure if I'm using the correct input format:

- https://www.kaggle.com/models/mistral-ai/mixtral/frameworks/PyTorch/variations/8x7b-instruct-v0.1-hf/versions/1
- https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1

### Studying the tokenizer

#### Typical case

In [None]:
messages = [
    {"role": "user", "content": "say hi"},
]
tokenizer.apply_chat_template(messages, return_tensors="pt").numpy().tolist()[0]

In [None]:
prompt = f'<s>[INST] say hi [/INST]'
tokenizer.encode(prompt, add_special_tokens=False)

In [None]:
tokenizer.tokenize(prompt, add_special_tokens=False)

In [None]:
prompt = f'<s>[INST] say hi'
tokenizer.encode(prompt, add_special_tokens=False)

In this case we can see that the encoding is exactly the same. Notice that I had to remove the space between `<s>` and `[INST]`

#### Longer conversations

In [None]:
messages = [
    {"role": "user", "content": "Hi"},
    {"role": "assistant", "content": "Hello"},
]
tokenizer.apply_chat_template(messages, return_tensors="pt").numpy().tolist()[0]

In [None]:
prompt = f'<s>[INST] Hi [/INST]Hello'
tokenizer.encode(prompt, add_special_tokens=False)

In [None]:
tokenizer.convert_ids_to_tokens(2)

We can see that the difference is just that the chat template assumed the bot had ended the chat, but I didn't

In [None]:
messages = [
    {"role": "user", "content": "Hi"},
    {"role": "assistant", "content": "Hello"},
    {"role": "user", "content": "Bye"},
]
tokenizer.apply_chat_template(messages, return_tensors="pt").numpy().tolist()[0][-10:]

In [None]:
prompt = f'<s>[INST] Hi [/INST]Hello</s>[INST] Bye[/INST]'
tokenizer.encode(prompt, add_special_tokens=False)[-10:]

### Checking the pipeline

In [None]:
pipe(f'<s>[INST] say hi [/INST]', do_sample=False, return_full_text=False, pad_token_id=tokenizer.eos_token_id, max_new_tokens=50)

In [None]:
pipe(f'[INST] say hi [/INST]', do_sample=False, return_full_text=False, pad_token_id=tokenizer.eos_token_id, max_new_tokens=50, add_special_tokens=True)

This example shows that by default the pipeline was not adding the special token, but if I use `add_special_tokens=True` I can get the same results.

## Batch generation

In [None]:
prompts = [f'<s>[INST] What is the capital of {country}? Do not give any additional information, just say the capital and shut up.[/INST]The capital of {country} is: ' for country in ['Spain', 'France', 'Germany', 'Italy']]
pipe(prompts, do_sample=False, return_full_text=False, pad_token_id=tokenizer.eos_token_id, max_new_tokens=50)

In [None]:
prompts = [f'<s>[INST] What is the history of {country}? [/INST]' for country in ['Spain', 'France', 'Germany', 'Italy']]
pipe(prompts, do_sample=False, return_full_text=False, pad_token_id=tokenizer.eos_token_id, max_new_tokens=50)

It does not seem to speedup the inference in any way using a pipe with multiple inputs. It works but not faster.

By default the tokenizer adds the BOS token. So it is likely that in the pipeline it is done as well.

In [None]:

pipe_bs4 = pipeline(task="text-generation", model=model, tokenizer=tokenizer, batch_size=4)

In [None]:
prompts = [f'<s>[INST] What is the history of {country}? [/INST]' for country in ['Spain', 'France', 'Germany', 'Italy']]
pipe_bs4(prompts, do_sample=False, return_full_text=False, pad_token_id=tokenizer.eos_token_id, max_new_tokens=50)

GPU usage is higher, and generation is faster. Let's try to increase the batch size.

In [None]:
prompts = [f'<s>[INST] What is the history of {country}? [/INST]' for country in ['Spain']]
max_new_tokens = 25
for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]: # OOM, 2048, 4096]:
    new_pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, batch_size=batch_size)
    t0 = time.time()
    output = new_pipe(prompts*batch_size, do_sample=False, return_full_text=False, pad_token_id=tokenizer.eos_token_id, max_new_tokens=max_new_tokens)
    t0 = time.time() - t0
    print(f'Batch size: {batch_size}\tExecution time: {t0:.1f} s, tokens per second: {max_new_tokens*batch_size/t0:.1f}')

```
Batch size: 1	Execution time: 3.1 s, tokens per second: 8.1
Batch size: 2	Execution time: 4.6 s, tokens per second: 10.8
Batch size: 4	Execution time: 4.6 s, tokens per second: 21.9
Batch size: 8	Execution time: 4.6 s, tokens per second: 43.3
Batch size: 16	Execution time: 4.7 s, tokens per second: 85.8
Batch size: 32	Execution time: 4.9 s, tokens per second: 164.0
Batch size: 64	Execution time: 5.1 s, tokens per second: 315.1
Batch size: 128	Execution time: 6.3 s, tokens per second: 511.9
Batch size: 256	Execution time: 8.4 s, tokens per second: 760.3
Batch size: 512	Execution time: 13.3 s, tokens per second: 963.9
Batch size: 1024	Execution time: 23.7 s, tokens per second: 1080.6
```

These are incredible speedups. If I batch the predictions I could do inference 100 times faster.
When using a big batch size I see the GPUs alternating at 100%

## TODO

- [x] What is the message: `Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.`
- [x] Try using `tokenizer.apply_chat_template()`, what is the correct input format?
- [x] Can I use batches to speeedup generation? GPU use is around 13% when generating data
- [ ] Maybe on another notebook: setup a pipeline to evaluate different prompts. This is the way of doing prompt engineering. Try some prompt, evaluate, iterate.