-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax safety checker #825
Flax safety checker #825
Conversation
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
Outdated
Show resolved
Hide resolved
The documentation is not available anymore as the PR was closed or merged. |
We could have decorated `generate` with `pmap`, but I wanted to keep it in case someone wants to invoke it in non-parallel mode.
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
…s into flax-safety-checker
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super cool that you got it to work @pcuenca !
In my opinion in the long run we could / should strive for a solution that would allow us to wrap everything into a single pmap(...)
call, e.g. not having to call pmap(...)
from inside.
For today's release I think it's totally fine, however! Just two things I'd advocate for to change:
- Let's move all sharding, replicate and random.split outside of
generate
=> the user should have control over these (it's also safer in terms of backwards comp). Only internal shard/replicate should be called with if-statements if something likeif jit and jax.device_count > 1 => then unshard/shard in
def run_safety_checker` - Also I'd advocate to not run
pmap
by default but only if a flag calledjit=False
is passed viajit=True
by the user because it fits better with JAX API (e.g. things are not jitted by default). In the long run we will then remove thisjit=True
flag and change the internals so that one canpmap(...)
the whole function end-to-end. - Let's make
generate
andrun_safety_checker
private methods
Sounds great, thanks a lot for the fast review! Making it work in a single function was much harder, so I opted for this intermediate solution. This way we don't require users to run the two steps themselves (generation and safety checker). Totally agree that we should try to wrap everything inside a single I also agree with the other comments, it didn't feel right to take decision to use |
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, very cool !
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params) | ||
special_cos_dist = unshard(special_cos_dist) | ||
cos_dist = unshard(cos_dist) | ||
safety_model_params = unreplicate(safety_model_params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to do unreplicate
here ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because if we are using jit
, safety_model_params
is extracted from the params
dict which is already replicated. We use the replicated version in _p_get_safety_scores
a couple of lines above, but then we need the unreplicated one to compute the scores in self.safety_checker.filtered_with_scores
# TODO: maybe use a config dict instead of so many static argnums | ||
@partial(jax.pmap, static_broadcasted_argnums=(0, 4, 5, 6, 7, 9)) | ||
def _p_generate( | ||
pipe, prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug | ||
): | ||
return pipe._generate( | ||
prompt_ids, params, prng_seed, num_inference_steps, height, width, guidance_scale, latents, debug | ||
) | ||
|
||
|
||
@partial(jax.pmap, static_broadcasted_argnums=(0,)) | ||
def _p_get_safety_scores(pipe, features, params): | ||
return pipe._get_safety_scores(features, params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit)
maybe have this as pipeline
methods.
|
||
return FlaxStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) | ||
def unshard(x: jnp.ndarray): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's maybe also make it private
* Remove set_format in Flax pipeline. * Remove DummyChecker. * Run safety_checker in pipeline. * Don't pmap on every call. We could have decorated `generate` with `pmap`, but I wanted to keep it in case someone wants to invoke it in non-parallel mode. * Remove commented line Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Replicate outside __call__, prepare for optional jitting. * Remove unnecessary clipping. As suggested by @kashif. * Do not jit unless requested. * Send all args to generate. * make style * Remove unused imports. * Fix docstring. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
* Remove set_format in Flax pipeline. * Remove DummyChecker. * Run safety_checker in pipeline. * Don't pmap on every call. We could have decorated `generate` with `pmap`, but I wanted to keep it in case someone wants to invoke it in non-parallel mode. * Remove commented line Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Replicate outside __call__, prepare for optional jitting. * Remove unnecessary clipping. As suggested by @kashif. * Do not jit unless requested. * Send all args to generate. * make style * Remove unused imports. * Fix docstring. Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
I managed to change the pipeline to run the slow portion of the safety checker in
pmap
mode. The way it works is that__call__
now invokes agenerate
function first, and then computes the safety scores. Bothgenerate
andget_safety_scores
are explicitly pmapped inside__call__
, and therefore arguments to__call__
are sharded too.This is how it works from a user's point of view:
This breaks the test recently added by @patrickvonplaten.
Please, let me know if this is acceptable and I'll fix the test and finalize a couple of TODOs.