Skip to content

Commit

Permalink
vq diffusion classifier free sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
williamberman committed Nov 15, 2022
1 parent 57525bb commit f5fbbbe
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 34 deletions.
47 changes: 45 additions & 2 deletions scripts/convert_vq_diffusion_to_diffusers.py
Expand Up @@ -39,8 +39,13 @@

import yaml
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.models.attention import Transformer2DModel
from diffusers import (
LearnedClassifierFreeSamplingEmbeddings,
Transformer2DModel,
VQDiffusionPipeline,
VQDiffusionScheduler,
VQModel,
)
from transformers import CLIPTextModel, CLIPTokenizer
from yaml.loader import FullLoader

Expand Down Expand Up @@ -826,6 +831,20 @@ def read_config_file(filename):
transformer_model, checkpoint
)

# classifier free sampling embeddings interlude

# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
# model, so we pull them off the checkpoint before the checkpoint is deleted.

learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf

if learnable_classifier_free_sampling_embeddings:
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
else:
learned_classifier_free_sampling_embeddings_embeddings = None

# done classifier free sampling embeddings interlude

with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
del diffusers_transformer_checkpoint
Expand Down Expand Up @@ -871,13 +890,37 @@ def read_config_file(filename):

# done scheduler

# learned classifier free sampling embeddings

with init_empty_weights():
learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(
learnable_classifier_free_sampling_embeddings
)

learned_classifier_free_sampling_checkpoint = {
"embeddings": learned_classifier_free_sampling_embeddings_embeddings.float()
}

with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file:
torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name)
del learned_classifier_free_sampling_checkpoint
del learned_classifier_free_sampling_embeddings_embeddings
load_checkpoint_and_dispatch(
learned_classifier_free_sampling_embeddings_model,
learned_classifier_free_sampling_checkpoint_file.name,
device_map="auto",
)

# learned classifier free sampling embeddings

print(f"saving VQ diffusion model, path: {args.dump_path}")

pipe = VQDiffusionPipeline(
vqvae=vqvae_model,
transformer=transformer_model,
tokenizer=tokenizer_model,
text_encoder=text_encoder_model,
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model,
scheduler=scheduler_model,
)
pipe.save_pretrained(args.dump_path)
Expand Down
10 changes: 9 additions & 1 deletion src/diffusers/__init__.py
Expand Up @@ -18,7 +18,15 @@

if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .models import (
AutoencoderKL,
LearnedClassifierFreeSamplingEmbeddings,
Transformer2DModel,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
VQModel,
)
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Expand Up @@ -17,6 +17,7 @@

if is_torch_available():
from .attention import Transformer2DModel
from .embeddings import LearnedClassifierFreeSamplingEmbeddings
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
Expand Down
14 changes: 14 additions & 0 deletions src/diffusers/models/embeddings.py
Expand Up @@ -17,6 +17,9 @@
import torch
from torch import nn

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin


def get_timestep_embedding(
timesteps: torch.Tensor,
Expand Down Expand Up @@ -198,3 +201,14 @@ def forward(self, index):
emb = emb + pos_emb[:, : emb.shape[1], :]

return emb


class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin):
@register_to_config
def __init__(self, learnable: bool):
super().__init__()

if learnable:
self.embeddings = torch.nn.Parameter(torch.empty(77, 512))
else:
self.embeddings = None
114 changes: 83 additions & 31 deletions src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
Expand Up @@ -16,7 +16,7 @@

import torch

from diffusers import Transformer2DModel, VQModel
from diffusers import LearnedClassifierFreeSamplingEmbeddings, Transformer2DModel, VQModel
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
from transformers import CLIPTextModel, CLIPTokenizer

Expand Down Expand Up @@ -55,6 +55,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel
tokenizer: CLIPTokenizer
transformer: Transformer2DModel
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings
scheduler: VQDiffusionScheduler

def __init__(
Expand All @@ -64,6 +65,7 @@ def __init__(
tokenizer: CLIPTokenizer,
transformer: Transformer2DModel,
scheduler: VQDiffusionScheduler,
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings,
):
super().__init__()

Expand All @@ -73,13 +75,78 @@ def __init__(
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
)

def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

if do_classifier_free_guidance:
if self.learned_classifier_free_sampling_embeddings.embeddings is None:
uncond_tokens = [""] * batch_size

max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# See comment for normalizing text embeddings
uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True)
else:
uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings
uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)

# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

return text_embeddings

@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: int = 100,
guidance_scale: float = 5.0,
truncation_rate: float = 1.0,
num_images_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
Expand Down Expand Up @@ -137,6 +204,10 @@ def __call__(

batch_size = batch_size * num_images_per_prompt

do_classifier_free_guidance = guidance_scale > 1.0

text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
Expand All @@ -145,35 +216,6 @@ def __call__(
f" {type(callback_steps)}."
)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

# get the initial completely masked latents unless the user supplied it

latents_shape = (batch_size, self.transformer.num_latent_pixels)
Expand All @@ -198,9 +240,19 @@ def __call__(
sample = latents

for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the sample if we are doing classifier free guidance
latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample

# predict the un-noised image
# model_output == `log_p_x_0`
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
model_output = self.transformer(
latent_model_input, encoder_hidden_states=text_embeddings, timestep=t
).sample

if do_classifier_free_guidance:
model_output_uncond, model_output_text = model_output.chunk(2)
model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond)
model_output -= torch.logsumexp(model_output, dim=1, keepdim=True)

model_output = self.truncate(model_output, truncation_rate)

Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Expand Up @@ -34,6 +34,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class LearnedClassifierFreeSamplingEmbeddings(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class Transformer2DModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down

0 comments on commit f5fbbbe

Please sign in to comment.