Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flax stable diffusion img2img pipeline #1355

Merged
merged 28 commits into from
Dec 20, 2022

Conversation

dhruvrnaik
Copy link
Contributor

@dhruvrnaik dhruvrnaik commented Nov 21, 2022

Added Flax SD pipeline for Image 2 Image

TODO:

  • Update README
  • Fix generation bug

The pipeline is working, but the result is not as expected.

Input image:download (1)
Prompt = "A fantasy landscape, trending on artstation"
Outputs: download (2)

Params:
strength=0.75,
num_inference_steps=50,
guidance_scale = 7.5

How to run this:

import requests
from io import BytesIO
from PIL import Image

url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"

response = requests.get(url)
init_img = Image.open(BytesIO(response.content)).convert("RGB")
init_img = init_img.resize((768, 512))
prompts = "A fantasy landscape, trending on artstation"
dtype=jnp.bfloat16
pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", revision="flax",
    dtype=dtype,
)
def create_key(seed=0):
    return jax.random.PRNGKey(seed)
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

prompt_ids, imgs = pipeline.prepare_inputs(prompt=[prompts]*jax.device_count(), init_image = [init_img]*jax.device_count())
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
imgs = shard(imgs)

output = pipeline(
    prompt_ids=prompt_ids, 
    init_images=imgs, 
    params=p_params, 
    prng_seed=rng, 
    strength=0.75, 
    num_inference_steps=50, 
    jit=True, 
    guidance_scale=7.5,
height=512,width=768).images

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 21, 2022

The documentation is not available anymore as the PR was closed or merged.

@dhruvrnaik dhruvrnaik marked this pull request as draft November 21, 2022 15:56
@dhruvrnaik
Copy link
Contributor Author

cc: @patil-suraj if you or someone else could review this, It'd be great. Trying to fix the image result, but a review can help fix any other issues.

@dhruvrnaik dhruvrnaik marked this pull request as ready for review November 25, 2022 03:09
@dhruvrnaik
Copy link
Contributor Author

dhruvrnaik commented Nov 25, 2022

Fixed it. Forgot to update the for loop with the right timesteps.

download (3)

@patrickvonplaten
Copy link
Contributor

Cool ! @dhruvrnaik do you need help finishing this PR?

@dhruvrnaik
Copy link
Contributor Author

@patrickvonplaten I am mostly done with this. I just need to update the docs, which I will finish tonight. I have tested this on a TPU v3-8 pod and it works as expected. Ready for a review, otherwise.
I will fix the black/isort test fails too.

@pcuenca
Copy link
Member

pcuenca commented Dec 2, 2022

@patrickvonplaten I am mostly done with this. I just need to update the docs, which I will finish tonight. I have tested this on a TPU v3-8 pod and it works as expected. Ready for a review, otherwise.
I will fix the black/isort test fails too.

Excellent! I'll check it out too. Please, let me know when you are done with your changes :)
Thanks a lot!

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! However, it still doesn't work for me. We also need to expose the class name up through the __init__.py chain so it's visible in diffusers.

Please, make sure you push any other changes you may have locally, and let us know if you need help driving this home :)

README.md Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
Comment on lines 250 to 255
if debug:
# run with python for loop
for i in range(len(timesteps)):
latents, timesteps, scheduler_state = loop_body(i, (latents, timesteps, scheduler_state))
else:
latents, _, _ = jax.lax.fori_loop(0, len(timesteps), loop_body, (latents, timesteps, scheduler_state))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure about this. Wouldn't it be possible to start the loop at the initial effective timestep (instead of 0), and avoid having to pass it around?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also done. Please have a look when you can

dhruvrnaik and others added 3 commits December 4, 2022 21:18
…diffusion_img2img.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
…diffusion_img2img.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
@patrickvonplaten
Copy link
Contributor

Hey @dhruvrnaik,

Thanks a lot for the PR! Could you maybe run make fix-copies once and correct the:

 ❱  45 class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):       │
│    46 │   r"""                                                               │
│    47 │   Pipeline for image-to-image generation using Stable Diffusion.     │
│    48                                                                        │
│                                                                              │
│ /usr/local/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/ │
│ pipeline_flax_stable_diffusion_img2img.py:270 in                             │
│ FlaxStableDiffusionImg2ImgPipeline                                           │
│                                                                              │
│   267 │   │   prompt_ids: jnp.array,                                         │
│   268 │   │   image: jnp.array,                                              │
│   269 │   │   params: Union[Dict, FrozenDict],                               │
│ ❱ 270 │   │   prng_seed: Union[jax.random.KeyArray, jax.Array],              │
│   271 │   │   num_inference_steps: int = 50,                                 │
│   272 │   │   height: int = 512,                                             │
│   273 │   │   width: int = 512,                                              │
╰──────────────────────────────────────────────────────────────────────────────╯
AttributeError: module 'jax' has no attribute 'Array'

error message of the PR Documentation to get this merged? :-)

@dhruvrnaik
Copy link
Contributor Author

Should be good to go @patrickvonplaten

@pcuenca
Copy link
Member

pcuenca commented Dec 19, 2022

@dhruvrnaik I'll test it today.

@patrickvonplaten
Copy link
Contributor

@pcuenca, feel free to merge whenever :-)

@pcuenca
Copy link
Member

pcuenca commented Dec 20, 2022

Thanks a lot @dhruvrnaik!

@pcuenca pcuenca merged commit a9190ba into huggingface:main Dec 20, 2022
sliard pushed a commit to sliard/diffusers that referenced this pull request Dec 21, 2022
* add flax img2img pipeline

* update pipeline

* black format file

* remove argg from get_timesteps

* update get_timesteps

* fix bug: make use of timesteps for for_loop

* black file

* black, isort, flake8

* update docstring

* update readme

* update flax img2img readme

* update sd pipeline init

* Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* update inits

* revert change

* update var name to image, typo

* update readme

* return new t_start instead of modified timestep

* black format

* isort files

* update docs

* fix-copies

* update prng_seed typing

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

t_start = self.get_timestep_start(num_inference_steps, strength, scheduler_state)
latent_timestep = scheduler_state.timesteps[t_start : t_start + 1].repeat(batch_size)
init_latents = self.scheduler.add_noise(init_latents, noise, latent_timestep)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of #1661, the add_noise function now requires the scheduler params:

init_latents = self.scheduler.add_noise(params["scheduler"], init_latents, noise, latent_timestep)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @skirsten! Would you like to open a PR for that?

In addition, we should create some tests for this new pipeline too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can work on the tests. Will check out the PR too

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the PR that also adds the params to add_noise: #1824

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* add flax img2img pipeline

* update pipeline

* black format file

* remove argg from get_timesteps

* update get_timesteps

* fix bug: make use of timesteps for for_loop

* black file

* black, isort, flake8

* update docstring

* update readme

* update flax img2img readme

* update sd pipeline init

* Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* update inits

* revert change

* update var name to image, typo

* update readme

* return new t_start instead of modified timestep

* black format

* isort files

* update docs

* fix-copies

* update prng_seed typing

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants