Skip to content

Commit e2f8bca

Browse files
committed
add prompt embeds param to other ONNX pipelines
1 parent 2237975 commit e2f8bca

File tree

4 files changed

+276
-41
lines changed

4 files changed

+276
-41
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,48 @@ def _encode_prompt(
264264

265265
return prompt_embeds
266266

267+
def check_inputs(
268+
self,
269+
prompt,
270+
callback_steps,
271+
negative_prompt=None,
272+
prompt_embeds=None,
273+
negative_prompt_embeds=None,
274+
):
275+
if (callback_steps is None) or (
276+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
277+
):
278+
raise ValueError(
279+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
280+
f" {type(callback_steps)}."
281+
)
282+
283+
if prompt is not None and prompt_embeds is not None:
284+
raise ValueError(
285+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
286+
" only forward one of the two."
287+
)
288+
elif prompt is None and prompt_embeds is None:
289+
raise ValueError(
290+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
291+
)
292+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
293+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
294+
295+
if negative_prompt is not None and negative_prompt_embeds is not None:
296+
raise ValueError(
297+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
298+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
299+
)
300+
301+
if prompt_embeds is not None and negative_prompt_embeds is not None:
302+
if prompt_embeds.shape != negative_prompt_embeds.shape:
303+
raise ValueError(
304+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
305+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
306+
f" {negative_prompt_embeds.shape}."
307+
)
308+
267309
def __call__(
268310
self,
269311
prompt: Union[str, List[str]],
@@ -275,6 +317,8 @@ def __call__(
275317
num_images_per_prompt: Optional[int] = 1,
276318
eta: Optional[float] = 0.0,
277319
generator: Optional[np.random.RandomState] = None,
320+
prompt_embeds: Optional[np.ndarray] = None,
321+
negative_prompt_embeds: Optional[np.ndarray] = None,
278322
output_type: Optional[str] = "pil",
279323
return_dict: bool = True,
280324
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
@@ -314,6 +358,13 @@ def __call__(
314358
[`schedulers.DDIMScheduler`], will be ignored for others.
315359
generator (`np.random.RandomState`, *optional*):
316360
A np.random.RandomState to make generation deterministic.
361+
prompt_embeds (`np.ndarray`, *optional*):
362+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
363+
provided, text embeddings will be generated from `prompt` input argument.
364+
negative_prompt_embeds (`np.ndarray`, *optional*):
365+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
366+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
367+
argument.
317368
output_type (`str`, *optional*, defaults to `"pil"`):
318369
The output format of the generate image. Choose between
319370
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -334,24 +385,21 @@ def __call__(
334385
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
335386
(nsfw) content, according to the `safety_checker`.
336387
"""
337-
if isinstance(prompt, str):
388+
389+
# check inputs. Raise error if not correct
390+
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
391+
392+
# define call parameters
393+
if prompt is not None and isinstance(prompt, str):
338394
batch_size = 1
339-
elif isinstance(prompt, list):
395+
elif prompt is not None and isinstance(prompt, list):
340396
batch_size = len(prompt)
341397
else:
342-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
398+
batch_size = prompt_embeds.shape[0]
343399

344400
if strength < 0 or strength > 1:
345401
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
346402

347-
if (callback_steps is None) or (
348-
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
349-
):
350-
raise ValueError(
351-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
352-
f" {type(callback_steps)}."
353-
)
354-
355403
if generator is None:
356404
generator = np.random
357405

@@ -366,7 +414,12 @@ def __call__(
366414
do_classifier_free_guidance = guidance_scale > 1.0
367415

368416
prompt_embeds = self._encode_prompt(
369-
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
417+
prompt,
418+
num_images_per_prompt,
419+
do_classifier_free_guidance,
420+
negative_prompt,
421+
prompt_embeds=prompt_embeds,
422+
negative_prompt_embeds=negative_prompt_embeds,
370423
)
371424

372425
latents_dtype = prompt_embeds.dtype

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,54 @@ def _encode_prompt(
265265

266266
return prompt_embeds
267267

268+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline.check_inputs
269+
def check_inputs(
270+
self,
271+
prompt,
272+
height,
273+
width,
274+
callback_steps,
275+
negative_prompt=None,
276+
prompt_embeds=None,
277+
negative_prompt_embeds=None,
278+
):
279+
if height % 8 != 0 or width % 8 != 0:
280+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
281+
282+
if (callback_steps is None) or (
283+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
284+
):
285+
raise ValueError(
286+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
287+
f" {type(callback_steps)}."
288+
)
289+
290+
if prompt is not None and prompt_embeds is not None:
291+
raise ValueError(
292+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
293+
" only forward one of the two."
294+
)
295+
elif prompt is None and prompt_embeds is None:
296+
raise ValueError(
297+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
298+
)
299+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
300+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
301+
302+
if negative_prompt is not None and negative_prompt_embeds is not None:
303+
raise ValueError(
304+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
305+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
306+
)
307+
308+
if prompt_embeds is not None and negative_prompt_embeds is not None:
309+
if prompt_embeds.shape != negative_prompt_embeds.shape:
310+
raise ValueError(
311+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
312+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
313+
f" {negative_prompt_embeds.shape}."
314+
)
315+
268316
@torch.no_grad()
269317
def __call__(
270318
self,
@@ -280,6 +328,8 @@ def __call__(
280328
eta: float = 0.0,
281329
generator: Optional[np.random.RandomState] = None,
282330
latents: Optional[np.ndarray] = None,
331+
prompt_embeds: Optional[np.ndarray] = None,
332+
negative_prompt_embeds: Optional[np.ndarray] = None,
283333
output_type: Optional[str] = "pil",
284334
return_dict: bool = True,
285335
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
@@ -326,6 +376,13 @@ def __call__(
326376
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
327377
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
328378
tensor will ge generated by sampling using the supplied random `generator`.
379+
prompt_embeds (`np.ndarray`, *optional*):
380+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
381+
provided, text embeddings will be generated from `prompt` input argument.
382+
negative_prompt_embeds (`np.ndarray`, *optional*):
383+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
384+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
385+
argument.
329386
output_type (`str`, *optional*, defaults to `"pil"`):
330387
The output format of the generate image. Choose between
331388
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -347,23 +404,18 @@ def __call__(
347404
(nsfw) content, according to the `safety_checker`.
348405
"""
349406

350-
if isinstance(prompt, str):
407+
# check inputs. Raise error if not correct
408+
self.check_inputs(
409+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
410+
)
411+
412+
# define call parameters
413+
if prompt is not None and isinstance(prompt, str):
351414
batch_size = 1
352-
elif isinstance(prompt, list):
415+
elif prompt is not None and isinstance(prompt, list):
353416
batch_size = len(prompt)
354417
else:
355-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
356-
357-
if height % 8 != 0 or width % 8 != 0:
358-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
359-
360-
if (callback_steps is None) or (
361-
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
362-
):
363-
raise ValueError(
364-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
365-
f" {type(callback_steps)}."
366-
)
418+
batch_size = prompt_embeds.shape[0]
367419

368420
if generator is None:
369421
generator = np.random
@@ -377,7 +429,12 @@ def __call__(
377429
do_classifier_free_guidance = guidance_scale > 1.0
378430

379431
prompt_embeds = self._encode_prompt(
380-
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
432+
prompt,
433+
num_images_per_prompt,
434+
do_classifier_free_guidance,
435+
negative_prompt,
436+
prompt_embeds=prompt_embeds,
437+
negative_prompt_embeds=negative_prompt_embeds,
381438
)
382439

383440
num_channels_latents = NUM_LATENT_CHANNELS

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py

Lines changed: 65 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,48 @@ def _encode_prompt(
250250

251251
return prompt_embeds
252252

253+
def check_inputs(
254+
self,
255+
prompt,
256+
callback_steps,
257+
negative_prompt=None,
258+
prompt_embeds=None,
259+
negative_prompt_embeds=None,
260+
):
261+
if (callback_steps is None) or (
262+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
263+
):
264+
raise ValueError(
265+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
266+
f" {type(callback_steps)}."
267+
)
268+
269+
if prompt is not None and prompt_embeds is not None:
270+
raise ValueError(
271+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
272+
" only forward one of the two."
273+
)
274+
elif prompt is None and prompt_embeds is None:
275+
raise ValueError(
276+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
277+
)
278+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
279+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
280+
281+
if negative_prompt is not None and negative_prompt_embeds is not None:
282+
raise ValueError(
283+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
284+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
285+
)
286+
287+
if prompt_embeds is not None and negative_prompt_embeds is not None:
288+
if prompt_embeds.shape != negative_prompt_embeds.shape:
289+
raise ValueError(
290+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
291+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
292+
f" {negative_prompt_embeds.shape}."
293+
)
294+
253295
def __call__(
254296
self,
255297
prompt: Union[str, List[str]],
@@ -262,6 +304,8 @@ def __call__(
262304
num_images_per_prompt: Optional[int] = 1,
263305
eta: Optional[float] = 0.0,
264306
generator: Optional[np.random.RandomState] = None,
307+
prompt_embeds: Optional[np.ndarray] = None,
308+
negative_prompt_embeds: Optional[np.ndarray] = None,
265309
output_type: Optional[str] = "pil",
266310
return_dict: bool = True,
267311
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
@@ -306,6 +350,13 @@ def __call__(
306350
[`schedulers.DDIMScheduler`], will be ignored for others.
307351
generator (`np.random.RandomState`, *optional*):
308352
A np.random.RandomState to make generation deterministic.
353+
prompt_embeds (`np.ndarray`, *optional*):
354+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
355+
provided, text embeddings will be generated from `prompt` input argument.
356+
negative_prompt_embeds (`np.ndarray`, *optional*):
357+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
358+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
359+
argument.
309360
output_type (`str`, *optional*, defaults to `"pil"`):
310361
The output format of the generate image. Choose between
311362
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -326,24 +377,21 @@ def __call__(
326377
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
327378
(nsfw) content, according to the `safety_checker`.
328379
"""
329-
if isinstance(prompt, str):
380+
381+
# check inputs. Raise error if not correct
382+
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
383+
384+
# define call parameters
385+
if prompt is not None and isinstance(prompt, str):
330386
batch_size = 1
331-
elif isinstance(prompt, list):
387+
elif prompt is not None and isinstance(prompt, list):
332388
batch_size = len(prompt)
333389
else:
334-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
390+
batch_size = prompt_embeds.shape[0]
335391

336392
if strength < 0 or strength > 1:
337393
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
338394

339-
if (callback_steps is None) or (
340-
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
341-
):
342-
raise ValueError(
343-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
344-
f" {type(callback_steps)}."
345-
)
346-
347395
if generator is None:
348396
generator = np.random
349397

@@ -359,7 +407,12 @@ def __call__(
359407
do_classifier_free_guidance = guidance_scale > 1.0
360408

361409
prompt_embeds = self._encode_prompt(
362-
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
410+
prompt,
411+
num_images_per_prompt,
412+
do_classifier_free_guidance,
413+
negative_prompt,
414+
prompt_embeds=prompt_embeds,
415+
negative_prompt_embeds=negative_prompt_embeds,
363416
)
364417

365418
latents_dtype = prompt_embeds.dtype

0 commit comments

Comments
 (0)