## Impressions

[Mark Saroufim](https://twitter.com/marksaroufim?lang=en) and I attempted to load Mistral-7b into the Max Engine for an hour to no avail.

Here are my initial impressions.

- The Modular team is super engaging and friendly!  I recommend going to [their discord](https://discord.gg/e6kQTD4k) with questions.
- Pytorch feels (and is as far as we can tell) a second class citizen.
    - You have to do many more steps to potentially load a Pytorch model vs. TF, and those steps are not clear.[^1]
- I'm not sure why in 2024 you would lead with BERT/Tensorflow.  This is an odd choice, and makes on-boarding much less exciting for me.  I would like to see paved paths for modern LLMs like Llama or Mistral - the fact that they are not gives me pause.
- Model compilation and loading took 5 minutes.  Compilation needs a progress bar given that it is so long.
- Torchscript as a serialization format is older and in maintenance mode compared to more recent `torch.compile` or `torch.export` but Max doesn't support that yet.  Discussion [is here](https://discord.com/channels/1087530497313357884/1212827597323509870/1212889053821796382).
- Printing the model is not informative, like it is when you print a torch model (doesn't show you all the layers and shapes).
- We couldn't quite understand the output of the model and we eventually hypothesized that `torch.script` is not the right serialization path for Mistral, but we aren't sure.  I think users will get quite confused by this.

I'm hoping that the above changes soon, as I'm pretty bullish on the talent level and team working on these things overall.

[^1]: The [documentation states](https://docs.modular.com/engine/python/get-started): _"This example uses is a TensorFlow model (which must be converted to SavedModel format), and it's just as easy to load a model from PyTorch (which must be converted to TorchScript format)."_ The **just as easy** part raised my expectations a bit too high but I discovered that the Pytorch path is _much_ more onerous.

## Attempting To Load Mistral In The Max Engine

Today, the Modular team released the [Max Inference Engine](https://docs.modular.com/engine/overview):

> MAX Engine is a next-generation compiler and runtime system for neural network graphs. It supercharges the execution of AI models in any format (including TensorFlow, PyTorch, and ONNX), on a wide variety of hardware. MAX Engine also allows you to extend these models with custom ops that MAX Engine can analyze and optimize with other ops in the graph.

[These docs](https://docs.modular.com/engine/python/get-started) show how to load a TensorFlow model, but I want to load a pytorch LLM like Mistral-7b.  The Modular team [helped me figure out](https://discord.com/channels/1087530497313357884/1212827597323509870/1212844710817824848) how to do this.  I wrote down all the steps here. 

### 1. Serialize Model as Torchscript

In order to load your model in the Max engine we must serialize the model as torchscript.  We can do this by tracing the model graph and then using `torch.jit.save` to save the model.

In [None]:
#|warning: false
import time
from functools import partial
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# load model artifacts from the hub
hf_path="mistralai/Mistral-7B-v0.1"
model = AutoModelForCausalLM.from_pretrained(hf_path,torchscript=True)
tokenizer = AutoTokenizer.from_pretrained(hf_path)
tokenizer.pad_token = tokenizer.eos_token

# trace the model and save it with torchscript
max_seq_len=128
model_path="mistral.pt"
text = "This is text to be used for tracing"
# I'm setting the arguments for tokenizer once so I can reuse it later (they need to be consistent)
max_tokenizer = partial(tokenizer, return_tensors="pt",padding="max_length", max_length=max_seq_len)
inputs = max_tokenizer(text)
traced_model = torch.jit.trace(model, [inputs['input_ids'], inputs['attention_mask']])
torch.jit.save(traced_model, model_path)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
  if past_key_values_length > 0:
  if query_length > 1 and not is_tracing:
  if seq_len > self.max_seq_len_cached:
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
Tensor-likes are not close!

Mismatched elements: 3807957 / 4096000 (93.0%)
Greatest absolute difference: 8.579869270324707 at index (0, 118, 17554) (up to 1e-05 allowed)
Greatest relative difference: 1273150.52 at index (0, 73, 19823) (up to 1e-05 allowed)
  _check_trace(


### 2. Specify Input Shape

Having a set input shape is required for compilation. 

This next bit is from [This code](https://github.com/modularml/max/blob/main/examples/inference/bert-python-torchscript/simple-inference.py#L37-L40).  Apparently there is a way to specify dynamic values for the sequence len and batch size, but we couldn't figure that out easily from the docs.

In [None]:
from max import engine
input_spec_list = [
    engine.TorchInputSpec(shape=tensor.size(), dtype=engine.DType.int64)
    for tensor in inputs.values()
]
options = engine.TorchLoadOptions(input_spec_list)

### 3. Compile and Load Model

In [None]:
#|warning: false
start_time = time.time()
session = engine.InferenceSession()
model = session.load(model_path, options)

end_time = time.time()

Compiling model..   [2349485:2367566:20240229,165403.360772:ERROR crashpad_client_linux.cc:632] sigaltstack: Cannot allocate memory (12)
[2349485:2367565:20240229,165403.381422:ERROR crashpad_client_linux.cc:632] sigaltstack: Cannot allocate memory (12)
Compiling model..   [2349485:2373801:20240229,165656.931613:ERROR crashpad_client_linux.cc:632] sigaltstack: Cannot allocate memory (12)
[2349485:2373802:20240229,165656.933624:ERROR crashpad_client_linux.cc:632] sigaltstack: Cannot allocate memory (12)
[2349485:2373803:20240229,165656.933650:ERROR crashpad_client_linux.cc:632] sigaltstack: Cannot allocate memory (12)
[2349485:2373804:20240229,165656.933685:ERROR crashpad_client_linux.cc:632] sigaltstack: Cannot allocate memory (12)
[2349485:2373805:20240229,165656.933712:ERROR crashpad_client_linux.cc:632] sigaltstack: Cannot allocate memory (12)
[2349485:2373806:20240229,165656.933731:ERROR crashpad_client_linux.cc:632] sigaltstack: Cannot allocate memory (12)
[2349485:2373807:2024022

Wow! The model takes ~5 minutes to compile and load.  Subsequent compilations are faster, but NOT if I restart the Jupyter Kernel.

In [None]:
elapsed_time = end_time - start_time
print(f"Elapsed time: {elapsed_time} seconds")

Elapsed time: 371.86387610435486 seconds


### 4. Inference

:::{.callout-note}
#### We failed to get this to work
Even though we could call `model.execute` the outputs we got didn't make much sense to us, even after some investigation.  Our hypothesis is that `execute` is not calling `model.generate`.  But this is where we gave up.

:::

Be sure to set `return_token_type_ids=False`, note that I'm using the same arguments for `padding` and `max_length` that I used for tracing the model (because I'm using the `max_tokenizer` which I defined) so the shape is consistent.  

In [None]:
INPUT="Why did the chicken cross the road?"
inp = max_tokenizer(INPUT, return_token_type_ids=False)
out = model.execute(**inp)

Get the token ids (predictions) and decode them:

In [None]:
preds = out['result0'].argmax(axis=-1)

We tried to debug this but could not figure out what was wrong, so we gave up here. We aren't sure why the output looks like this.  See what the output is supposed to look like in [this section](#huggingface-comparison).

(Scroll to the right to see the full output)

In [None]:
#| code-overflow: wrap
' '.join(tokenizer.batch_decode(preds, skip_special_tokens=False))

'ммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммммм # do you chicken cross the road?\n'

We were intrigued by the `M` and Mark joked that it is some interesting illuminati secret code injected into the model for (i.e. `M` for `Modular`) which I thought was funny :)

Our theory is that torchscript is not the right way to serialize this model and this is some kind of silent failure, but it is hard to know.

## HuggingFace Comparison

As a sanity check, let's do inference with `transformers`:

In [None]:
#|warning: false
from transformers import AutoTokenizer, AutoModelForCausalLM
hf_path="mistralai/Mistral-7B-v0.1"
hfmodel = AutoModelForCausalLM.from_pretrained(hf_path,torchscript=True).cuda()
hftokenizer = AutoTokenizer.from_pretrained(hf_path)
hftokenizer.pad_token = hftokenizer.eos_token

_p="Why did the chicken cross the road?"
input_ids = hftokenizer(_p, return_tensors="pt", 
                      padding="max_length",
                      truncation=True).input_ids.cuda()
out_ids = hfmodel.generate(input_ids=input_ids, max_new_tokens=15, 
                          do_sample=False)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:01<00:00,  1.50it/s]
Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


In [None]:
out = hftokenizer.batch_decode(out_ids.detach().cpu().numpy(), 
                                  skip_special_tokens=True)[0][len(_p):]
print(out)



To get to the other side.

Why did the chicken
