Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax safety checker #825

Merged
merged 15 commits into from
Oct 13, 2022
Merged

Flax safety checker #825

merged 15 commits into from
Oct 13, 2022

Conversation

pcuenca
Copy link
Member

@pcuenca pcuenca commented Oct 13, 2022

I managed to change the pipeline to run the slow portion of the safety checker in pmap mode. The way it works is that __call__ now invokes a generate function first, and then computes the safety scores. Both generate and get_safety_scores are explicitly pmapped inside __call__, and therefore arguments to __call__ are sharded too.

This is how it works from a user's point of view:

from diffusers import FlaxStableDiffusionPipeline

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "/home/pedro/code/diffusers/sd-v1-4-flax",
    dtype=jnp.bfloat16
)

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)

prng_seed = jax.random.PRNGKey(0)

# Replication done by the pipeline
output = pipeline(prompt_ids, params, prng_seed)

This breaks the test recently added by @patrickvonplaten.

Please, let me know if this is acceptable and I'll fix the test and finalize a couple of TODOs.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 13, 2022

The documentation is not available anymore as the PR was closed or merged.

pcuenca and others added 3 commits October 13, 2022 12:57
We could have decorated `generate` with `pmap`, but I wanted to keep it
in case someone wants to invoke it in non-parallel mode.
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super cool that you got it to work @pcuenca !

In my opinion in the long run we could / should strive for a solution that would allow us to wrap everything into a single pmap(...) call, e.g. not having to call pmap(...) from inside.

For today's release I think it's totally fine, however! Just two things I'd advocate for to change:

  • Let's move all sharding, replicate and random.split outside of generate => the user should have control over these (it's also safer in terms of backwards comp). Only internal shard/replicate should be called with if-statements if something like if jit and jax.device_count > 1 => then unshard/shard in def run_safety_checker`
  • Also I'd advocate to not run pmap by default but only if a flag called jit=False is passed via jit=True by the user because it fits better with JAX API (e.g. things are not jitted by default). In the long run we will then remove this jit=True flag and change the internals so that one can pmap(...) the whole function end-to-end.
  • Let's make generate and run_safety_checker private methods

@pcuenca
Copy link
Member Author

pcuenca commented Oct 13, 2022

Super cool that you got it to work @pcuenca !

In my opinion in the long run we could / should strive for a solution that would allow us to wrap everything into a single pmap(...) call, e.g. not having to call pmap(...) from inside.

For today's release I think it's totally fine, however! Just two things I'd advocate for to change:

  • Let's move all sharding, replicate and random.split outside of generate => the user should have control over these (it's also safer in terms of backwards comp)
  • Also I'd advocate to not run pmap by default but only if a flag called jit=False is passed via jit=True by the user because it fits better with JAX API (e.g. things are not jitted by default). In the long run we will then remove this jit=True flag and change the internals so that one can pmap(...) the whole function end-to-end.

Sounds great, thanks a lot for the fast review!

Making it work in a single function was much harder, so I opted for this intermediate solution. This way we don't require users to run the two steps themselves (generation and safety checker). Totally agree that we should try to wrap everything inside a single pmap call.

I also agree with the other comments, it didn't feel right to take decision to use pmap on our own. Thanks!

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, very cool !

special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
special_cos_dist = unshard(special_cos_dist)
cos_dist = unshard(cos_dist)
safety_model_params = unreplicate(safety_model_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to do unreplicate here ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because if we are using jit, safety_model_params is extracted from the params dict which is already replicated. We use the replicated version in _p_get_safety_scores a couple of lines above, but then we need the unreplicated one to compute the scores in self.safety_checker.filtered_with_scores

Comment on lines +279 to +291
# TODO: maybe use a config dict instead of so many static argnums
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9))
def _p_generate(
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
):
return pipe._generate(
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug
)


@partial(jax.pmap, static_broadcasted_argnums=(0,))
def _p_get_safety_scores(pipe, features, params):
return pipe._get_safety_scores(features, params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit)

maybe have this as pipeline methods.


return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def unshard(x: jnp.ndarray):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's maybe also make it private

@patrickvonplaten patrickvonplaten merged commit 78db11d into main Oct 13, 2022
@patil-suraj patil-suraj deleted the flax-safety-checker branch October 13, 2022 15:05
prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
* Remove set_format in Flax pipeline.

* Remove DummyChecker.

* Run safety_checker in pipeline.

* Don't pmap on every call.

We could have decorated `generate` with `pmap`, but I wanted to keep it
in case someone wants to invoke it in non-parallel mode.

* Remove commented line

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Replicate outside __call__, prepare for optional jitting.

* Remove unnecessary clipping.

As suggested by @kashif.

* Do not jit unless requested.

* Send all args to generate.

* make style

* Remove unused imports.

* Fix docstring.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* Remove set_format in Flax pipeline.

* Remove DummyChecker.

* Run safety_checker in pipeline.

* Don't pmap on every call.

We could have decorated `generate` with `pmap`, but I wanted to keep it
in case someone wants to invoke it in non-parallel mode.

* Remove commented line

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Replicate outside __call__, prepare for optional jitting.

* Remove unnecessary clipping.

As suggested by @kashif.

* Do not jit unless requested.

* Send all args to generate.

* make style

* Remove unused imports.

* Fix docstring.

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants