Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ docker build -t starlette-transformers:gpu -f dockerfiles/tensorflow/gpu/Dockerf
```bash
docker run -ti -p 5000:5000 -e HF_MODEL_ID=distilbert-base-uncased-distilled-squad -e HF_TASK=question-answering starlette-transformers:cpu
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=nlpconnect/vit-gpt2-image-captioning -e HF_TASK=image-to-text starlette-transformers:gpu
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=echarlaix/tiny-random-stable-diffusion-xl -e HF_TASK=text-to-image starlette-transformers:gpu
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=stabilityai/stable-diffusion-xl-base-1.0 -e HF_TASK=text-to-image starlette-transformers:gpu
docker run -ti -p 5000:5000 -e HF_MODEL_DIR=/repository -v $(pwd)/distilbert-base-uncased-emotion:/repository starlette-transformers:cpu
```

Expand Down
2 changes: 1 addition & 1 deletion dockerfiles/pytorch/cpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ RUN apt-get update \
# install micromamba
ENV MAMBA_ROOT_PREFIX=/opt/conda
ENV PATH=/opt/conda/bin:$PATH
RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
&& touch /root/.bashrc \
&& ./bin/micromamba shell init -s bash -p /opt/conda \
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc
Expand Down
7 changes: 4 additions & 3 deletions dockerfiles/pytorch/cpu/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ dependencies:
- python=3.9.13
- pytorch::pytorch=1.13.1=py3.9_cpu_0
- pip:
- transformers[sklearn,sentencepiece,audio,vision]==4.27.2
- transformers[sklearn,sentencepiece,audio,vision]==4.31.0
- sentence_transformers==2.2.2
- torchvision==0.14.1
- diffusers==0.14.0
- accelerate==0.17.1
- diffusers==0.19.3
- accelerate==0.21.0
- safetensors
2 changes: 1 addition & 1 deletion dockerfiles/pytorch/gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ENV MAMBA_ROOT_PREFIX=/opt/conda
ENV PATH=/opt/conda/bin:$PATH
ENV LD_LIBRARY_PATH="/opt/conda/lib:${LD_LIBRARY_PATH}"

RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
&& touch /root/.bashrc \
&& ./bin/micromamba shell init -s bash -p /opt/conda \
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc
Expand Down
5 changes: 3 additions & 2 deletions dockerfiles/pytorch/gpu/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@ dependencies:
- transformers[sklearn,sentencepiece,audio,vision]==4.31.0
- sentence_transformers==2.2.2
- torchvision==0.14.1
- diffusers==0.18.2
- accelerate==0.21.0
- diffusers==0.19.3
- accelerate==0.21.0
- safetensors
2 changes: 1 addition & 1 deletion dockerfiles/tensorflow/cpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ RUN apt-get update \
# install micromamba
ENV MAMBA_ROOT_PREFIX=/opt/conda
ENV PATH=/opt/conda/bin:$PATH
RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
&& touch /root/.bashrc \
&& ./bin/micromamba shell init -s bash -p /opt/conda \
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc
Expand Down
2 changes: 1 addition & 1 deletion dockerfiles/tensorflow/gpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ ENV MAMBA_ROOT_PREFIX=/opt/conda
ENV PATH=/opt/conda/bin:$PATH
ENV LD_LIBRARY_PATH="/opt/conda/lib:${LD_LIBRARY_PATH}"

RUN curl -L https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
RUN curl -L https://micro.mamba.pm/api/micromamba/linux-64/latest | tar -xj "bin/micromamba" \
&& touch /root/.bashrc \
&& ./bin/micromamba shell init -s bash -p /opt/conda \
&& grep -v '[ -z "\$PS1" ] && return' /root/.bashrc > /opt/conda/bashrc
Expand Down
66 changes: 27 additions & 39 deletions src/huggingface_inference_toolkit/diffusers_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib.util
import json
import os
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO)

_diffusers = importlib.util.find_spec("diffusers") is not None

Expand All @@ -11,60 +13,46 @@ def is_diffusers_available():

if is_diffusers_available():
import torch
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline


def check_supported_pipeline(model_dir):
try:
with open(os.path.join(model_dir, "model_index.json")) as json_file:
data = json.load(json_file)
if data["_class_name"] == "StableDiffusionPipeline":
return True
else:
return False
except Exception:
return False
from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionPipeline


class DiffusersPipelineImageToText:
class IEAutoPipelineForText2Image:
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
self.pipeline = StableDiffusionPipeline.from_pretrained(model_dir, torch_dtype=torch.float16)
dtype = torch.float32
if device == "cuda":
dtype = torch.float16
device_map = "auto" if device == "cuda" else None

self.pipeline = AutoPipelineForText2Image.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map)
# try to use DPMSolverMultistepScheduler
try:
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
except Exception:
pass
if isinstance(self.pipeline, StableDiffusionPipeline):
try:
self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config)
except Exception:
pass
self.pipeline.to(device)

def __call__(
self,
prompt,
num_inference_steps=25,
guidance_scale=7.5,
num_images_per_prompt=1,
height=None,
width=None,
negative_prompt=None,
**kwargs,
):
# TODO: add support for more images (Reason is correct output)
num_images_per_prompt = 1
if "num_images_per_prompt" in kwargs:
kwargs.pop("num_images_per_prompt")
logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.")

# Call pipeline with parameters
out = self.pipeline(
prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images_per_prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
)

if self.pipeline.device.type == "cuda":
with torch.autocast("cuda"):
out = self.pipeline(prompt, num_images_per_prompt=1)
else:
out = self.pipeline(prompt, num_images_per_prompt=1)
return out.images[0]


DIFFUSERS_TASKS = {
"text-to-image": DiffusersPipelineImageToText,
"text-to-image": IEAutoPipelineForText2Image,
}


Expand Down
32 changes: 14 additions & 18 deletions src/huggingface_inference_toolkit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from pathlib import Path
from typing import Optional, Union

from huggingface_hub import login, snapshot_download
from huggingface_hub import HfApi, login, snapshot_download
from transformers import WhisperForConditionalGeneration, pipeline
from transformers.file_utils import is_tf_available, is_torch_available
from transformers.pipelines import Conversation, Pipeline

from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME
from huggingface_inference_toolkit.diffusers_utils import (
check_supported_pipeline,
get_diffusers_pipeline,
is_diffusers_available,
)
Expand Down Expand Up @@ -46,11 +45,12 @@ def is_optimum_available():
"pt": "pytorch*",
"flax": "flax*",
"rust": "rust*",
"onnx": "*onnx",
"onnx": "*onnx*",
"safetensors": "*safetensors",
"coreml": "*mlmodel",
"tflite": "*tflite",
"savedmodel": "*tar.gz",
"openvino": "*openvino*",
"ckpt": "*ckpt",
}

Expand All @@ -59,18 +59,8 @@ def create_artifact_filter(framework):
"""
Returns a list of regex pattern based on the DL Framework. which will be to used to ignore files when downloading
"""
ignore_regex_list = [
"pytorch*",
"tf*",
"flax*",
"rust*",
"*onnx",
"*safetensors",
"*mlmodel",
"*tflite",
"*tar.gz",
"*ckpt",
]
ignore_regex_list = list(set(framework2weight.values()))

pattern = framework2weight.get(framework, None)
if pattern in ignore_regex_list:
ignore_regex_list.remove(pattern)
Expand Down Expand Up @@ -157,6 +147,12 @@ def _load_repository_from_hf(
if not target_dir.exists():
target_dir.mkdir(parents=True)

# check if safetensors weights are available
if framework == "pytorch":
files = HfApi().model_info(repository_id).siblings
if any(f.rfilename.endswith("safetensors") for f in files):
framework = "safetensors"

# create regex to only include the framework specific weights
ignore_regex = create_artifact_filter(framework)
logger.info(f"Ignore regex pattern for files, which are not downloaded: { ', '.join(ignore_regex) }")
Expand Down Expand Up @@ -259,7 +255,7 @@ def get_pipeline(task: str, model_dir: Path, **kwargs) -> Pipeline:
"sentence-ranking",
]:
hf_pipeline = get_sentence_transformers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs)
elif is_diffusers_available() and check_supported_pipeline(model_dir) and task == "text-to-image":
elif is_diffusers_available() and task == "text-to-image":
hf_pipeline = get_diffusers_pipeline(task=task, model_dir=model_dir, device=device, **kwargs)
else:
hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs)
Expand Down Expand Up @@ -287,8 +283,8 @@ def convert_params_to_int_or_bool(params):
for k, v in params.items():
if v.isnumeric():
params[k] = int(v)
if v == 'false':
if v == "false":
params[k] = False
if v == 'true':
if v == "true":
params[k] = True
return params
6 changes: 3 additions & 3 deletions tests/unit/test_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from PIL import Image
from transformers.testing_utils import require_torch, slow

from huggingface_inference_toolkit.handler import get_inference_handler_either_custom_or_default_handler
from huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, DiffusersPipelineImageToText

from huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, IEAutoPipelineForText2Image
from huggingface_inference_toolkit.utils import _load_repository_from_hf, get_pipeline


Expand All @@ -15,7 +15,7 @@ def test_get_diffusers_pipeline():
"hf-internal-testing/tiny-stable-diffusion-torch", tmpdirname, framework="pytorch"
)
pipe = get_pipeline("text-to-image", storage_dir.as_posix())
assert isinstance(pipe, DiffusersPipelineImageToText)
assert isinstance(pipe, IEAutoPipelineForText2Image)


@slow
Expand Down