Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue with converting to ONNX format for LLM models #95

Closed
ratan opened this issue Jan 30, 2024 · 3 comments
Closed

Issue with converting to ONNX format for LLM models #95

ratan opened this issue Jan 30, 2024 · 3 comments
Labels
question Further information is requested

Comments

@ratan
Copy link

ratan commented Jan 30, 2024

crystal_chat_logs.txt
I have tried to convert few models to ONNX format and I am facing below issues:

  1. I am trying to convert to ONNXR format and script keeps running, please see the attached log file. It keeps creating new folder and keep running for different shape. Is it expected?
  2. I tried with https://huggingface.co/LLM360/CrystalChat, https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2 and found the same issue.
  3. Can you share some sample scripts where above models works?
@jeremyfowers
Copy link
Collaborator

Hi @ratan, thank you for trying out turnkey! We currently have a limitation with respect to auto-regressive models: when you call model.generate(prompt), the model is invoked many times, each with a different input size. In turn, this triggers many builds, and therefore many folders with one ONNX file each.

I took the Mistral-7B-Instruct example and replaced the model.generate() call with a model() call, which generates a single token. This worked as expected and generated a single ONNX file. Here is my example code:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

messages = [
    {"role": "user", "content": "What is your favourite condiment?"},
    {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
    {"role": "user", "content": "Do you have mayonnaise recipes?"}
]

encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")

result = model(encodeds)

Here is an additional example from turnkey's model corpus showing how to get a single ONNX file with a desired input shape: https://github.com/onnx/turnkeyml/blob/main/models/transformers/mistral_7b.py

@ratan does this workaround work for your use case?

cc @danielholanda for visibility

@jeremyfowers jeremyfowers added the question Further information is requested label Jan 30, 2024
@ratan
Copy link
Author

ratan commented Feb 1, 2024

Hi @jeremyfowers thanks for quick reply. I tried the suggestions you have mentioned and it worked!!

Replacing the model.generate() call with a model() call, i am able to generate the single ONNX file.

Below sample works for Mistral-7B-Instruct-v0.2

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# This is minimum setup required to convert LLM models (PyTorch) into ONNX Formats
# This converted ONNXR format can be viewed into Netron app

# Load the tokenizer and model
torch.set_default_device("cpu")
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

text_prompt = [
    {"role": "user", "content": "What is your favourite condiment?"},
    {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
    {"role": "user", "content": "Do you have mayonnaise recipes?"}
]

# Tokenize the prompt and convert to PyTorch tensors
inputs = tokenizer.apply_chat_template(text_prompt, return_tensors="pt")

# Generate text using the model
# Replace model.generate() call with a model() to reduce the call to auto-regressive models
# https://github.com/onnx/turnkeyml/issues/95

outputs = model(inputs)

Below sample works for LLM360/CrystalChat, microsoft/phi-2 and adept/fuyu-8b models

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# This is minimum setup required to convert LLM models (PyTorch) into ONNX Formats
# This converted ONNXR format can be viewed into Netron app

# Load the tokenizer and model
torch.set_default_device("cpu")
model_id = "adept/fuyu-8b"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

text_prompt = "<s> <|sys_start|> You are an AI assistant. You will be given a task. You must generate a detailed and long answer. <|sys_end|> <|im_start|> Write a python function that takes a list of integers and returns the squared sum of the list. <|im_end|>"

# Tokenize the prompt and convert to PyTorch tensors
inputs = tokenizer(text_prompt, return_tensors="pt")

# Generate text using the model
# Replace model.generate() call with a model() to reduce the call to auto-regressive models
# https://github.com/onnx/turnkeyml/issues/95

outputs = model(**inputs)

You may close this issue. Thanks again.

@jeremyfowers
Copy link
Collaborator

I'm glad to have been able to help! Please reach out if you have any more questions :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants