Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

consistency decoder #5694

Merged

Conversation

williamberman
Copy link
Contributor

@williamberman williamberman commented Nov 8, 2023

from diffusers import StableDiffusionPipeline, ConsistencyDecoderVae
import torch

pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
)
pipe.to('cuda')
pipe("horse", generator=torch.Generator('cpu').manual_seed(0)).images[0].save('orig.png')

vae = ConsistencyDecoderVae.from_pretrained("../consistency-decoder")
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, vae=vae
)
pipe.to('cuda')

pipe("horse", generator=torch.Generator('cpu').manual_seed(0)).images[0].save('new.png')

orig
orig

new
new

import torch
from diffusers import ConsistencyDecoderVae, AutoencoderKL
from diffusers.utils.loading_utils import load_image
import numpy as np
from PIL import Image

vae_orig = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, subfolder="vae")
vae_orig.to('cuda')

vae_new = ConsistencyDecoderVae.from_pretrained("../consistency-decoder", torch_dtype=torch.float16)
vae_new.to('cuda')


image = load_image("https://raw.githubusercontent.com/openai/consistencydecoder/main/assets/gt1.png").resize((256, 256))
image = torch.from_numpy(np.array(image).transpose(2, 0, 1).astype(np.float32) / 127.5 - 1)[None, :, :, :].half().cuda()
latent = vae_new.encode(image).latent_dist.mean

sample_orig = vae_orig.decode(latent).sample
sample_new = vae_new.decode(latent).sample

Image.fromarray(((sample_orig.detach()[0].cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).clip(0, 255).astype(np.uint8)).save('orig.png')
Image.fromarray(((sample_new.detach()[0].cpu().numpy().transpose(1, 2, 0) + 1) * 127.5).clip(0, 255).astype(np.uint8)).save('new.png')

original
orig

new
new

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 8, 2023

The documentation is not available anymore as the PR was closed or merged.

Comment on lines 62 to 110
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
):
assert num_inference_steps == 2, "TODO"
self.timesteps = torch.tensor([1008, 512], dtype=torch.long, device=device)
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device)
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device)
self.c_skip = self.c_skip.to(device)
self.c_out = self.c_out.to(device)
self.c_in = self.c_in.to(device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to implement more timestep options but I'm ok shipping w/out and adding a better error message for now

@sayakpaul sayakpaul marked this pull request as ready for review November 8, 2023 09:53
@sayakpaul
Copy link
Member

@williamberman summary of the changes:

  • Docs
  • Tests (including a slow one)

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 👍🏽 Had some comments on a few small things.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Nov 8, 2023

Very nice job on porting the checkpoint this quickly @williamberman .

I feel quite strongly that we should allow this decoder to be a drop-in replacement for existing SD15 pipelines. E.g. the following should work:

+from diffusers import StableDiffusionPipeline, ConsistencyDecoderVae
-from diffusers import StableDiffusionPipeline
import torch

+vae = ConsistencyDecoderVae.from_pretrained("openai/consistency_decoder", torch_dtype=torch.float16)

model_id = "runwayml/stable-diffusion-v1-5"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, vae=vae)
-pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]  
    
image.save("astronaut_rides_horse.png")

Therefore, I would do the following:

    1. Instead create a ConsistencyDecoderVae class that has exactly the same signatures as AutoencoderKL
    1. Save both the unet, scheduler and encoder (same as the SDv15 vae) as class attributes
    1. Add the sdv15 vae to the checkpoint as well so that Img2Img pipelines work
class ConsistencyDecoderVae(ModelMixin, ConfigMixin):
    
    @register_to_config
    def __init__(self, <all_config_params_of_unet_and_vae_encoder>):
         # same encoder as standard VAE
         self.encoder = Encoder(
            in_channels=in_channels,
            out_channels=latent_channels,
            down_block_types=down_block_types,
            block_out_channels=block_out_channels,
            layers_per_block=layers_per_block,
            act_fn=act_fn,
            norm_num_groups=norm_num_groups,
            double_z=True,
         )

        self.decoder_unet = UNet2DModel(<all_config_params_on_unet>)
        self.decoder_scheduler = ConsistencyDecoderScheduler()

    def forward(
        self,
        sample: torch.FloatTensor,
        sample_posterior: bool = False,
        return_dict: bool = True,
        generator: Optional[torch.Generator] = None,
    ) -> Union[DecoderOutput, torch.FloatTensor]:
        r"""
        Args:
            sample (`torch.FloatTensor`): Input sample.
            sample_posterior (`bool`, *optional*, defaults to `False`):
                Whether to sample from the posterior.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
        """
        x = sample
        posterior = self.encode(x).latent_dist
        if sample_posterior:
            z = posterior.sample(generator=generator)
        else:
            z = posterior.mode()
        dec = self.decode(z).sample

        if not return_dict:
            return (dec,)

        return DecoderOutput(sample=dec)

    @apply_forward_hook
    def encode(
        self, x: torch.FloatTensor, return_dict: bool = True
    ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
        """
        Encode a batch of images into latents.

        Args:
            x (`torch.FloatTensor`): Input batch of images.
            return_dict (`bool`, *optional*, defaults to `True`):
                Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.

        Returns:
                The latent representations of the encoded images. If `return_dict` is True, a
                [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
        """
        if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
            return self.tiled_encode(x, return_dict=return_dict)

        if self.use_slicing and x.shape[0] > 1:
            encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
            h = torch.cat(encoded_slices)
        else:
            h = self.encoder(x)

        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)

        if not return_dict:
            return (posterior,)

        return AutoencoderKLOutput(latent_dist=posterior)

    @apply_forward_hook
    def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
        # here call the scheduler and the unet ...

For now let's maybe just not allow saving the and loading the scheduler and always default to the correct default values until we have more usage (we can always add it later IMO)

@sayakpaul
Copy link
Member

But @patrickvonplaten won't that be restrictive because users won't be able to control the number of timesteps for the decoder scheduler (I know it's currently not supported).

Do you think most users would stick to 2 timesteps? I think they would.

@williamberman
Copy link
Contributor Author

yeah chatted with patrick here, I have no strong opinions on pipeline vs model class

Comment on lines -841 to +843
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helps for tests for the noise generation in the decoder, I also added an un-used generator arg to the AutoencoderKL.decode so should be fine

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perf!

Comment on lines -297 to +299
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Helps keep interface same with consistency decoder and can pass generator through from pipeline

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes perf

Comment on lines -201 to +206
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
if self.forward_requires_fresh_args:
model = self.model_class(**self.init_dict)
else:
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to pass generator with same seed for multiple forward calls

@williamberman
Copy link
Contributor Author

failing tests are unrelated

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
@patrickvonplaten patrickvonplaten merged commit 2fd4640 into huggingface:main Nov 9, 2023
12 of 13 checks passed
@patrickvonplaten
Copy link
Contributor

Some bug fixes here: #5722

kashif pushed a commit to kashif/diffusers that referenced this pull request Nov 11, 2023
* consistency decoder

* rename

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py

* uP

* Apply suggestions from code review

* uP

* uP

* uP

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* consistency decoder

* rename

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py

* uP

* Apply suggestions from code review

* uP

* uP

* uP

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* consistency decoder

* rename

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py

* uP

* Apply suggestions from code review

* uP

* uP

* uP

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants