From 4ee1e066e3d2871d914ce0841001e7e81a841089 Mon Sep 17 00:00:00 2001 From: William Berman Date: Mon, 14 Nov 2022 22:49:02 -0800 Subject: [PATCH] vq diffusion classifier free sampling --- scripts/convert_vq_diffusion_to_diffusers.py | 36 +++++- src/diffusers/__init__.py | 2 +- src/diffusers/models/__init__.py | 1 + src/diffusers/models/embeddings.py | 14 +++ .../vq_diffusion/pipeline_vq_diffusion.py | 117 +++++++++++++----- 5 files changed, 134 insertions(+), 36 deletions(-) diff --git a/scripts/convert_vq_diffusion_to_diffusers.py b/scripts/convert_vq_diffusion_to_diffusers.py index ae105e30362e..877d156513ae 100644 --- a/scripts/convert_vq_diffusion_to_diffusers.py +++ b/scripts/convert_vq_diffusion_to_diffusers.py @@ -39,8 +39,7 @@ import yaml from accelerate import init_empty_weights, load_checkpoint_and_dispatch -from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel -from diffusers.models.attention import Transformer2DModel +from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel, Transformer2DModel, LearnedClassifierFreeSamplingEmbeddings from transformers import CLIPTextModel, CLIPTokenizer from yaml.loader import FullLoader @@ -826,6 +825,21 @@ def read_config_file(filename): transformer_model, checkpoint ) + # classifier free sampling embeddings interlude + + # The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate + # model, so we pull them off the checkpoint before the checkpoint is deleted. + + learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf + + if learnable_classifier_free_sampling_embeddings: + learned_classifier_free_sampling_embeddings_embeddings = checkpoint['transformer.empty_text_embed'] + else: + learned_classifier_free_sampling_embeddings_embeddings = None + + + # done classifier free sampling embeddings interlude + with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file: torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name) del diffusers_transformer_checkpoint @@ -871,6 +885,23 @@ def read_config_file(filename): # done scheduler + # learned classifier free sampling embeddings + + with init_empty_weights(): + learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(learnable_classifier_free_sampling_embeddings) + + learned_classifier_free_sampling_checkpoint = { + 'embeddings': learned_classifier_free_sampling_embeddings_embeddings.float() + } + + with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file: + torch.save(learned_classifier_free_sampling_checkpoint , learned_classifier_free_sampling_checkpoint_file.name) + del learned_classifier_free_sampling_checkpoint + del learned_classifier_free_sampling_embeddings_embeddings + load_checkpoint_and_dispatch(learned_classifier_free_sampling_embeddings_model, learned_classifier_free_sampling_checkpoint_file.name, device_map='auto') + + # learned classifier free sampling embeddings + print(f"saving VQ diffusion model, path: {args.dump_path}") pipe = VQDiffusionPipeline( @@ -878,6 +909,7 @@ def read_config_file(filename): transformer=transformer_model, tokenizer=tokenizer_model, text_encoder=text_encoder_model, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model, scheduler=scheduler_model, ) pipe.save_pretrained(args.dump_path) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 86eda7371fe9..1bc6340c0de3 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -18,7 +18,7 @@ if is_torch_available(): from .modeling_utils import ModelMixin - from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel + from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel, LearnedClassifierFreeSamplingEmbeddings from .optimization import ( get_constant_schedule, get_constant_schedule_with_warmup, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 5b101d169148..d8bdc473ff95 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -21,6 +21,7 @@ from .unet_2d import UNet2DModel from .unet_2d_condition import UNet2DConditionModel from .vae import AutoencoderKL, VQModel + from .embeddings import LearnedClassifierFreeSamplingEmbeddings if is_flax_available(): from .unet_2d_condition_flax import FlaxUNet2DConditionModel diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0221d891f171..5b5916f2ac01 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -17,6 +17,9 @@ import torch from torch import nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.modeling_utils import ModelMixin + def get_timestep_embedding( timesteps: torch.Tensor, @@ -198,3 +201,14 @@ def forward(self, index): emb = emb + pos_emb[:, : emb.shape[1], :] return emb + + +class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin): + @register_to_config + def __init__(self, learnable: bool): + super().__init__() + + if learnable: + self.embeddings = torch.nn.Parameter(torch.empty(77, 512)) + else: + self.embeddings = None diff --git a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py index 6e5325ba7ef5..86c2b7427347 100644 --- a/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py @@ -16,17 +16,17 @@ import torch -from diffusers import Transformer2DModel, VQModel -from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler +from diffusers import Transformer2DModel, VQModel, LearnedClassifierFreeSamplingEmbeddings from transformers import CLIPTextModel, CLIPTokenizer +from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler + from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name - class VQDiffusionPipeline(DiffusionPipeline): r""" Pipeline for text-to-image generation using VQ Diffusion @@ -55,6 +55,7 @@ class VQDiffusionPipeline(DiffusionPipeline): text_encoder: CLIPTextModel tokenizer: CLIPTokenizer transformer: Transformer2DModel + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings scheduler: VQDiffusionScheduler def __init__( @@ -64,6 +65,7 @@ def __init__( tokenizer: CLIPTokenizer, transformer: Transformer2DModel, scheduler: VQDiffusionScheduler, + learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings ): super().__init__() @@ -73,13 +75,77 @@ def __init__( text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler, + learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings ) + def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance): + batch_size = len(prompt) if isinstance(prompt, list) else 1 + + # get prompt text embeddings + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + if text_input_ids.shape[-1] > self.tokenizer.model_max_length: + removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + + # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. + # While CLIP does normalize the pooled output of the text transformer when combining + # the image and text embeddings, CLIP does not directly normalize the last hidden state. + # + # CLIP normalizing the pooled output. + # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 + text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) + + # duplicate text embeddings for each generation per prompt + text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) + + if do_classifier_free_guidance: + if self.learned_classifier_free_sampling_embeddings.embeddings is None: + uncond_tokens = [""] * batch_size + + max_length = text_input_ids.shape[-1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + # TODO we might have to normalize the unconditional embeddings as well + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + else: + uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings + uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1) + + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = uncond_embeddings.shape[1] + uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1) + uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + + return text_embeddings + @torch.no_grad() def __call__( self, prompt: Union[str, List[str]], num_inference_steps: int = 100, + guidance_scale: float = 5.0, truncation_rate: float = 1.0, num_images_per_prompt: int = 1, generator: Optional[torch.Generator] = None, @@ -137,6 +203,12 @@ def __call__( batch_size = batch_size * num_images_per_prompt + do_classifier_free_guidance = guidance_scale > 1.0 + + text_embeddings = self._encode_prompt( + prompt, num_images_per_prompt, do_classifier_free_guidance + ) + if (callback_steps is None) or ( callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) ): @@ -145,35 +217,6 @@ def __call__( f" {type(callback_steps)}." ) - # get prompt text embeddings - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - - if text_input_ids.shape[-1] > self.tokenizer.model_max_length: - removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] - - # NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion. - # While CLIP does normalize the pooled output of the text transformer when combining - # the image and text embeddings, CLIP does not directly normalize the last hidden state. - # - # CLIP normalizing the pooled output. - # https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053 - text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True) - - # duplicate text embeddings for each generation per prompt - text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0) - # get the initial completely masked latents unless the user supplied it latents_shape = (batch_size, self.transformer.num_latent_pixels) @@ -198,9 +241,17 @@ def __call__( sample = latents for i, t in enumerate(self.progress_bar(timesteps_tensor)): + # expand the sample if we are doing classifier free guidance + latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample + # predict the un-noised image # model_output == `log_p_x_0` - model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample + model_output = self.transformer(latent_model_input, encoder_hidden_states=text_embeddings, timestep=t).sample + + if do_classifier_free_guidance: + model_output_uncond, model_output_text = model_output.chunk(2) + model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond) + model_output -= torch.logsumexp(model_output, dim=1, keepdim=True) model_output = self.truncate(model_output, truncation_rate)