From 5333895b3d2c6f5c2835869625b27927592b6d18 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 1 Nov 2022 15:56:15 +0000 Subject: [PATCH 1/6] Do not recompile when guidance_scale changes. --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index fe0e284c6720..d635f8df3331 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -308,8 +308,7 @@ def __call__( return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) -# TODO: maybe use a config dict instead of so many static argnums -@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) +@partial(jax.pmap, in_axes=(None, 0, 0, 0, None, None, None, None, 0, None), static_broadcasted_argnums=(0, 4, 5, 6, 9)) def _p_generate( pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug ): From e1a7d264ebfb1d8ef4570bc3b3862072bf448e50 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 1 Nov 2022 16:12:16 +0000 Subject: [PATCH 2/6] Remove debug for simplicity. --- .../pipeline_flax_stable_diffusion.py | 20 ++++++------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index d635f8df3331..8bebab5d88da 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -142,7 +142,6 @@ def _generate( width: int = 512, guidance_scale: float = 7.5, latents: Optional[jnp.array] = None, - debug: bool = False, ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -201,13 +200,7 @@ def loop_body(step, args): # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma - - if debug: - # run with python for loop - for i in range(num_inference_steps): - latents, scheduler_state = loop_body(i, (latents, scheduler_state)) - else: - latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / 0.18215 * latents @@ -228,7 +221,6 @@ def __call__( latents: jnp.array = None, return_dict: bool = True, jit: bool = False, - debug: bool = False, **kwargs, ): r""" @@ -276,11 +268,11 @@ def __call__( """ if jit: images = _p_generate( - self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + self, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents ) else: images = self._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents ) if self.safety_checker is not None: @@ -308,12 +300,12 @@ def __call__( return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) -@partial(jax.pmap, in_axes=(None, 0, 0, 0, None, None, None, None, 0, None), static_broadcasted_argnums=(0, 4, 5, 6, 9)) +@partial(jax.pmap, in_axes=(None, 0, 0, 0, None, None, None, None, 0), static_broadcasted_argnums=(0, 4, 5, 6)) def _p_generate( - pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents ): return pipe._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug + prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents ) From 3e9b7535ad9443bcff64baab8bffd6da0d996ef9 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 1 Nov 2022 16:15:20 +0000 Subject: [PATCH 3/6] make style --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 8bebab5d88da..b58f6edd45db 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -301,12 +301,8 @@ def __call__( @partial(jax.pmap, in_axes=(None, 0, 0, 0, None, None, None, None, 0), static_broadcasted_argnums=(0, 4, 5, 6)) -def _p_generate( - pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents -): - return pipe._generate( - prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents - ) +def _p_generate(pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents): + return pipe._generate(prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents) @partial(jax.pmap, static_broadcasted_argnums=(0,)) From 759094cd820be2b36df16e797dc4819d1a36eb13 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Tue, 6 Dec 2022 13:15:27 +0000 Subject: [PATCH 4/6] Make guidance_scale an array. --- .../pipeline_flax_stable_diffusion.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 150847cef277..335d46ef09ea 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -276,7 +276,7 @@ def __call__( num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, - guidance_scale: float = 7.5, + guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, return_dict: bool = True, jit: bool = False, @@ -326,6 +326,12 @@ def __call__( height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor + if isinstance(guidance_scale, float): + guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) + if len(prompt_ids.shape) > 2: + # Assume sharded + guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2]) + if jit: images = _p_generate( self, @@ -379,7 +385,11 @@ def __call__( # TODO: maybe use a config dict instead of so many static argnums -@partial(jax.pmap, in_axes=(None, 0, 0, 0, None, None, None, None, 0, 0), static_broadcasted_argnums=(0, 4, 5, 6)) +@partial( + jax.pmap, + in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0), + static_broadcasted_argnums=(0, 4, 5, 6), +) def _p_generate( pipe, prompt_ids, From bc80eed18067ac3d4e3ac87ced16cefe95a993e8 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 7 Dec 2022 13:23:15 +0000 Subject: [PATCH 5/6] Make DEBUG a constant to avoid passing it down. --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 335d46ef09ea..eb271c0289cb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -42,6 +42,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Set to True to use python for loop instead of jax.fori_loop for easier debugging +DEBUG = False class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): r""" @@ -259,7 +261,12 @@ def loop_body(step, args): # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma - latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) + if DEBUG: + # run with python for loop + for i in range(num_inference_steps): + latents, scheduler_state = loop_body(i, (latents, scheduler_state)) + else: + latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state)) # scale and decode the image latents with vae latents = 1 / 0.18215 * latents From b55a8a19d3f897b4cdea2f2ddced5bdbf4c93eb8 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Wed, 7 Dec 2022 13:23:33 +0000 Subject: [PATCH 6/6] Add comments for clarification. --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index eb271c0289cb..912a4381d032 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -45,6 +45,7 @@ # Set to True to use python for loop instead of jax.fori_loop for easier debugging DEBUG = False + class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): r""" Pipeline for text-to-image generation using Stable Diffusion. @@ -334,6 +335,8 @@ def __call__( width = width or self.unet.config.sample_size * self.vae_scale_factor if isinstance(guidance_scale, float): + # Convert to a tensor so each device gets a copy. Follow the prompt_ids for + # shape information, as they may be sharded (when `jit` is `True`), or not. guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0]) if len(prompt_ids.shape) > 2: # Assume sharded @@ -391,7 +394,8 @@ def __call__( return FlaxStableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) -# TODO: maybe use a config dict instead of so many static argnums +# Static argnums are pipe, num_inference_steps, height, width. A change would trigger recompilation. +# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`). @partial( jax.pmap, in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0),