-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
consistency decoder #5694
consistency decoder #5694
Conversation
The documentation is not available anymore as the PR was closed or merged. |
def set_timesteps( | ||
self, | ||
num_inference_steps: Optional[int] = None, | ||
device: Union[str, torch.device] = None, | ||
): | ||
assert num_inference_steps == 2, "TODO" | ||
self.timesteps = torch.tensor([1008, 512], dtype=torch.long, device=device) | ||
self.sqrt_alphas_cumprod = self.sqrt_alphas_cumprod.to(device) | ||
self.sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod.to(device) | ||
self.c_skip = self.c_skip.to(device) | ||
self.c_out = self.c_out.to(device) | ||
self.c_in = self.c_in.to(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to implement more timestep options but I'm ok shipping w/out and adding a better error message for now
src/diffusers/pipelines/consistency_models/pipeline_consistency_decoder.py
Outdated
Show resolved
Hide resolved
@williamberman summary of the changes:
|
src/diffusers/pipelines/consistency_models/pipeline_consistency_decoder.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/consistency_models/pipeline_consistency_decoder.py
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 👍🏽 Had some comments on a few small things.
Very nice job on porting the checkpoint this quickly @williamberman . I feel quite strongly that we should allow this decoder to be a drop-in replacement for existing SD15 pipelines. E.g. the following should work: +from diffusers import StableDiffusionPipeline, ConsistencyDecoderVae
-from diffusers import StableDiffusionPipeline
import torch
+vae = ConsistencyDecoderVae.from_pretrained("openai/consistency_decoder", torch_dtype=torch.float16)
model_id = "runwayml/stable-diffusion-v1-5"
+pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, vae=vae)
-pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png") Therefore, I would do the following:
class ConsistencyDecoderVae(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, <all_config_params_of_unet_and_vae_encoder>):
# same encoder as standard VAE
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)
self.decoder_unet = UNet2DModel(<all_config_params_on_unet>)
self.decoder_scheduler = ConsistencyDecoderScheduler()
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.tiled_encode(x, return_dict=return_dict)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
@apply_forward_hook
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
# here call the scheduler and the unet ... For now let's maybe just not allow saving the and loading the scheduler and always default to the correct default values until we have more usage (we can always add it later IMO) |
But @patrickvonplaten won't that be restrictive because users won't be able to control the number of timesteps for the decoder scheduler (I know it's currently not supported). Do you think most users would stick to 2 timesteps? I think they would. |
yeah chatted with patrick here, I have no strong opinions on pipeline vs model class |
38984c1
to
4a02f91
Compare
5389d07
to
e8f39af
Compare
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] | ||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ | ||
0 | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This helps for tests for the noise generation in the decoder, I also added an un-used generator arg to the AutoencoderKL.decode so should be fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perf!
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: | ||
def decode( | ||
self, z: torch.FloatTensor, return_dict: bool = True, generator=None | ||
) -> Union[DecoderOutput, torch.FloatTensor]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Helps keep interface same with consistency decoder and can pass generator through from pipeline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes perf
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
if self.forward_requires_fresh_args: | ||
model = self.model_class(**self.init_dict) | ||
else: | ||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
model = self.model_class(**init_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed to pass generator with same seed for multiple forward calls
failing tests are unrelated |
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py
Outdated
Show resolved
Hide resolved
…erman/diffusers into vae_consistency_decoder
…vae_consistency_decoder
Some bug fixes here: #5722 |
* consistency decoder * rename * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py * uP * Apply suggestions from code review * uP * uP * uP --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* consistency decoder * rename * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py * uP * Apply suggestions from code review * uP * uP * uP --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
* consistency decoder * rename * Apply suggestions from code review Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/pipelines/consistency_models/pipeline_consistency_models.py * uP * Apply suggestions from code review * uP * uP * uP --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
orig
new
original
new