Skip to content
This repository was archived by the owner on Apr 8, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions Dockerfile-onnxruntime
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Adapted from ONNXRuntime CUDA Dockerfile at https://github.com/microsoft/onnxruntime/blob/master/dockerfiles/Dockerfile.cuda

FROM nvidia/cuda:10.1-cudnn7-devel

ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=master

RUN apt-get update &&\
apt-get install -y sudo git bash

WORKDIR /code
ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/code/cmake-3.14.3-Linux-x86_64/bin:/opt/miniconda/bin:${PATH}
ENV LD_LIBRARY_PATH /opt/miniconda/lib:$LD_LIBRARY_PATH

# Prepare onnxruntime repository & build onnxruntime with CUDA
RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\
/bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\
cp onnxruntime/docs/Privacy.md /code/Privacy.md &&\
cp onnxruntime/ThirdPartyNotices.txt /code/ThirdPartyNotices.txt &&\
cp onnxruntime/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt &&\
cd onnxruntime &&\
/bin/sh ./build.sh --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_cuda --config Release --build_wheel --update --build --cmake_extra_defines ONNXRUNTIME_VERSION=$(cat ./VERSION_NUMBER) &&\
pip install /code/onnxruntime/build/Linux/Release/dist/*.whl &&\
cd .. &&\
rm -rf onnxruntime cmake-3.14.3-Linux-x86_64

# Clone FARM repositry and install the requirements
RUN git clone --depth 1 --branch 0.4.3 https://github.com/deepset-ai/farm.git
RUN pip install -e FARM
RUN pip install -r FARM/test/requirements.txt
5 changes: 5 additions & 0 deletions readme.rst
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,11 @@ Much of the heavy lifting is then handled behind the scenes to make it fast & si

.. image:: https://raw.githubusercontent.com/deepset-ai/FARM/master/docs/img/data_silo_no_bg_small.jpg

Inference Time Benchmarks
##########################

FARM has a configurable `test suite <https://github.com/deepset-ai/FARM/blob/master/test/benchmarks/README.md>`__ for benchmarking inference times with combinations of inference engine(PyTorch, `ONNXRuntime <https://github.com/microsoft/onnxruntime>`__), batch size, document length, maximum sequence length, and other parameters. `Here <https://docs.google.com/spreadsheets/d/1ak9Cxj1zcNBDtjf7qn2j_ydKDDzpBgWiyJ7cO-7BPvA/edit?usp=sharing>`__ is a benchmark for Question Answering inference with the current FARM version.

FAQ
####
**1. What language model shall I use for non-english NLP?**
Expand Down
31 changes: 31 additions & 0 deletions test/benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Inference Speed Benchmarks

FARM provides an automated speed benchmarking pipeline with options to parameterize the benchmarks with batch_size,
max sequence length, document size, and so on.

The pipeline is implemented using [pytest-benchmark](https://github.com/ionelmc/pytest-benchmark). The warmup/iterations for each benchmark are configurable and the
results can be exported to a JSON file.



## Question Answering

The `benchmarks/question_answering.py` file contains tests for inference with PyTorch(`test_question_answering_pytorch`)
and ONNXRuntime(`test_question_answering_onnx`).

The benchmarks are available [here](https://docs.google.com/spreadsheets/d/1ak9Cxj1zcNBDtjf7qn2j_ydKDDzpBgWiyJ7cO-7BPvA/edit?usp=sharing).

### Running Benchmark with Docker

#### GPU
For running benchmark on a GPU, bash into the Docker Image using ```docker run -it --gpus all deepset/farm-onnxruntime-gpu:0.4.3 bash```.
Once inside the container, execute ```cd FARM/test && pytest benchmarks/question_answering.py -k test_question_answering_pytorch --use_gpu --benchmark-json result.json```.

#### CPU
Bash into the Docker container with ```docker run -it deepset/farm-inference-api:0.4.3 bash``` and then execute
```cd test && pytest benchmarks/question_answering.py -k test_question_answering_pytorch --benchmark-json result.json```.

### Exporting results in CSV format

The results of benchmarks are exported to a `result.json` file in the `test` folder. To convert results to csv format,
execute `python benchmarks/convert_result_to_csv.py`.
23 changes: 23 additions & 0 deletions test/benchmarks/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from pathlib import Path

import pytest

from farm.infer import Inferencer
from farm.modeling.adaptive_model import AdaptiveModel


@pytest.fixture(scope="session")
def onnx_adaptive_model_qa(use_gpu, num_processes):
model_name_or_path = "deepset/bert-base-cased-squad2"
onnx_model_export_path = Path("benchmarks/onnx-export")
if not (onnx_model_export_path / "model.onnx").is_file():
model = AdaptiveModel.convert_from_transformers(
model_name_or_path, device="cpu", task_type="question_answering"
)
model.convert_to_onnx(onnx_model_export_path)

model = Inferencer.load(
onnx_model_export_path, task_type="question_answering", batch_size=1, num_processes=num_processes, gpu=use_gpu
)

return model
14 changes: 14 additions & 0 deletions test/benchmarks/convert_result_to_csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import json
import csv

with open("result.json") as f:
results = json.load(f)

with open("result.csv", "w") as f:
fieldnames = list(results["benchmarks"][0]["params"].keys())
fieldnames.append("time")
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()

for benchmark in results["benchmarks"]:
writer.writerow({"time": benchmark["stats"]["total"], **benchmark["params"]})
50 changes: 50 additions & 0 deletions test/benchmarks/question_answering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import logging

import pytest
import torch

logger = logging.getLogger(__name__)


@pytest.mark.parametrize("max_seq_len", [128, 256, 384])
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
@pytest.mark.parametrize("document_size", [10_000, 100_000])
@pytest.mark.parametrize("num_processes", [0], scope="session")
def test_question_answering_pytorch(adaptive_model_qa, benchmark, max_seq_len, batch_size, use_gpu, document_size):
if use_gpu and not torch.cuda.is_available():
pytest.skip("Skipping benchmarking on GPU as it not available.")

if not use_gpu and document_size > 10_000:
pytest.skip("Document size is large for CPU")

with open("benchmarks/sample_file.txt") as f:
context = f.read()[:document_size]
QA_input = [{"qas": ["When were the first traces of Human life found in France?"], "context": context}]

adaptive_model_qa.batch_size = batch_size
adaptive_model_qa.max_seq_len = max_seq_len
benchmark.pedantic(
target=adaptive_model_qa.inference_from_dicts, args=(QA_input,), warmup_rounds=1, iterations=3,
)


@pytest.mark.parametrize("max_seq_len", [128, 256, 384])
@pytest.mark.parametrize("batch_size", [1, 4, 16, 64])
@pytest.mark.parametrize("document_size", [10_000, 100_000])
@pytest.mark.parametrize("num_processes", [0], scope="session")
def test_question_answering_onnx(onnx_adaptive_model_qa, benchmark, max_seq_len, batch_size, use_gpu, document_size):
if use_gpu and not torch.cuda.is_available():
pytest.skip("Skipping benchmarking on GPU as it not available.")

if not use_gpu and document_size > 10_000:
pytest.skip("Document size is large for CPU")

with open("benchmarks/sample_file.txt") as f:
context = f.read()[:document_size]
QA_input = [{"qas": ["When were the first traces of Human life found in France?"], "context": context}]

onnx_adaptive_model_qa.batch_size = batch_size
onnx_adaptive_model_qa.max_seq_len = max_seq_len
benchmark.pedantic(
target=onnx_adaptive_model_qa.inference_from_dicts, args=(QA_input,), warmup_rounds=1, iterations=3
)
Loading