Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dockerfiles/pytorch/cpu/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ dependencies:
- transformers[sklearn,sentencepiece,audio,vision]==4.31.0
- sentence_transformers==2.2.2
- torchvision==0.14.1
- diffusers==0.19.3
- diffusers==0.20.0
- accelerate==0.21.0
- safetensors
2 changes: 1 addition & 1 deletion dockerfiles/pytorch/gpu/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ dependencies:
- transformers[sklearn,sentencepiece,audio,vision]==4.31.0
- sentence_transformers==2.2.2
- torchvision==0.14.1
- diffusers==0.19.3
- diffusers==0.20.0
- accelerate==0.21.0
- safetensors
10 changes: 4 additions & 6 deletions src/huggingface_inference_toolkit/diffusers_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib.util
import logging

from transformers.utils.import_utils import is_torch_bf16_gpu_available

logger = logging.getLogger(__name__)
logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO)

Expand All @@ -20,7 +22,7 @@ class IEAutoPipelineForText2Image:
def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
dtype = torch.float32
if device == "cuda":
dtype = torch.float16
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
device_map = "auto" if device == "cuda" else None

self.pipeline = AutoPipelineForText2Image.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map)
Expand All @@ -43,11 +45,7 @@ def __call__(
logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.")

# Call pipeline with parameters
if self.pipeline.device.type == "cuda":
with torch.autocast("cuda"):
out = self.pipeline(prompt, num_images_per_prompt=1)
else:
out = self.pipeline(prompt, num_images_per_prompt=1)
out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs)
return out.images[0]


Expand Down
1 change: 0 additions & 1 deletion src/huggingface_inference_toolkit/webservice_starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ async def predict(request):
# check for query parameter and add them to the body
if request.query_params and "parameters" not in deserialized_body:
deserialized_body["parameters"] = convert_params_to_int_or_bool(dict(request.query_params))
print(deserialized_body)

# tracks request time
start_time = perf_counter()
Expand Down