## Using Mistral-7B with MAX Engine 🏎️ on CPU

**In this notebook we will walk through an example of using [Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) model with MAX Engine 🏎️ on CPU and float32. Check out the [roadmap](https://docs.modular.com/max/roadmap) for quantization and the GPU support.**

The Mistral-7B-v0.1 Large Language Model (LLM) is a pretrained generative text model with 7 billion parameters. Generative text models generate the next token iteratively given a sequence of past tokens representing the input prompt plus already generated response tokens.

Thus the underlying transformer model is invoked in each iteration of this loop until we reach the stopping condition (either the maximum number of generated tokens or a token designated as the end).

**Caveat: The model size is **28Gb**. Please make sure you have enough disk space to download the model and for the converted ONNX counterpart as we will use them later in this tutorial**

First, make sure you've installed `PyTorch` and `trasformers` packages

In [None]:
!python3 -m pip install -q torch --index-url https://download.pytorch.org/whl/cpu
!python3 -m pip install -q transformers onnx

## Vanilla transformers

Let's first see how the model generates a response using the vanilla `transformers`

**Note**: The model is ~ 28Gb so if you're downloading it for the first time it can take a while. Also make sure you've enough disk space.

In [None]:
%%time

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

PROMPT = "Why did the chicken cross the road?"

hf_path = "mistralai/Mistral-7B-v0.1"
hfmodel = AutoModelForCausalLM.from_pretrained(hf_path)
hftokenizer = AutoTokenizer.from_pretrained(hf_path)
hftokenizer.pad_token = hftokenizer.eos_token

# Tokenize the text prompt
input_ids = hftokenizer(PROMPT, return_tensors="pt", max_length=128, truncation=True).input_ids

# Run generation
out_ids = hfmodel.generate(input_ids=input_ids, max_new_tokens=15, do_sample=False)

# De-tokenize the generated response
response = hftokenizer.batch_decode(out_ids.numpy(), skip_special_tokens=True)[0][len(PROMPT):]
print("Response:", response)

## Next token generation

Now that we see that the model works, let's try to decompose its `htmodel.generate` method because we will encouter it later. We should be able to get the same output as before, but by only using `forward` method of the model.

The code below is a simplified version of the actual loop you can find in the transformer's source code. It starts by initializing the current sequense to the given prompt and then generates 10 subsequent tokens - these tokens constitute the response of the model.

In [None]:
from time import time
from transformers.generation.logits_process import LogitsProcessorList

logits_processor = LogitsProcessorList()

time_start = time()
current_seq = input_ids
N_TOKENS = 10
for idx in range(N_TOKENS):
    # Run model's `forward` on the current sequence.
    # 'logits' output would let us determine the next token for this sequence
    outputs = hfmodel(current_seq, return_dict=True).logits

    # Get the newly generated next token
    next_token_logits = outputs[:, -1, :]
    next_tokens_scores = logits_processor(current_seq, next_token_logits)
    next_tokens = torch.argmax(next_tokens_scores, dim=-1)

    print(hftokenizer.decode(next_tokens), end=' ', flush=True)

    # Append the new token to our sequence
    current_seq = torch.cat([current_seq, next_tokens[:, None]], dim=-1)

time_finish = time()
print(f"Prompt: {PROMPT}")
print("Response:", hftokenizer.batch_decode(current_seq.numpy(), skip_special_tokens=True)[0][len(PROMPT):])
print(f"Tokens per second: {N_TOKENS / (time_finish - time_start):.2f}\n")

## Convert to ONNX with optimum

Great! We were able to see the same response now with using only `forward` method of our model. We're now ready to use MAX Engine inference.
To do that, we start by getting an ONNX version of the model (more specifically - of its `forward` method). 

The easiest way to do it is to use HuggingFace `optimum` tool which you can install as follows

In [None]:
!python3 -m pip install -q optimum

Then the conversion to ONNX. This part can take a while. Also please make sure you've enough disk space.

In [None]:
!optimum-cli export onnx --model "mistralai/Mistral-7B-v0.1" "./onnx/mistral-7b-onnx"

## Examine the ONNX model

Now let's examine the ONNX model but first `onnx.load` can take a while too.

In [None]:
%%time
import onnx

onnxmodel = onnx.load("./onnx/mistral-7b-onnx/model.onnx")

In [None]:
def print_dims(tensor):
    dims = []
    for dim in tensor.type.tensor_type.shape.dim:
        if dim.HasField("dim_value"):
            dims.append(str(dim.dim_value))
        elif dim.HasField("dim_param"):
            dims.append(str(dim.dim_param))
    print(onnx.TensorProto.DataType.Name(tensor.type.tensor_type.elem_type), end=" ")
    print("[", ", ".join(dims), "]")

print("=== Inputs ===")
for input_tensor in onnxmodel.graph.input:
    print(input_tensor.name, end=": ")
    print_dims(input_tensor)

print("\n=== Outputs ===")
for output_tensor in onnxmodel.graph.output:
    print(output_tensor.name, end=": ")
    print_dims(output_tensor)

It might be quite surprising to see so many inputs and outputs in the model!

To decode what they all mean and how they should be used we will need to look into the [documentation of the MistralModel](https://huggingface.co/docs/transformers/v4.38.2/en/model_doc/mistral#transformers.MistralModel).

In short, we have the following inputs:
* input_ids
* position_ids
* attention_mask
* past_key_values

And the outputs will be:
* logits
* present_key_value

**Note** that since ONNX doesn't support dictionaries as a input/output type, the `key_value` is expanded into 32 pairs of individual tensors (32 is the number of attention heads in our model).

In order to use this model we will need to slightly modify our glue code to correctly weave all these values from each iteration to the next.
Specifically, we will need to pass the `key_values` from previous iteration to the current (for the first iteration they are initializes as empty tensors).
We will also need to correctly fill in `position_ids` and `attention_mask` tensors and update them on each iteration. We will not get into all the details of how exactly all these tensors affect the model behavior and should be used - this is an extremely interesting topic, but it lays beyond the scope of this walkthrough.

To free up some space, we delete the `onnxmodel`

In [None]:
%xdel onnxmodel

## MAX Engine 🏎️

With that we're finally ready to use the MAX Engine 🏎️ for inference.

The code modifications (apart from the already described updates to how we update tensors from iteration to iteration) will be quite minimal. All we need to do is to load the ONNX model (can take a while) into an `InferenceSession` object and instead of using the `hfmodel` we will need to use `maxmodel.execute`, and pack the input values into a dictionary.

Make sure you the converted ONNX model ready that we are going to use 👇

In [None]:
%%time

from max import engine
# Create an InferenceSession and load the ONNX model
session = engine.InferenceSession()
maxmodel = session.load("./onnx/mistral-7b-onnx/model.onnx")

We can also quickly inspect the input and output metadata that match the ONNX version above

In [None]:
for tensor in maxmodel.input_metadata:
    print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')

for tensor in maxmodel.output_metadata:
    print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')

Here is the how to get the response from our `maxmodel`. The token per second is about **2X** faster comparing to the PyTorch version.

In [None]:
inputs = {}
N_HEADS = 32
assert(N_HEADS == hfmodel.config.num_attention_heads)
# Initialize KV cache to 0 for the first iteration:
for i in range(N_HEADS):
    inputs[f"past_key_values.{i}.key"] = torch.zeros([1,8,0,128], dtype=torch.float).numpy()
    inputs[f"past_key_values.{i}.value"] = torch.zeros([1,8,0,128], dtype=torch.float).numpy()

current_seq = input_ids

time_start = time()
for idx in range(N_TOKENS):
    # Prepare inputs dictionary
    inputs["input_ids"] = current_seq.numpy()
    inputs["position_ids"] = torch.arange(inputs["input_ids"].shape[1], dtype=torch.long).unsqueeze(0).numpy()
    inputs["attention_mask"] = torch.ones([1, inputs[f"past_key_values.0.key"].shape[2] + inputs["input_ids"].shape[1]], dtype=torch.int64).numpy()

    # Run the model with MAX engine
    max_outputs = maxmodel.execute(**inputs)
    outputs = torch.from_numpy(max_outputs["logits"])

    # Get the newly generated next token
    next_token_logits = outputs[:, -1, :]
    next_tokens_scores = logits_processor(current_seq, next_token_logits)
    next_tokens = torch.argmax(next_tokens_scores, dim=-1)

    print(hftokenizer.decode(next_tokens), end=' ', flush=True)

    # Append the new token to our sequence
    current_seq = torch.cat([current_seq, next_tokens[:, None]], dim=-1)

    # Update the cache for the next iteration
    for i in range(N_HEADS):
        inputs[f"past_key_values.{i}.key"] = max_outputs[f"present.{i}.key"]
        inputs[f"past_key_values.{i}.value"] = max_outputs[f"present.{i}.value"]

time_finish = time()

print(f"Prompt: {PROMPT}")
print("Response:", hftokenizer.batch_decode(current_seq.numpy(), skip_special_tokens=True)[0][len(PROMPT):])
print(f"Tokens per second: {idx/(time_finish-time_start):.2f}\n")

That is it! 🎉

Serving an LLM has historically been not an easy task, but hopefully this example lifts the curtain on how this can be done. MAX Engine 🏎️ doesn't (yet) make this process easier, however, if you've already gone this path with ONNX or TorchScript, switching to MAX should be trivial and bring easy performance wins.