Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/gpu-integ-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
tensorflow-integration-test:
needs:
- start-runner
# - pytorch-integration-test
- pytorch-integration-test
runs-on: ${{ needs.start-runner.outputs.label }} # run the job on the newly created runner
env:
AWS_REGION: us-east-1
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ docker build -t starlette-transformers:gpu -f dockerfiles/starlette/tensorflow/D
docker run -ti -p 5000:5000 -e HF_MODEL_ID=distilbert-base-uncased-distilled-squad -e HF_TASK=question-answering starlette-transformers:cpu
```


3. Send request. The API schema is the same as from the [inference API](https://huggingface.co/docs/api-inference/detailed_parameters)

```bash
Expand Down
4 changes: 2 additions & 2 deletions dockerfiles/starlette/pytorch/Dockerfile.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ FROM huggingface/transformers-inference:4.24.0-pt1.13-cpu
COPY starlette_requirements.txt /tmp/requirements.txt
RUN pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt

# Think about a better solution -> base contaienr has pt 1.11. thats why need below 0.13
RUN pip install --no-cache-dir sentence_transformers torchvision~="0.12.0"
# Think about a better solution -> base contaienr has pt 1.13. thats why need below 0.14
RUN pip install --no-cache-dir sentence_transformers torchvision~="0.14.0" diffusers=="0.8.1" accelerate=="0.14.0"

# copy application
COPY src/huggingface_inference_toolkit huggingface_inference_toolkit
Expand Down
7 changes: 4 additions & 3 deletions dockerfiles/starlette/pytorch/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ FROM huggingface/transformers-inference:4.24.0-pt1.13-cuda11.6
COPY starlette_requirements.txt /tmp/requirements.txt
RUN pip install --no-cache-dir -r /tmp/requirements.txt && rm /tmp/requirements.txt

# Think about a better solution -> base contaienr has pt 1.11. thats why need below 0.13
RUN pip install --no-cache-dir sentence_transformers torchvision~="0.12.0"
# Think about a better solution -> base contaienr has pt 1.13. thats why need below 0.14
RUN pip install --no-cache-dir sentence_transformers torchvision~="0.14.0" diffusers=="0.8.1" accelerate=="0.14.0"

# copy application
COPY src/huggingface_inference_toolkit huggingface_inference_toolkit
COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starlette.py

# run app
ENTRYPOINT ["uvicorn", "webservice_starlette:app", "--host", "0.0.0.0", "--port", "5000"]
ENTRYPOINT ["uvicorn", "webservice_starlette:app", "--host", "0.0.0.0", "--port", "5000"]

3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
extras = {}

extras["st"] = ["sentence_transformers"]
extras["diffusers"] = ["diffusers==0.8.1", "accelerate==0.14.0"]


# Hugging Face specific dependencies
# framework specific dependencies
extras["torch"] = ["torch>=1.8.0", "torchaudio"]
extras["tensorflow"] = ["tensorflow>=2.4.0"]
extras["tensorflow"] = ["tensorflow==2.9.0"]
# test and quality
extras["test"] = [
"pytest",
Expand Down
56 changes: 56 additions & 0 deletions src/huggingface_inference_toolkit/diffusers_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import importlib.util
import json
import os


_diffusers = importlib.util.find_spec("diffusers") is not None


def is_diffusers_available():
return _diffusers


if is_diffusers_available():
import torch

from diffusers import StableDiffusionPipeline


def check_supported_pipeline(model_dir):
try:
with open(os.path.join(model_dir, "model_index.json")) as json_file:
data = json.load(json_file)
if data["_class_name"] == "StableDiffusionPipeline":
return True
else:
return False
except Exception:
return False


class DiffusersPipelineImageToText:
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
self.pipeline = StableDiffusionPipeline.from_pretrained(model_dir, torch_dtype=torch.float16)
self.pipeline.to(device)

def __call__(self, prompt, **kwargs):

if kwargs:
out = self.pipeline(prompt, **kwargs)
else:
out = self.pipeline(prompt)

# TODO: return more than 1 image if requested
return out.images[0]


DIFFUSERS_TASKS = {
"text-to-image": DiffusersPipelineImageToText,
}


def get_diffusers_pipeline(task=None, model_dir=None, device=-1, **kwargs):
"""Get a pipeline for Diffusers models."""
device = "cuda" if device == 0 else "cpu"
pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device)
return pipeline
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ def deserialize(body):
return {"inputs": bytes(body)}

@staticmethod
def serialize(body):
def serialize(body, accept=None):
raise NotImplementedError("Audio serialization not implemented")
9 changes: 9 additions & 0 deletions src/huggingface_inference_toolkit/serialization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,12 @@ def get_deserializer(content_type):
raise Exception(
f'Content type "{content_type}" not supported. Supported content types are: {", ".join(list(content_type_mapping.keys()))}'
)

@staticmethod
def get_serializer(accept):
if accept in content_type_mapping:
return content_type_mapping[accept]
else:
raise Exception(
f'Accept type "{accept}" not supported. Supported accept types are: {", ".join(list(content_type_mapping.keys()))}'
)
10 changes: 8 additions & 2 deletions src/huggingface_inference_toolkit/serialization/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,11 @@ def deserialize(body):
return {"inputs": image}

@staticmethod
def serialize(body):
raise NotImplementedError("Image serialization not implemented")
def serialize(image, accept=None):
if isinstance(image, Image.Image):
img_byte_arr = BytesIO()
image.save(img_byte_arr, format=accept.split("/")[-1].upper())
img_byte_arr = img_byte_arr.getvalue()
return img_byte_arr
else:
raise ValueError(f"Can only serialize PIL.Image.Image, got {type(image)}")
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def deserialize(body):
return orjson.loads(body)

@staticmethod
def serialize(body):
def serialize(body, accept=None):
return orjson.dumps(body, option=orjson.OPT_SERIALIZE_NUMPY, default=default)


Expand Down
13 changes: 11 additions & 2 deletions src/huggingface_inference_toolkit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,19 @@
from pathlib import Path
from typing import Optional, Union

from huggingface_hub import HfApi
from huggingface_hub import HfApi, login
from huggingface_hub.file_download import cached_download, hf_hub_url
from huggingface_hub.utils import filter_repo_objects
from transformers import pipeline
from transformers.file_utils import is_tf_available, is_torch_available
from transformers.pipelines import Conversation, Pipeline

from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME
from huggingface_inference_toolkit.diffusers_utils import (
check_supported_pipeline,
get_diffusers_pipeline,
is_diffusers_available,
)
from huggingface_inference_toolkit.sentence_transformers_utils import (
get_sentence_transformers_pipeline,
is_sentence_transformers_available,
Expand Down Expand Up @@ -127,6 +132,9 @@ def _load_repository_from_hf(
"""
Load a model from huggingface hub.
"""
if hf_hub_token is not None:
login(token=hf_hub_token)

if framework is None:
framework = _get_framework()

Expand All @@ -146,7 +154,6 @@ def _load_repository_from_hf(
repo_id=repository_id,
repo_type="model",
revision=revision,
token=hf_hub_token,
)
# apply regex to filter out non-framework specific weights if args.framework is set
filtered_repo_files = filter_repo_objects(
Expand Down Expand Up @@ -267,6 +274,8 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
"sentence-ranking",
]:
hf_pipeline = get_sentence_transformers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs)
elif is_diffusers_available() and check_supported_pipeline(model_dir) and task == "text-to-image":
hf_pipeline = get_diffusers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs)
else:
hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs)

Expand Down
27 changes: 7 additions & 20 deletions src/huggingface_inference_toolkit/webservice_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,14 @@ async def predict(request):
# log request time
# TODO: repalce with middleware
logger.info(f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms")

# response extracts content from request
accept = request.headers.get("accept", None)
if accept is None or accept == "*/*":
accept = "application/json"
# deserialized and resonds with json
return Response(Jsoner.serialize(pred), media_type="application/json")
serialized_response_body = ContentType.get_serializer(accept).serialize(pred, accept)
return Response(serialized_response_body, media_type=accept)
except Exception as e:
logger.error(e)
return Response(Jsoner.serialize({"error": str(e)}), status_code=400, media_type="application/json")
Expand All @@ -98,22 +104,3 @@ async def predict(request):
],
on_startup=[some_startup_task],
)


# for pegasus it was async
# 1.2rps at 20 with 17s latency
# 1rps at 1 user with 930ms latency

# for pegasus it was sync
# 1.2rps at 20 with 17s latency
# 1rps at 1 user with 980ms latency
# health is blocking with 17s latency


# for tiny it was async
# 107.7rps at 500 with 4.7s latency
# 8.5rps at 1 user with 120ms latency

# for tiny it was sync
# 109rps at 500 with 4.6s latency
# 8.5rps at 1 user with 120ms latency
2 changes: 1 addition & 1 deletion starlette_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ orjson
starlette
uvicorn
pandas
huggingface_hub>=0.9.0
huggingface_hub>=0.11.0
8 changes: 8 additions & 0 deletions tests/integ/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
validate_summarization,
validate_text2text_generation,
validate_text_generation,
validate_text_to_image,
validate_translation,
validate_zero_shot_classification,
)
Expand Down Expand Up @@ -101,6 +102,10 @@
"pytorch": "cross-encoder/ms-marco-MiniLM-L-6-v2",
"tensorflow": None,
},
"text-to-image": {
"pytorch": "hf-internal-testing/tiny-stable-diffusion-torch",
"tensorflow": None,
},
}


Expand Down Expand Up @@ -156,6 +161,7 @@
},
"sentence-embeddings": {"inputs": "Lets create an embedding"},
"sentence-ranking": {"inputs": ["Lets create an embedding", "Lets create an embedding"]},
"text-to-image": {"inputs": "a man on a horse jumps over a broken down airplane."},
}

task2output = {
Expand Down Expand Up @@ -204,6 +210,7 @@
"sentence-similarity": {"similarities": ""},
"sentence-embeddings": {"embeddings": ""},
"sentence-ranking": {"scores": ""},
"text-to-image": bytes,
}


Expand All @@ -229,4 +236,5 @@
"sentence-similarity": validate_zero_shot_classification,
"sentence-embeddings": validate_zero_shot_classification,
"sentence-ranking": validate_zero_shot_classification,
"text-to-image": validate_text_to_image,
}
7 changes: 7 additions & 0 deletions tests/integ/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def verify_task(container: DockerClient, task: str, port: int = 5000, framework:
prediction = requests.post(
f"{BASE_URL}", data=task2input[task], headers={"content-type": "audio/x-audio"}
).json()
elif task == "text-to-image":
prediction = requests.post(f"{BASE_URL}", json=input, headers={"accept": "image/png"}).content
else:
prediction = requests.post(f"{BASE_URL}", json=input).json()
assert task2validation[task](result=prediction, snapshot=task2output[task]) is True
Expand Down Expand Up @@ -90,6 +92,8 @@ def verify_task(container: DockerClient, task: str, port: int = 5000, framework:
"sentence-similarity",
"sentence-embeddings",
"sentence-ranking",
# diffusers
"text-to-image",
],
)
def test_pt_container_remote_model(task) -> None:
Expand All @@ -111,6 +115,7 @@ def test_pt_container_remote_model(task) -> None:
device_requests=device_request,
)
# time.sleep(5)

verify_task(container, task, port)
container.stop()
container.remove()
Expand Down Expand Up @@ -143,6 +148,8 @@ def test_pt_container_remote_model(task) -> None:
"sentence-similarity",
"sentence-embeddings",
"sentence-ranking",
# diffusers
"text-to-image",
],
)
def test_pt_container_local_model(task) -> None:
Expand Down
5 changes: 5 additions & 0 deletions tests/integ/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,8 @@ def validate_automatic_speech_recognition(result=None, snapshot=None):
def validate_object_detection(result=None, snapshot=None):
assert result[0].keys() == snapshot[0].keys()
return True


def validate_text_to_image(result=None, snapshot=None):
assert isinstance(result, snapshot)
return True
40 changes: 40 additions & 0 deletions tests/unit/test_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import tempfile
from PIL import Image
from transformers.testing_utils import require_torch, slow

from huggingface_inference_toolkit.handler import get_inference_handler_either_custom_or_default_handler
from huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, DiffusersPipelineImageToText
from huggingface_inference_toolkit.utils import _load_repository_from_hf, get_pipeline


@require_torch
def test_get_diffusers_pipeline():
with tempfile.TemporaryDirectory() as tmpdirname:
storage_dir = _load_repository_from_hf(
"hf-internal-testing/tiny-stable-diffusion-torch", tmpdirname, framework="pytorch"
)
pipe = get_pipeline("text-to-image", storage_dir.as_posix())
assert isinstance(pipe, DiffusersPipelineImageToText)


@slow
@require_torch
def test_pipe_on_gpu():
with tempfile.TemporaryDirectory() as tmpdirname:
storage_dir = _load_repository_from_hf(
"hf-internal-testing/tiny-stable-diffusion-torch", tmpdirname, framework="pytorch"
)
pipe = get_pipeline("text-to-image", storage_dir.as_posix())
assert pipe.device.type == "cuda"


@require_torch
def test_text_to_image_task():
with tempfile.TemporaryDirectory() as tmpdirname:
storage_dir = _load_repository_from_hf(
"hf-internal-testing/tiny-stable-diffusion-torch", tmpdirname, framework="pytorch"
)
pipe = get_pipeline("text-to-image", storage_dir.as_posix())
res = pipe("Lets create an embedding")
assert isinstance(res, Image.Image)
5 changes: 5 additions & 0 deletions tests/unit/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ def test_json_image_serialization():
Jsoner.serialize(t)


def test_image_serialization():
image = Image.new("RGB", (60, 30), color="red")
Imager.serialize(image, accept="image/png")


def test_json_deserialization():
raw_content = b'{\n\t"inputs": "i like you"\n}'
assert {"inputs": "i like you"} == Jsoner.deserialize(raw_content)
Expand Down