Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add fp16 inference support (trt) #875

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ jobs:
pip install --no-cache-dir "server/[onnx]"
pip install --no-cache-dir "server/[transformers]"
pip install --no-cache-dir "server/[search]"
pip install open-clip-torch==2.7.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please specify the open-clip-torch and tensorrt version in setup

- name: Test
id: test
run: |
Expand Down Expand Up @@ -159,11 +160,13 @@ jobs:
pip install -e "server/[tensorrt]"
pip install -e "server/[onnx]"
pip install -e "server/[transformers]"
pip install nvidia-tensorrt==8.4.1.5
{
pip install -e "server/[flash-attn]"
} || {
echo "flash attention was not installed."
}
pip install open-clip-torch==2.7.0
- name: Test
id: test
run: |
Expand Down
6 changes: 5 additions & 1 deletion server/clip_server/executors/clip_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
num_worker_preprocess: int = 4,
minibatch_size: int = 32,
access_paths: str = '@r',
dtype: Optional[str] = 'fp32',
**kwargs,
):
"""
Expand All @@ -36,6 +37,7 @@ def __init__(
number if you encounter OOM errors.
:param access_paths: The access paths to traverse on the input documents to get the images and texts to be
processed. Visit https://docarray.jina.ai/fundamentals/documentarray/access-elements for more details.
:param dtype: inference data type, defaults to 'fp32'.
"""
super().__init__(**kwargs)

Expand All @@ -51,6 +53,7 @@ def __init__(
self._access_paths = kwargs['traversal_paths']

self._device = device
self._dtype = dtype

import torch

Expand All @@ -63,7 +66,7 @@ def __init__(
torch.cuda.is_available()
), "CUDA/GPU is not available on Pytorch. Please check your CUDA installation"

self._model = CLIPTensorRTModel(name)
self._model = CLIPTensorRTModel(name=name, dtype=dtype)

self._model.start_engines()

Expand All @@ -85,6 +88,7 @@ def _preproc_images(self, docs: 'DocumentArray', drop_image_content: bool):
device=self._device,
return_np=False,
drop_image_content=drop_image_content,
dtype=self._dtype,
)

def _preproc_texts(self, docs: 'DocumentArray'):
Expand Down
37 changes: 25 additions & 12 deletions server/clip_server/model/clip_trt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Dict
from typing import Dict, Optional

try:
import tensorrt as trt
Expand Down Expand Up @@ -51,6 +51,7 @@ class CLIPTensorRTModel(BaseCLIPModel):
def __init__(
self,
name: str,
dtype: Optional[str] = 'fp32',
):
super().__init__(name)

Expand All @@ -59,23 +60,35 @@ def __init__(
f'~/.cache/clip/{name.replace("/", "-").replace("::", "-")}'
)

self._textual_path = os.path.join(
cache_dir,
f'textual.{ONNX_MODELS[name][0][1]}.trt',
)
self._visual_path = os.path.join(
cache_dir,
f'visual.{ONNX_MODELS[name][1][1]}.trt',
)
if dtype == 'fp16':
self._textual_path = os.path.join(
cache_dir,
f'textual.{ONNX_MODELS[name][0][1]}.fp16.trt',
)
self._visual_path = os.path.join(
cache_dir,
f'visual.{ONNX_MODELS[name][1][1]}.fp16.trt',
)
else:
self._textual_path = os.path.join(
cache_dir,
f'textual.{ONNX_MODELS[name][0][1]}.trt',
)
self._visual_path = os.path.join(
cache_dir,
f'visual.{ONNX_MODELS[name][1][1]}.trt',
)

if not os.path.exists(self._textual_path) or not os.path.exists(
self._visual_path
):
from clip_server.model.clip_onnx import CLIPOnnxModel

fp16 = dtype == 'fp16'

trt_logger: Logger = trt.Logger(trt.Logger.ERROR)
runtime: Runtime = trt.Runtime(trt_logger)
onnx_model = CLIPOnnxModel(name)
onnx_model = CLIPOnnxModel(name=name, dtype=dtype)

visual_engine = build_engine(
runtime=runtime,
Expand All @@ -95,7 +108,7 @@ def __init__(
onnx_model.image_size,
),
workspace_size=10000 * 1024 * 1024,
fp16=False,
fp16=fp16,
int8=False,
)
save_engine(visual_engine, self._visual_path)
Expand All @@ -108,7 +121,7 @@ def __init__(
optimal_shape=(768, 77),
max_shape=(1024, 77),
workspace_size=10000 * 1024 * 1024,
fp16=False,
fp16=fp16,
int8=False,
)
save_engine(text_engine, self._textual_path)
Expand Down
118 changes: 60 additions & 58 deletions server/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,91 +4,93 @@
from setuptools import find_packages, setup

if sys.version_info < (3, 7, 0):
raise OSError(f'CLIP-as-service requires Python >=3.7, but yours is {sys.version}')
raise OSError(f"CLIP-as-service requires Python >=3.7, but yours is {sys.version}")

try:
pkg_name = 'clip-server'
pkg_name = "clip-server"
libinfo_py = path.join(
path.dirname(__file__), pkg_name.replace('-', '_'), '__init__.py'
path.dirname(__file__), pkg_name.replace("-", "_"), "__init__.py"
)
libinfo_content = open(libinfo_py, 'r', encoding='utf8').readlines()
version_line = [l.strip() for l in libinfo_content if l.startswith('__version__')][
libinfo_content = open(libinfo_py, "r", encoding="utf8").readlines()
version_line = [l.strip() for l in libinfo_content if l.startswith("__version__")][
0
]
exec(version_line) # gives __version__
except FileNotFoundError:
__version__ = '0.0.0'
__version__ = "0.0.0"

try:
with open('../README.md', encoding='utf8') as fp:
with open("../README.md", encoding="utf8") as fp:
_long_description = fp.read()
except FileNotFoundError:
_long_description = ''
_long_description = ""

setup(
name=pkg_name,
packages=find_packages(),
version=__version__,
include_package_data=True,
description='Embed images and sentences into fixed-length vectors via CLIP',
author='Jina AI',
author_email='hello@jina.ai',
license='Apache 2.0',
url='https://github.com/jina-ai/clip-as-service',
download_url='https://github.com/jina-ai/clip-as-service/tags',
description="Embed images and sentences into fixed-length vectors via CLIP",
author="Jina AI",
author_email="hello@jina.ai",
license="Apache 2.0",
url="https://github.com/jina-ai/clip-as-service",
download_url="https://github.com/jina-ai/clip-as-service/tags",
long_description=_long_description,
long_description_content_type='text/markdown',
long_description_content_type="text/markdown",
zip_safe=False,
setup_requires=['setuptools>=18.0', 'wheel'],
setup_requires=["setuptools>=18.0", "wheel"],
install_requires=[
'ftfy',
'torch',
'regex',
'torchvision<=0.13.0' if sys.version_info <= (3, 7, 2) else 'torchvision',
'jina>=3.12.0',
'prometheus-client',
'open_clip_torch>=2.7.0',
"ftfy==6.1.1",
"torch==1.13.0",
"regex==2022.10.31",
"torchvision<=0.13.0"
if sys.version_info <= (3, 7, 2)
else "torchvision==0.14.0",
"jina==3.12.0",
"prometheus-client==0.15.0",
"open_clip_torch==2.7.0",
],
extras_require={
'onnx': [
'onnxruntime',
'onnx',
'onnxmltools',
"onnx": [
"onnxruntime==1.13.1",
"onnx==1.12.0",
"onnxmltools==1.11.1",
]
+ (['onnxruntime-gpu>=1.8.0'] if sys.platform != 'darwin' else []),
'tensorrt': ['nvidia-tensorrt'],
'transformers': ['transformers>=4.16.2'],
'search': ['annlite>=0.3.10'],
'flash-attn': ['flash-attn'],
+ (["onnxruntime-gpu==1.13.1"] if sys.platform != "darwin" else []),
"tensorrt": ["nvidia-tensorrt==8.4.1.5"],
"transformers": ["transformers==4.25.1"],
"search": ["annlite>=0.3.10"],
"flash-attn": ["flash-attn==0.2.4"],
},
classifiers=[
'Development Status :: 5 - Production/Stable',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Unix Shell',
'Environment :: Console',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
'Topic :: Database :: Database Engines/Servers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Internet :: WWW/HTTP :: Indexing/Search',
'Topic :: Scientific/Engineering :: Image Recognition',
'Topic :: Multimedia :: Video',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Mathematics',
'Topic :: Software Development',
'Topic :: Software Development :: Libraries',
'Topic :: Software Development :: Libraries :: Python Modules',
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Intended Audience :: Education",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Unix Shell",
"Environment :: Console",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
"Topic :: Database :: Database Engines/Servers",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Internet :: WWW/HTTP :: Indexing/Search",
"Topic :: Scientific/Engineering :: Image Recognition",
"Topic :: Multimedia :: Video",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
project_urls={
'Documentation': 'https://clip-as-service.jina.ai',
'Source': 'https://github.com/jina-ai/clip-as-service/',
'Tracker': 'https://github.com/jina-ai/clip-as-service/issues',
"Documentation": "https://clip-as-service.jina.ai",
"Source": "https://github.com/jina-ai/clip-as-service/",
"Tracker": "https://github.com/jina-ai/clip-as-service/issues",
},
keywords='jina openai clip deep-learning cross-modal multi-modal neural-search',
keywords="jina openai clip deep-learning cross-modal multi-modal neural-search",
)
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ def make_trt_flow(port_generator, request):
yield f


@pytest.fixture(scope='session', params=['tensorrt'])
def make_trt_flow_fp16(port_generator, request):
from clip_server.executors.clip_tensorrt import CLIPEncoder

f = Flow(port=port_generator()).add(
name=request.param, uses=CLIPEncoder, uses_with={'dtype': 'fp16'}
)
with f:
yield f


@pytest.fixture(params=['torch'])
def make_search_flow(tmpdir, port_generator, request):
from clip_server.executors.clip_torch import CLIPEncoder
Expand Down
33 changes: 33 additions & 0 deletions tests/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,39 @@ def test_docarray_inputs(make_trt_flow, inputs):
assert inputs[0] is r[0]


@pytest.mark.gpu
@pytest.mark.parametrize(
'inputs',
[
[Document(text='hello, world'), Document(text='goodbye, world')],
DocumentArray([Document(text='hello, world'), Document(text='goodbye, world')]),
lambda: (Document(text='hello, world') for _ in range(10)),
DocumentArray(
[
Document(uri='https://docarray.jina.ai/_static/favicon.png'),
Document(
uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
),
Document(text='hello, world'),
Document(
uri=f'{os.path.dirname(os.path.abspath(__file__))}/img/00000.jpg'
).load_uri_to_image_tensor(),
]
),
DocumentArray.from_files(
f'{os.path.dirname(os.path.abspath(__file__))}/**/*.jpg'
),
],
)
def test_docarray_inputs_fp16(make_trt_flow_fp16, inputs):
c = Client(server=f'grpc://0.0.0.0:{make_trt_flow_fp16.port}')
r = c.encode(inputs if not callable(inputs) else inputs())
assert isinstance(r, DocumentArray)
assert r.embeddings.shape
if hasattr(inputs, '__len__'):
assert inputs[0] is r[0]


@pytest.mark.gpu
@pytest.mark.asyncio
@pytest.mark.parametrize(
Expand Down