-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Flax stable diffusion img2img pipeline #1355
Conversation
The documentation is not available anymore as the PR was closed or merged. |
cc: @patil-suraj if you or someone else could review this, It'd be great. Trying to fix the image result, but a review can help fix any other issues. |
Cool ! @dhruvrnaik do you need help finishing this PR? |
@patrickvonplaten I am mostly done with this. I just need to update the docs, which I will finish tonight. I have tested this on a TPU v3-8 pod and it works as expected. Ready for a review, otherwise. |
Excellent! I'll check it out too. Please, let me know when you are done with your changes :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! However, it still doesn't work for me. We also need to expose the class name up through the __init__.py
chain so it's visible in diffusers
.
Please, make sure you push any other changes you may have locally, and let us know if you need help driving this home :)
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
Outdated
Show resolved
Hide resolved
if debug: | ||
# run with python for loop | ||
for i in range(len(timesteps)): | ||
latents, timesteps, scheduler_state = loop_body(i, (latents, timesteps, scheduler_state)) | ||
else: | ||
latents, _, _ = jax.lax.fori_loop(0, len(timesteps), loop_body, (latents, timesteps, scheduler_state)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about this. Wouldn't it be possible to start the loop at the initial effective timestep (instead of 0), and avoid having to pass it around?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also done. Please have a look when you can
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py
Outdated
Show resolved
Hide resolved
…diffusion_img2img.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
…diffusion_img2img.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Hey @dhruvrnaik, Thanks a lot for the PR! Could you maybe run
error message of the PR Documentation to get this merged? :-) |
Should be good to go @patrickvonplaten |
@dhruvrnaik I'll test it today. |
@pcuenca, feel free to merge whenever :-) |
Thanks a lot @dhruvrnaik! |
* add flax img2img pipeline * update pipeline * black format file * remove argg from get_timesteps * update get_timesteps * fix bug: make use of timesteps for for_loop * black file * black, isort, flake8 * update docstring * update readme * update flax img2img readme * update sd pipeline init * Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * update inits * revert change * update var name to image, typo * update readme * return new t_start instead of modified timestep * black format * isort files * update docs * fix-copies * update prng_seed typing Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
|
||
t_start = self.get_timestep_start(num_inference_steps, strength, scheduler_state) | ||
latent_timestep = scheduler_state.timesteps[t_start : t_start + 1].repeat(batch_size) | ||
init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As of #1661, the add_noise
function now requires the scheduler params:
init_latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @skirsten! Would you like to open a PR for that?
In addition, we should create some tests for this new pipeline too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can work on the tests. Will check out the PR too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the PR that also adds the params to add_noise
: #1824
* add flax img2img pipeline * update pipeline * black format file * remove argg from get_timesteps * update get_timesteps * fix bug: make use of timesteps for for_loop * black file * black, isort, flake8 * update docstring * update readme * update flax img2img readme * update sd pipeline init * Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * update inits * revert change * update var name to image, typo * update readme * return new t_start instead of modified timestep * black format * isort files * update docs * fix-copies * update prng_seed typing Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Added Flax SD pipeline for Image 2 Image
TODO:
The pipeline is working, but the result is not as expected.
Input image:![download (1)](https://user-images.githubusercontent.com/22565320/203099668-92f8eb0f-4f19-4a16-a51d-f7d0c4038e4f.png)
![download (2)](https://user-images.githubusercontent.com/22565320/203099672-75eced2f-5a2b-4b03-bfa7-a2de342ca0cb.png)
Prompt = "A fantasy landscape, trending on artstation"
Outputs:
Params:
strength=0.75,
num_inference_steps=50,
guidance_scale = 7.5
How to run this: