Skip to content

Commit

Permalink
vq diffusion classifier free sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
williamberman committed Nov 15, 2022
1 parent 57525bb commit 4ee1e06
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 36 deletions.
36 changes: 34 additions & 2 deletions scripts/convert_vq_diffusion_to_diffusers.py
Expand Up @@ -39,8 +39,7 @@

import yaml
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.models.attention import Transformer2DModel
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel, Transformer2DModel, LearnedClassifierFreeSamplingEmbeddings
from transformers import CLIPTextModel, CLIPTokenizer
from yaml.loader import FullLoader

Expand Down Expand Up @@ -826,6 +825,21 @@ def read_config_file(filename):
transformer_model, checkpoint
)

# classifier free sampling embeddings interlude

# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
# model, so we pull them off the checkpoint before the checkpoint is deleted.

learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf

if learnable_classifier_free_sampling_embeddings:
learned_classifier_free_sampling_embeddings_embeddings = checkpoint['transformer.empty_text_embed']
else:
learned_classifier_free_sampling_embeddings_embeddings = None


# done classifier free sampling embeddings interlude

with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
del diffusers_transformer_checkpoint
Expand Down Expand Up @@ -871,13 +885,31 @@ def read_config_file(filename):

# done scheduler

# learned classifier free sampling embeddings

with init_empty_weights():
learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(learnable_classifier_free_sampling_embeddings)

learned_classifier_free_sampling_checkpoint = {
'embeddings': learned_classifier_free_sampling_embeddings_embeddings.float()
}

with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file:
torch.save(learned_classifier_free_sampling_checkpoint , learned_classifier_free_sampling_checkpoint_file.name)
del learned_classifier_free_sampling_checkpoint
del learned_classifier_free_sampling_embeddings_embeddings
load_checkpoint_and_dispatch(learned_classifier_free_sampling_embeddings_model, learned_classifier_free_sampling_checkpoint_file.name, device_map='auto')

# learned classifier free sampling embeddings

print(f"saving VQ diffusion model, path: {args.dump_path}")

pipe = VQDiffusionPipeline(
vqvae=vqvae_model,
transformer=transformer_model,
tokenizer=tokenizer_model,
text_encoder=text_encoder_model,
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model,
scheduler=scheduler_model,
)
pipe.save_pretrained(args.dump_path)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/__init__.py
Expand Up @@ -18,7 +18,7 @@

if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel, LearnedClassifierFreeSamplingEmbeddings
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Expand Up @@ -21,6 +21,7 @@
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
from .vae import AutoencoderKL, VQModel
from .embeddings import LearnedClassifierFreeSamplingEmbeddings

if is_flax_available():
from .unet_2d_condition_flax import FlaxUNet2DConditionModel
Expand Down
14 changes: 14 additions & 0 deletions src/diffusers/models/embeddings.py
Expand Up @@ -17,6 +17,9 @@
import torch
from torch import nn

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin


def get_timestep_embedding(
timesteps: torch.Tensor,
Expand Down Expand Up @@ -198,3 +201,14 @@ def forward(self, index):
emb = emb + pos_emb[:, : emb.shape[1], :]

return emb


class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, learnable: bool):
super().__init__()

if learnable:
self.embeddings = torch.nn.Parameter(torch.empty(77, 512))
else:
self.embeddings = None
117 changes: 84 additions & 33 deletions src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
Expand Up @@ -16,17 +16,17 @@

import torch

from diffusers import Transformer2DModel, VQModel
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
from diffusers import Transformer2DModel, VQModel, LearnedClassifierFreeSamplingEmbeddings
from transformers import CLIPTextModel, CLIPTokenizer

from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler

from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...utils import logging


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


class VQDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using VQ Diffusion
Expand Down Expand Up @@ -55,6 +55,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel
tokenizer: CLIPTokenizer
transformer: Transformer2DModel
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings
scheduler: VQDiffusionScheduler

def __init__(
Expand All @@ -64,6 +65,7 @@ def __init__(
tokenizer: CLIPTokenizer,
transformer: Transformer2DModel,
scheduler: VQDiffusionScheduler,
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings
):
super().__init__()

Expand All @@ -73,13 +75,77 @@ def __init__(
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings
)

def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

if do_classifier_free_guidance:
if self.learned_classifier_free_sampling_embeddings.embeddings is None:
uncond_tokens = [""] * batch_size

max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
# TODO we might have to normalize the unconditional embeddings as well
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
else:
uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings
uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)

# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

return text_embeddings

@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: int = 100,
guidance_scale: float = 5.0,
truncation_rate: float = 1.0,
num_images_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
Expand Down Expand Up @@ -137,6 +203,12 @@ def __call__(

batch_size = batch_size * num_images_per_prompt

do_classifier_free_guidance = guidance_scale > 1.0

text_embeddings = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance
)

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
Expand All @@ -145,35 +217,6 @@ def __call__(
f" {type(callback_steps)}."
)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

# get the initial completely masked latents unless the user supplied it

latents_shape = (batch_size, self.transformer.num_latent_pixels)
Expand All @@ -198,9 +241,17 @@ def __call__(
sample = latents

for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the sample if we are doing classifier free guidance
latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample

# predict the un-noised image
# model_output == `log_p_x_0`
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
model_output = self.transformer(latent_model_input, encoder_hidden_states=text_embeddings, timestep=t).sample

if do_classifier_free_guidance:
model_output_uncond, model_output_text = model_output.chunk(2)
model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond)
model_output -= torch.logsumexp(model_output, dim=1, keepdim=True)

model_output = self.truncate(model_output, truncation_rate)

Expand Down

0 comments on commit 4ee1e06

Please sign in to comment.