Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference worse with onnxruntime-gpu than native pytorch for seq2seq model #404

Closed
2 of 4 tasks
Matthieu-Tinycoaching opened this issue Sep 28, 2022 · 12 comments · Fixed by #461
Closed
2 of 4 tasks
Assignees
Labels
bug Something isn't working inference Related to Inference onnxruntime Related to ONNX Runtime

Comments

@Matthieu-Tinycoaching
Copy link

Matthieu-Tinycoaching commented Sep 28, 2022

System Info

Optimum: 1.4.1.dev0
torch: 1.12.1+cu116
onnx: 1.12.0
onnxruntime-gpu: 1.12.1
python: 3.8.13
CUDA: 11.6
cudnn: 8.4.1
RTX 3090

Who can help?

@JingyaHuang @echarlaix

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I compared inference on GPU of a native torch Helsinki-NLP/opus-mt-fr-en model with respect to the optimized onnx model thanks to Optimum library. So, I have defined a fastAPI microservice based on two classes below for GPU both torch and optimized ONNX, repsectively:

class Seq2SeqModel:
    tokenizer: Optional[MarianTokenizer]
    model: Optional[MarianMTModel]

    def load_model(self):
        """Loads the model"""
        # model_id="Helsinki-NLP/opus-mt-fr-en"
        model_path = Path("./app/artifacts/HF")
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to("cuda")
        self.tokenizer = tokenizer
        self.model = model

    async def predict(self, input: PredictionInput) -> PredictionOutput:
        """Runs a prediction"""
        if not self.tokenizer or not self.model:
            raise RuntimeError("Model is not loaded")
        tokens = self.tokenizer(input.text, return_tensors="pt").to("cuda")
        translated = self.model.generate(**tokens, num_beams=beam_size)
        return PredictionOutput(translated_text=self.tokenizer.decode(translated[0], skip_special_tokens=True))

class OnnxOptimizedSeq2SeqModel:
    tokenizer: Optional[MarianTokenizer]
    model: Optional[ORTModelForSeq2SeqLM]

    def load_model(self):
        """Loads the model"""
        # model_id="Helsinki-NLP/opus-mt-fr-en"
        onnx_path = Path("./app/artifacts/OL_1")
        tokenizer = AutoTokenizer.from_pretrained(onnx_path)
        optimized_model = ORTModelForSeq2SeqLM.from_pretrained(
            onnx_path,
            encoder_file_name="encoder_model_optimized.onnx",
            decoder_file_name="decoder_model_optimized.onnx",
            decoder_file_with_past_name="decoder_with_past_model_optimized.onnx",
            provider="CUDAExecutionProvider"
        )
        self.tokenizer = tokenizer
        self.model = optimized_model

app = FastAPI()
seq2seq_model = Seq2SeqModel()
onnx_optimized_seq2seq_model = OnnxOptimizedSeq2SeqModel()
beam_size = 3

@app.on_event("startup")
async def startup():
    seq2seq_model.load_model()
    onnx_optimized_seq2seq_model.load_model()

@app.post("/prediction")
async def prediction(
    output: PredictionOutput = Depends(seq2seq_model.predict),
) -> PredictionOutput:
    return output

@app.post("/prediction_onnx_optimized")
async def prediction(
    output: PredictionOutput = Depends(onnx_optimized_seq2seq_model.predict),
) -> PredictionOutput:
    return output

Expected behavior

When load testing the model on my local computer, I was surprised by two things:

  1. The performance on GPU of the optimized ONNX model is worse than the native torch (maybe linked to Inference performance drop 22X on GPU hardware with optimum[onnxruntime-gpu] (compared with transformer) #365 and Optimize ONNX model based on encoder-decoder #396?) :

GPU_optimized_onnxruntime
GPU_torch

  1. When running this fastAPI service into a docker image I got the following warning:

2022-09-28 08:20:21.214094612 [W:onnxruntime:Default, onnxruntime_pybind_state.cc:566 CreateExecutionProviderInstance] Failed to create CUDAExecutionProvider. Please reference https://onnxruntime.ai/docs/reference/execution-providers/CUDA-ExecutionProvider.html#requirements to ensure all dependencies are met.

Does this mean the CUDAExecutionProvider is not working even if I set it in?:

        optimized_model = ORTModelForSeq2SeqLM.from_pretrained(
            onnx_path,
            encoder_file_name="encoder_model_optimized.onnx",
            decoder_file_name="decoder_model_optimized.onnx",
            decoder_file_with_past_name="decoder_with_past_model_optimized.onnx",
            provider="CUDAExecutionProvider"
        )

What could be caused that? I saw in https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html that CUDA 11.6 is not mentionned, could it be this?

@fxmarty
Copy link
Collaborator

fxmarty commented Sep 28, 2022

Hi @Matthieu-Tinycoaching , thank you for the report! Check my answer for your question 2. in the linked issue in onnxruntime repo.

Could you post the report of pip list | grep onnx and pip list | grep optimum?

On my end, I run onnxruntime-gpu=1.12.1, CUDA cuda_11.7.r11.7.

Maybe try

export CUDA_PATH=/usr/local/cuda
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.7/lib64

with the right path for CUDA.

For question 1, I am currently looking into it, stay tuned.

@fxmarty
Copy link
Collaborator

fxmarty commented Sep 28, 2022

Answering your question in the issue in onnxruntime repo: Yes, passing provider="CUDAExecutionProvider" is necessary, see the Optimum documentation: https://huggingface.co/docs/optimum/main/en/onnxruntime/modeling_ort#optimum.onnxruntime.ORTModel.from_pretrained

Given that you can load the model with CUDAExecutionProvider with my code snippet, I am not sure what goes wrong in yours. It's likely not an issue with CUDA/onnxruntime install. How were your encoder_model_optimized.onnx and such saved?

@Matthieu-Tinycoaching
Copy link
Author

Matthieu-Tinycoaching commented Sep 28, 2022

Hi @fxmarty,

Thanks for helping. For the question 2. it seems that my GPU is working properly... Please find below the steps I used to export and optimize ONNX models:

## Convert native model to ONNX

from optimum.onnxruntime import ORTModelForSeq2SeqLM
from transformers import AutoTokenizer
from pathlib import Path

model_id="Helsinki-NLP/opus-mt-fr-en"
onnx_path = Path("/home/matthieu/Deployment/ONNX_runtime/GPU/V3/Optimum/opus-mt-fr-en/onnx1.12.0_onxxruntime-gpu1.12.1/OL_1")

# load vanilla transformers and convert to onnx
model = ORTModelForSeq2SeqLM.from_pretrained(model_id, from_transformers=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# save onnx checkpoint and tokenizer
model.save_pretrained(onnx_path)
tokenizer.save_pretrained(onnx_path)
## Optimize the model for GPU using ORTOptimizer

from optimum.onnxruntime import ORTOptimizer
from optimum.onnxruntime.configuration import OptimizationConfig

# Create ORTOptimizer
optimizer = ORTOptimizer.from_pretrained(model)

# Define the optimization strategy by creating the appropriate configuration
optimization_config = OptimizationConfig(optimization_level=1,
                                        optimize_for_gpu=True,
                                        fp16=True
                                        )

# Optimize the model
optimizer.optimize(save_dir=onnx_path, optimization_config=optimization_config)

Please find below the output of the following commands:

pip list | grep onnx
onnx                      1.12.0
onnxruntime-gpu           1.12.1
onnxruntime-tools         1.7.0
pip list | grep optimum
optimum                   1.4.1.dev0

1 / Did I do something wrong? Could you reproduce my results using fastAPI?

2/ Is there a way to get the prediction output from combination ONNX InferenceSession with the 3 ONNX optimized model instead of using Optimum? So, I could check with another deployment framework than fastAPI?

3 / How long does it takes to export the ONNX model from torch? I try to convert the model from my first code block above by adding provider="CUDAExecutionProvider" flag an it has still run from 8 minutes...

@fxmarty
Copy link
Collaborator

fxmarty commented Sep 28, 2022

Concerning 3/: The from_transformers argument is used in case we load the model from a transformers checkpoint. In case your model is already stored as onnx, you don't need to pass the argument, or pass from_transformers=False. You can refer to https://huggingface.co/docs/optimum/main/en/onnxruntime/modeling_ort#optimum.onnxruntime.ORTModel.from_pretrained for details. Having the conversion run for 8 minutes is not normal (unless you are downloading a model not cached from the Hub).

2/: not sure what you mean. Do you mean, could you use the encoder_model_optimized.onnx, decoder_model_optimized.onnx independently of Optimum?

In the first place, you don't get anymore the warning 2022-09-28 08:20:21.214094612 [W:onnxruntime:Default, onnxruntime_pybind_state.cc:566 CreateExecutionProviderInstance] Failed to create CUDAExecutionProvider. Please reference https://onnxruntime.ai/docs/reference/execution-providers/CUDA-ExecutionProvider.html#requirements to ensure all dependencies are met.?

@Matthieu-Tinycoaching
Copy link
Author

Matthieu-Tinycoaching commented Sep 28, 2022

Hi @fxmarty,

OK I regenerate all onnx models based on this script:

import onnxruntime
from optimum.onnxruntime import ORTModelForSeq2SeqLM
from transformers import MarianTokenizer, pipeline
from pathlib import Path
from optimum.onnxruntime import ORTOptimizer
from optimum.onnxruntime.configuration import OptimizationConfig

onnx_path = Path("/home/matthieu/Deployment/ONNX_runtime/GPU/V3/Optimum/opus-mt-fr-en/github_issue_onnxruntime")

tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en")

options = onnxruntime.SessionOptions()
options.log_severity_level = 0  # verbose, to see which execution provider is used

#### 1. Convert an opus-mt model to ONNX for inference ####

# load vanilla transformers and convert to onnx
ort_model = ORTModelForSeq2SeqLM.from_pretrained(
    "Helsinki-NLP/opus-mt-fr-en",
    from_transformers=True,
    provider="CUDAExecutionProvider",
    session_options=options,
)

print(ort_model.providers)

# save onnx checkpoint and tokenizer
ort_model.save_pretrained(onnx_path)
tokenizer.save_pretrained(onnx_path)

#### 2. Optimize model for GPU using ORTOptimizer ####

# Create ORTOptimizer
optimizer = ORTOptimizer.from_pretrained(ort_model)

# Define the optimization strategy by creating the appropriate configuration
optimization_config = OptimizationConfig(optimization_level=1,
                                        optimize_for_gpu=True,
                                        fp16=True
                                        )

# Optimize the model
optimizer.optimize(save_dir=onnx_path, optimization_config=optimization_config)

And I got both triplets (encoder_model.onnx, decoder_with_past_model.onnx, decoder_model.onnx) and (encoder_model_optimized.onnx, decoder_with_past_model_optimized.onnx, decoder_model_optimized.onnx) in the onnx_path folder with GPU optimization.

Then, I tried to compare load performance between native torch model and onnx models (w/wo optimizations) using fastAPI:

import os
from typing import List, Optional, Tuple
from fastapi import FastAPI, Depends
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from transformers.pipelines.text2text_generation import TranslationPipeline
from optimum.onnxruntime import ORTModelForSeq2SeqLM
from transformers.models.marian.tokenization_marian import MarianTokenizer
from transformers.models.marian.modeling_marian import MarianMTModel
from pathlib import Path
# from easynmt import EasyNMT

class PredictionInput(BaseModel):
    text: str

class PredictionOutput(BaseModel):
    translated_text: str

class Seq2SeqModel:
    tokenizer: Optional[MarianTokenizer]
    model: Optional[MarianMTModel]

    def load_model(self):
        """Loads the model"""
        # model_id="Helsinki-NLP/opus-mt-fr-en"
        model_path = Path("./app/artifacts/HF")
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to("cuda")
        self.tokenizer = tokenizer
        self.model = model

    async def predict(self, input: PredictionInput) -> PredictionOutput:
        """Runs a prediction"""
        if not self.tokenizer or not self.model:
            raise RuntimeError("Model is not loaded")
        tokens = self.tokenizer(input.text, return_tensors="pt").to("cuda")
        translated = self.model.generate(**tokens, num_beams=beam_size)
        return PredictionOutput(translated_text=self.tokenizer.decode(translated[0], skip_special_tokens=True))

class OnnxSeq2SeqModel:
    tokenizer: Optional[MarianTokenizer]
    model: Optional[ORTModelForSeq2SeqLM]

    def load_model(self):
        """Loads the model"""
        # model_id="Helsinki-NLP/opus-mt-fr-en"
        onnx_path = Path("./app/artifacts/OL_1")
        tokenizer = AutoTokenizer.from_pretrained(onnx_path)
        optimized_model = ORTModelForSeq2SeqLM.from_pretrained(
            onnx_path,
            encoder_file_name="encoder_model.onnx",
            decoder_file_name="decoder_model.onnx",
            decoder_file_with_past_name="decoder_with_past_model.onnx",
            provider="CUDAExecutionProvider"
        )
        self.tokenizer = tokenizer
        self.model = optimized_model

    async def predict(self, input: PredictionInput) -> PredictionOutput:
        """Runs a prediction"""
        if not self.tokenizer or not self.model:
            raise RuntimeError("Model is not loaded")
        tokens = self.tokenizer(input.text, return_tensors="pt").to("cuda")
        translated = self.model.generate(**tokens, num_beams=beam_size)
        return PredictionOutput(translated_text=self.tokenizer.decode(translated[0], skip_special_tokens=True))

class OnnxOptimizedSeq2SeqModel:
    tokenizer: Optional[MarianTokenizer]
    model: Optional[ORTModelForSeq2SeqLM]

    def load_model(self):
        """Loads the model"""
        # model_id="Helsinki-NLP/opus-mt-fr-en"
        onnx_path = Path("./app/artifacts/OL_1")
        tokenizer = AutoTokenizer.from_pretrained(onnx_path)
        optimized_model = ORTModelForSeq2SeqLM.from_pretrained(
            onnx_path,
            encoder_file_name="encoder_model_optimized.onnx",
            decoder_file_name="decoder_model_optimized.onnx",
            decoder_file_with_past_name="decoder_with_past_model_optimized.onnx",
            provider="CUDAExecutionProvider"
        )
        self.tokenizer = tokenizer
        self.model = optimized_model

    async def predict(self, input: PredictionInput) -> PredictionOutput:
        """Runs a prediction"""
        if not self.tokenizer or not self.model:
            raise RuntimeError("Model is not loaded")
        tokens = self.tokenizer(input.text, return_tensors="pt").to("cuda")
        translated = self.model.generate(**tokens, num_beams=beam_size)
        return PredictionOutput(translated_text=self.tokenizer.decode(translated[0], skip_special_tokens=True))

class OnnxOptimizedSeq2SeqPipeline:
    pipeline: Optional[TranslationPipeline]

    def load_pipeline(self):
        """Loads the pipeline"""
        # model_id="Helsinki-NLP/opus-mt-fr-en"
        onnx_path = Path("./app/artifacts/OL_1")
        tokenizer = AutoTokenizer.from_pretrained(onnx_path)
        optimized_model = ORTModelForSeq2SeqLM.from_pretrained(
            onnx_path,
            encoder_file_name="encoder_model_optimized.onnx",
            decoder_file_name="decoder_model_optimized.onnx",
            decoder_file_with_past_name="decoder_with_past_model_optimized.onnx",
            provider="CUDAExecutionProvider"
        )
        self.pipeline = pipeline("translation_fr_to_en", model=optimized_model, tokenizer=tokenizer, num_beams=beam_size, device=0)

    async def predict(self, input: PredictionInput) -> PredictionOutput:
        """Runs a prediction"""
        if not self.pipeline:
            raise RuntimeError("Pipeline is not loaded")
        translated = self.pipeline(input.text)
        return PredictionOutput(translated_text=translated[0]['translation_text'])

app = FastAPI()
seq2seq_model = Seq2SeqModel()
onnx_seq2seq_model = OnnxSeq2SeqModel()
onnx_optimized_seq2seq_model = OnnxOptimizedSeq2SeqModel()
onnx_optimized_seq2seq_pipeline = OnnxOptimizedSeq2SeqPipeline()
beam_size = 3

@app.on_event("startup")
async def startup():
    seq2seq_model.load_model()
    onnx_seq2seq_model.load_model()
    onnx_optimized_seq2seq_model.load_model()
    onnx_optimized_seq2seq_pipeline.load_pipeline()

@app.post("/prediction")
async def prediction(
    output: PredictionOutput = Depends(seq2seq_model.predict),
) -> PredictionOutput:
    return output

@app.post("/prediction_onnx")
async def prediction(
    output: PredictionOutput = Depends(onnx_seq2seq_model.predict),
) -> PredictionOutput:
    return output

@app.post("/prediction_onnx_optimized")
async def prediction(
    output: PredictionOutput = Depends(onnx_optimized_seq2seq_model.predict),
) -> PredictionOutput:
    return output

@app.post("/prediction_onnx_optimized_pipeline")
async def prediction(
    output: PredictionOutput = Depends(onnx_optimized_seq2seq_pipeline.predict),
) -> PredictionOutput:
    return output

GPU_onn_optimized_pipeline
GPU_onnx_optimized
GPU_onnx
GPU_native_torch

1 / And best results are the last one for the native pytorch model. Could you reproduce these misperformance of onnx models (with or wihtout optimization) regarding native model?

2/ Do you mean, could you use the encoder_model_optimized.onnx, decoder_model_optimized.onnx independently of Optimum?

Yes, is it possible to sequentially call the triplet of onnx models?

3 / In the first place, you don't get anymore the warning 2022-09-28 08:20:21.214094612 [W:onnxruntime:Default, onnxruntime_pybind_state.cc:566 CreateExecutionProviderInstance] Failed to create CUDAExecutionProvider. Please reference https://onnxruntime.ai/docs/reference/execution-providers/CUDA-ExecutionProvider.html#requirements to ensure all dependencies are met.?

I'll test tomorrow with docker image, but since performance are still worse with onnx, I suppose it stil there...

@michaelbenayoun michaelbenayoun added onnxruntime Related to ONNX Runtime inference Related to Inference labels Oct 14, 2022
@JingyaHuang JingyaHuang self-assigned this Oct 14, 2022
@JingyaHuang
Copy link
Collaborator

Hi @Matthieu-Tinycoaching,

For your concern about the inference speed, it could be caused by the overhead on data copying between CPU and GPU, as ONNX Runtime put inputs and outputs on the CPU by default. I just merged the PR which add the IO binding support to avoid the issue, do you want to test with our main branch(no code change as use_io_binding=True by default) to see if this can solve the problem?

You can build optimum from source with

python -m pip install git+https://github.com/huggingface/optimum.git#egg=optimum[onnxruntime-gpu]

@soocheolnoh
Copy link

In my case on T5 seq2seq model,

model use_io_binding num_beams average time for 10 loops [sec]
pytorch 1 0.89
onnx True 1 0.54
onnx False 1 0.77
pytorch 5 1.17
onnx True 5 3.17
onnx False 5 2.13

When num_beams > 1, the average time is even slower than before. (use_io_binding=False)
All results was generated by using these functions:

def load_pytorch(saved_path, device="cuda:0"):
    tokenizer = T5Tokenizer.from_pretrained("google/mt5-base")
    net = MT5ForConditionalGeneration.from_pretrained("google/mt5-base").to(device)
    net.eval()
    state_dict = torch.load(saved_path, map_location=device)["model_state_dict"]
    net.load_state_dict(state_dict)
    return tokenizer, net


def load_onnx(saved_path, device="cuda:0"):
    tokenizer = T5Tokenizer.from_pretrained(saved_path)
    net = ORTModelForSeq2SeqLM.from_pretrained(
        saved_path,
        encoder_file_name="encoder_model.onnx",
        decoder_file_name="decoder_model.onnx",
        decoder_with_past_file_name="decoder_with_past_model.onnx",
        use_io_binding=False,
        provider="CUDAExecutionProvider",
    ).to(device)
    return tokenizer, net


def inference(tokenizer, net, sentence, device="cuda:0"):
    input_ids = tokenizer.encode(sentence)
    input_ids = torch.tensor(input_ids).to(dtype=torch.int64, device=device)
    input_ids = torch.unsqueeze(input_ids, dim=0)
    output_ids = net.generate(input_ids, max_length=256, num_beams=5)
    translated = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return translated

@JingyaHuang
Copy link
Collaborator

Hi @soocheolnoh,

Thanks for testing the IO binding feature and sharing your results!

I have done a quick test with beam search num_beams=5 on the checkpoint facebook/m2m100_418M(the snippet was borrowed from #365 )

# -*- coding: utf-8 -*-

import logging
import time

import torch
from transformers import AutoTokenizer
from tqdm import tqdm


logging.basicConfig(level=logging.INFO)


model_checkpoint = "facebook/m2m100_418M"

loop = 100


chinese_text = "机器学习是人工智能的一个分支。人工智能的研究历史有着一条从以“推理”为重点,到以“知识”为重点,再到以“学习”为重点的自然、清晰的脉络。显然,机器学习是实现人工智能的一个途径,即以机器学习为手段解决人工智能中的问题。机器学习在近30多年已发展为一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析(英语:Convex analysis)、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。机器学习算法是一类从数据中自动分析获得规律,并利用规律对未知数据进行预测的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。很多推论问题属于无程序可循难度,所以部分的机器学习研究是开发容易处理的近似算法。机器学习已广泛应用于数据挖掘、计算机视觉、自然语言处理、生物特征识别、搜索引擎、医学诊断、检测信用卡欺诈、证券市场分析、DNA序列测序、语音和手写识别、战略游戏和机器人等领域"
logging.info(f"chinese_text is {chinese_text}")
logging.info(f"chinese_text length is {len(chinese_text)}")

device = torch.device("cuda:0")
logging.info(f"This test will use device: {device}")


def get_transformer_model():
    from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer

    model = M2M100ForConditionalGeneration.from_pretrained(model_checkpoint).to(device)
    tokenizer = M2M100Tokenizer.from_pretrained(model_checkpoint)
    return (model, tokenizer)


def get_optimum_onnx_model():
    from optimum.onnxruntime import ORTModelForSeq2SeqLM

    model = ORTModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_transformers=True).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    return (model, tokenizer)


def translate():

    # (model, tokenizer) = get_optimum_onnx_model()
    (model, tokenizer) = get_transformer_model()

    # Warm-up
    for i in range(10):
        encoded_zh = tokenizer(chinese_text, return_tensors="pt").to(device)
        generated_tokens = model.generate(
            **encoded_zh,
            forced_bos_token_id=tokenizer.get_lang_id("en"),
            max_length=256,
            num_beams=5,
        )
        result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        logging.debug(f"#{i}: {result}")

    start = time.time()
    for i in tqdm(range(loop)):
        encoded_zh = tokenizer(chinese_text, return_tensors="pt").to(device)
        generated_tokens = model.generate(
            **encoded_zh,
            forced_bos_token_id=tokenizer.get_lang_id("en"),
            max_length=256,
            num_beams=5,
        )
    end = time.time()
    total_time = end - start

    logging.info(f"total: {total_time}")
    logging.info(f"loop: {loop}")
    logging.info(f"avg(s): {total_time / loop}")
    logging.info(f"throughput(translation/s): {loop / total_time}")


if __name__ == "__main__":
    translate()

Here are the results that I got with a T4, ORTModel with IO binding V.S. PyTorch:

PyTorch Optimum(w/. io)
total(s) 442.00 323.40
loop 100 100
avg(s) 4.420 3.234
throughput(translation/s) 0.2262 0.3092

Can you test the snippet to see if you can get something similar on your end, or share your entire script so that I can try to reproduce your experiment?

@soocheolnoh
Copy link

@JingyaHuang Thanks for your response!

I'd try your script and I got similar results. (pytorch: 4.074 sec, optimum: 2.7825 sec for average time)
And I also ran the following script for mT5 model: (currently I'm using)

# -*- coding: utf-8 -*-

import logging
import time

import torch
from tqdm import tqdm
from transformers import AutoTokenizer

logging.basicConfig(level=logging.INFO)


# model_checkpoint = "facebook/m2m100_418M"
model_checkpoint = "K024/mt5-zh-ja-en-trimmed"

loop = 100


chinese_text = "zh2en: 机器学习是人工智能的一个分支。人工智能的研究历史有着一条从以“推理”为重点,到以“知识”为重点,再到以“学习”为重点的自然、清晰的脉络。显然,机器学习是实现人工智能的一个途径,即以机器学习为手段解决人工智能中的问题。机器学习在近30多年已发展为一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析(英语:Convex analysis)、计算复杂性理论等多门学科。机器学习理论主要是设计和分析一些让计算机可以自动“学习”的算法。机器学习算法是一类从数据中自动分析获得规律,并利用规律对未知数据进行预测的算法。因为学习算法中涉及了大量的统计学理论,机器学习与推断统计学联系尤为密切,也被称为统计学习理论。算法设计方面,机器学习理论关注可以实现的,行之有效的学习算法。很多推论问题属于无程序可循难度,所以部分的机器学习研究是开发容易处理的近似算法。机器学习已广泛应用于数据挖掘、计算机视觉、自然语言处理、生物特征识别、搜索引擎、医学诊断、检测信用卡欺诈、证券市场分析、DNA序列测序、语音和手写识别、战略游戏和机器人等领域"
logging.info(f"chinese_text is {chinese_text}")
logging.info(f"chinese_text length is {len(chinese_text)}")

device = torch.device("cuda:0")
logging.info(f"This test will use device: {device}")


def get_transformer_model():
    # from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer

    # model = M2M100ForConditionalGeneration.from_pretrained(model_checkpoint).to(device)
    # tokenizer = M2M100Tokenizer.from_pretrained(model_checkpoint)

    from transformers import MT5ForConditionalGeneration, T5Tokenizer

    model = MT5ForConditionalGeneration.from_pretrained(model_checkpoint).to(device)
    tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)
    return (model, tokenizer)


def get_optimum_onnx_model():
    from optimum.onnxruntime import ORTModelForSeq2SeqLM
    from transformers import T5Tokenizer

    model = ORTModelForSeq2SeqLM.from_pretrained(model_checkpoint, from_transformers=True).to(device)
    tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)
    return (model, tokenizer)


def translate():

    (model, tokenizer) = get_optimum_onnx_model()
    # (model, tokenizer) = get_transformer_model()

    # Warm-up
    for i in range(10):
        encoded_zh = tokenizer(chinese_text, return_tensors="pt").to(device)
        generated_tokens = model.generate(
            **encoded_zh,
            # forced_bos_token_id=tokenizer.get_lang_id("en"),
            max_length=256,
            num_beams=5,
        )
        result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        logging.debug(f"#{i}: {result}")

    start = time.time()
    for i in tqdm(range(loop)):
        encoded_zh = tokenizer(chinese_text, return_tensors="pt").to(device)
        generated_tokens = model.generate(
            **encoded_zh,
            # forced_bos_token_id=tokenizer.get_lang_id("en"),
            max_length=256,
            num_beams=5,
        )
    end = time.time()
    total_time = end - start

    logging.info(f"total: {total_time}")
    logging.info(f"loop: {loop}")
    logging.info(f"avg(s): {total_time / loop}")
    logging.info(f"throughput(translation/s): {loop / total_time}")


if __name__ == "__main__":
    translate()

The result is (done by V100):

pytorch optimum
total 289.4229 298.1077
loop 100 100
avg(s) 2.8942 2.9811
throughput 0.3455 0.3354

When I ran the pytorch model, I had no warnings, but in optimum I got the following warnings:

UserWarning: The sentencepiece tokenizer that you are converting to a fast tokenizer uses the byte fallback option which is not implemented in the fast tokenizers. In practice this means that the fast version of the tokenizer can produce unknown tokens whereas the sentencepiece version would have converted these unknown tokens into a sequence of byte tokens matching the original piece of text.

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if causal_mask.shape[1] < attention_mask.shape[1]:

Also the translation results are different like (optimum result seems to be strange):

  • pytorch Machine learning is a branch of artificial intelligence. The history of artificial intelligence has a natural, clear line from "rationality" to "knowledge" to "learning" which is clearly a way to achieve artificial intelligence. Machine learning has been developed for nearly 30 years as a cross-disciplinary discipline involving probability theory, statistics, proximity theory, convex analysis, and computational complexity theory. Machine learning theory primarily designs and analyzes algorithms that allow computers to "learn" automatically.
  • optimum Machine learning is a branch of artificial intelligence. It learning has been learning has been learning has learning is a learning Machine learning algorithms are learning algorithms are algorithms that automatically analyze patterns from data and use patterns to predict unknown data. Machine algorithms are a learning algorithms, learning algorithms are algorithms that are learning is learning learning is learning is learning is learning is learning is learning is learning learning learning learning is Machine learning Machine learning is a type Machine learning learning algorithms are algorithm learning algorithms learning learning Machine learning is learning Machine Machine learning learning learning learning algorithms are algorithm learning algorithms learning algorithms are algorithms that automatically analyze patterns from data and use patterns to predict unknown data. Since learning is a class of algorithms which learning algorithms that are learning is learning algorithms are learning algorithms are algorithms learning algorithms, learning algorithms are algorithms learning algorithms learning algorithm

I think the reason of the different between the average times could be just the results' different (because of the length of output tokens?) or other reasons. But I'm not sure why the result using opimum is strange.

@JingyaHuang
Copy link
Collaborator

Hi @soocheolnoh, thanks for testing.

From my side, for the mt5 model, the generated results are different w/. V.S. w/o. IO binding, which is not normal as IO Binding is not supposed to change the result(only the place to put the data should be different), it might be a bug on the post-processing of the outputs. I will take a closer look, and fix the beam search ASAP.

@JingyaHuang
Copy link
Collaborator

JingyaHuang commented Nov 10, 2022

Hi @soocheolnoh

The fix has been done, there was a bug on the output population, thanks for pointing it out.

Now you shall get the same translation result w/. or w/o. IOBinding.

# Transcript w/. IOBindng
['Machine learning is a branch of artificial intelligence. The history of artificial intelligence has a natural, clear line from "rationality" to "knowledge" to "learning" which is clearly a way to achieve artificial intelligence. Machine learning has been developed for nearly 30 years as a cross-disciplinary discipline involving probability theory, statistics, proximity theory, convex analysis, and computational complexity theory. Machine learning theory primarily designs and analyzes algorithms that allow computers to "learn" automatically.']

Also, share some performance numbers tested with the previous snippet here:
(PyTorch V.S Optimum, T4, warm_up_steps=10, loop=100, num_beam=5, max_length=256)

PyTorch Optimum(w/. io)
total(s) 260.7315 145.2276
loop 100 100
avg(s) / seq 2.6073 1.4523
throughput(translation/s) 0.3835 0.6886

The issue is closed, but feel free to reopen it or ping me if you have extra questions about IOBinding. @soocheolnoh @Matthieu-Tinycoaching Thanks again for helping us improve Optimum.

@soocheolnoh
Copy link

Thanks for the quick response!! @JingyaHuang

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working inference Related to Inference onnxruntime Related to ONNX Runtime
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants