## WIP: Mistral-7B

In [None]:
!python3 -m pip install -q transformers optimum[onnxruntime]
!python3 -m pip install -q torch --index-url https://download.pytorch.org/whl/cpu

To the conversion once. It takes around 15min to complete on c5.12xlarge.

In [None]:
!optimum-cli export onnx --model "mistralai/Mistral-7B-v0.1" "./onnx/mistral-7b-onnx"

In [20]:
from functools import partial
from transformers import AutoTokenizer
from max import engine

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
tokenizer.pad_token = tokenizer.eos_token

max_seq_len = 128
INPUT = "Why did the chicken cross the road?"
inputs = tokenizer(INPUT, return_tensors="pt", return_token_type_ids=False, max_length=max_seq_len)
print(f"inputs: {inputs}")
#max_tokenizer = partial(tokenizer, return_tensors="pt", padding="max_length", max_length=max_seq_len)
#inputs = max_tokenizer(text)
input_spec_list = [
    engine.TorchInputSpec(shape=tensor.size(), dtype=engine.DType.int64)
    for tensor in inputs.values()
]
options = engine.TorchLoadOptions(input_spec_list)
print(f"options: {options}")

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


inputs: {'input_ids': tensor([[    1,  4315,   863,   272, 13088,  3893,   272,  3878, 28804]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}
options: TorchLoadOptions(input_specs=[<max.engine.api.TorchInputSpec object at 0x7f23559a5720>, <max.engine.api.TorchInputSpec object at 0x7f23555b97e0>], type='torch')


First compilation takes around 10min on c5.12xlarge. Subsequent is faster ~ 2-5min

In [15]:
%%time
session = engine.InferenceSession()
model = session.load("./onnx/mistral-7b-onnx/model.onnx", options)

CPU times: user 22.7 s, sys: 29.9 s, total: 52.6 s
Wall time: 2min 48s


Compiling model.    
Done!


In [None]:
# inspect
# for tensor in model.input_metadata:
#     print(f'name: {tensor.name}, shape: {tensor.shape}, dtype: {tensor.dtype}')

In [21]:
import torch

inputs['position_ids'] = torch.arange(inputs['input_ids'].size(1), dtype=torch.long, device=inputs['input_ids'].device).unsqueeze(0)
# TODO: remove the hardcoded 32
for i in range(32):
    inputs[f"past_key_values.{i}.key"] = torch.zeros([1,8,0,128], dtype=torch.float)
    inputs[f"past_key_values.{i}.value"] = torch.zeros([1,8,0,128], dtype=torch.float)

In [22]:
%%time

output = model.execute(**inputs)

CPU times: user 7.52 s, sys: 44 ms, total: 7.56 s
Wall time: 366 ms


In [23]:
%%time

logits = torch.tensor(output["logits"], dtype=torch.float)
generated_token_ids = torch.argmax(logits, dim=-1)
generated_text = tokenizer.decode(generated_token_ids.squeeze().tolist(), skip_special_tokens=True)
print("First iteration generated text:", generated_text)

First iteration generated text: # do you chicken cross the road?

CPU times: user 13.7 ms, sys: 0 ns, total: 13.7 ms
Wall time: 1.33 ms


In [24]:
# TODO: cleanup after examination
for j in range(20):
    inputs = tokenizer(generated_text, return_tensors="pt", return_token_type_ids=False)
    inputs['position_ids'] = torch.arange(inputs['input_ids'].size(1), dtype=torch.long, device=inputs['input_ids'].device).unsqueeze(0)
    for i in range(32):
        inputs[f"past_key_values.{i}.key"] = torch.tensor(output[f"present.{i}.key"], dtype=torch.float)
        inputs[f"past_key_values.{i}.value"] = torch.tensor(output[f"present.{i}.value"], dtype=torch.float)

    inputs["attention_mask"] = torch.ones([1, inputs[f"past_key_values.0.key"].size(2) + inputs["input_ids"].size(1)], dtype=torch.int64)
    # print(inputs["attention_mask"].shape)
    output = model.execute(**inputs)
    logits = torch.tensor(output["logits"], dtype=torch.float)
    generated_token_ids = torch.argmax(logits, dim=-1)
    generated_text = tokenizer.decode(generated_token_ids.squeeze().tolist(), skip_special_tokens=True)
    print(f"=" * 80)
    print(f"Generated text iter {j}: {generated_text}")


Generated text iter 0: #1 we cross cross the road?


Generated text iter 1: #1? cross the the road?

To
Generated text iter 2: #1:
 the road road?

To get
Generated text iter 3: #1:

 road??

To get to
Generated text iter 4: #1: To

??

 get to the
Generated text iter 5: #1: To get

?

 to the other
Generated text iter 6: #1: To get to



To the other side
Generated text iter 7: #1: To get to the


To the other side.
Generated text iter 8: #1: To get to the road
ToWhy the other side.

Generated text iter 9: #1: To get to the other?
 get did other side.


Generated text iter 10: #1: To get to the other side
Why tod side.

Why
Generated text iter 11: #1: To get to the other side.Why didays.


 did
Generated text iter 12: #1: To get to the other side.
 did theide


Why the
Generated text iter 13: #1: To get to the other side.

 the other.
WhyWhy did chicken
Generated text iter 14: #1: To get to the other side.


 other side

 did did the #
Generated text iter 15: #1: To get to the other s