Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vq diffusion classifier free sampling #1294

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 47 additions & 2 deletions scripts/convert_vq_diffusion_to_diffusers.py
Expand Up @@ -39,8 +39,13 @@

import yaml
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.models.attention import Transformer2DModel
from diffusers import (
LearnedClassifierFreeSamplingEmbeddings,
Transformer2DModel,
VQDiffusionPipeline,
VQDiffusionScheduler,
VQModel,
)
from transformers import CLIPTextModel, CLIPTokenizer
from yaml.loader import FullLoader

Expand Down Expand Up @@ -826,6 +831,20 @@ def read_config_file(filename):
transformer_model, checkpoint
)

# classifier free sampling embeddings interlude

# The learned embeddings are stored on the transformer in the original VQ-diffusion. We store them on a separate
# model, so we pull them off the checkpoint before the checkpoint is deleted.

learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf

if learnable_classifier_free_sampling_embeddings:
learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"]
else:
learned_classifier_free_sampling_embeddings_embeddings = None

# done classifier free sampling embeddings interlude

with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
del diffusers_transformer_checkpoint
Expand Down Expand Up @@ -871,13 +890,39 @@ def read_config_file(filename):

# done scheduler

# learned classifier free sampling embeddings

with init_empty_weights():
learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings(
learnable_classifier_free_sampling_embeddings,
hidden_size=text_encoder_model.config.hidden_size,
length=tokenizer_model.model_max_length,
)

learned_classifier_free_sampling_checkpoint = {
"embeddings": learned_classifier_free_sampling_embeddings_embeddings.float()
}

with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file:
torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name)
del learned_classifier_free_sampling_checkpoint
del learned_classifier_free_sampling_embeddings_embeddings
load_checkpoint_and_dispatch(
learned_classifier_free_sampling_embeddings_model,
learned_classifier_free_sampling_checkpoint_file.name,
device_map="auto",
)

# done learned classifier free sampling embeddings

print(f"saving VQ diffusion model, path: {args.dump_path}")

pipe = VQDiffusionPipeline(
vqvae=vqvae_model,
transformer=transformer_model,
tokenizer=tokenizer_model,
text_encoder=text_encoder_model,
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model,
scheduler=scheduler_model,
)
pipe.save_pretrained(args.dump_path)
Expand Down
10 changes: 9 additions & 1 deletion src/diffusers/__init__.py
Expand Up @@ -18,7 +18,15 @@

if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .models import (
AutoencoderKL,
LearnedClassifierFreeSamplingEmbeddings,
Transformer2DModel,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
VQModel,
)
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/__init__.py
Expand Up @@ -17,6 +17,7 @@

if is_torch_available():
from .attention import Transformer2DModel
from .embeddings import LearnedClassifierFreeSamplingEmbeddings
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel
Expand Down
26 changes: 26 additions & 0 deletions src/diffusers/models/embeddings.py
Expand Up @@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional

import numpy as np
import torch
from torch import nn

from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin


def get_timestep_embedding(
timesteps: torch.Tensor,
Expand Down Expand Up @@ -198,3 +202,25 @@ def forward(self, index):
emb = emb + pos_emb[:, : emb.shape[1], :]

return emb


class LearnedClassifierFreeSamplingEmbeddings(ModelMixin, ConfigMixin):
"""
Utility class for storing learned text embeddings for classifier free sampling
"""

@register_to_config
def __init__(self, learnable: bool, hidden_size: Optional[int] = None, length: Optional[int] = None):
super().__init__()

self.learnable = learnable

if self.learnable:
assert hidden_size is not None, "learnable=True requires `hidden_size` to be set"
assert length is not None, "learnable=True requires `length` to be set"

embeddings = torch.zeros(length, hidden_size)
else:
embeddings = None

self.embeddings = nn.Parameter(embeddings)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is the preferred way to add the learned embeddings to the pipeline. An alternative might be to add the additional vector to the scheduler instead

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's very model specific, so moving it to the pipeline here directly :-)
Think that's a bit cleaner! The model works much better now though - thanks!

120 changes: 89 additions & 31 deletions src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
Expand Up @@ -16,7 +16,7 @@

import torch

from diffusers import Transformer2DModel, VQModel
from diffusers import LearnedClassifierFreeSamplingEmbeddings, Transformer2DModel, VQModel
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
from transformers import CLIPTextModel, CLIPTokenizer

Expand Down Expand Up @@ -55,6 +55,7 @@ class VQDiffusionPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel
tokenizer: CLIPTokenizer
transformer: Transformer2DModel
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings
scheduler: VQDiffusionScheduler

def __init__(
Expand All @@ -64,6 +65,7 @@ def __init__(
tokenizer: CLIPTokenizer,
transformer: Transformer2DModel,
scheduler: VQDiffusionScheduler,
learned_classifier_free_sampling_embeddings: LearnedClassifierFreeSamplingEmbeddings,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's definitely the right way to do it - it's quite specific to vq-diffusion IMO though, so will move it here :-)

):
super().__init__()

Expand All @@ -73,13 +75,78 @@ def __init__(
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings,
)

def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

if do_classifier_free_guidance:
if self.learned_classifier_free_sampling_embeddings.learnable:
uncond_embeddings = self.learned_classifier_free_sampling_embeddings.embeddings
uncond_embeddings = uncond_embeddings.unsqueeze(0).repeat(batch_size, 1, 1)
else:
uncond_tokens = [""] * batch_size

max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
# See comment for normalizing text embeddings
uncond_embeddings = uncond_embeddings / uncond_embeddings.norm(dim=-1, keepdim=True)

# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

return text_embeddings

@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: int = 100,
guidance_scale: float = 5.0,
truncation_rate: float = 1.0,
num_images_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
Expand All @@ -98,6 +165,12 @@ def __call__(
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
Expand Down Expand Up @@ -137,6 +210,10 @@ def __call__(

batch_size = batch_size * num_images_per_prompt

do_classifier_free_guidance = guidance_scale > 1.0

text_embeddings = self._encode_prompt(prompt, num_images_per_prompt, do_classifier_free_guidance)

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
Expand All @@ -145,35 +222,6 @@ def __call__(
f" {type(callback_steps)}."
)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

# get the initial completely masked latents unless the user supplied it

latents_shape = (batch_size, self.transformer.num_latent_pixels)
Expand All @@ -198,9 +246,19 @@ def __call__(
sample = latents

for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the sample if we are doing classifier free guidance
latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample

# predict the un-noised image
# model_output == `log_p_x_0`
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
model_output = self.transformer(
latent_model_input, encoder_hidden_states=text_embeddings, timestep=t
).sample

if do_classifier_free_guidance:
model_output_uncond, model_output_text = model_output.chunk(2)
model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond)
model_output -= torch.logsumexp(model_output, dim=1, keepdim=True)

model_output = self.truncate(model_output, truncation_rate)

Expand Down
15 changes: 15 additions & 0 deletions src/diffusers/utils/dummy_pt_objects.py
Expand Up @@ -34,6 +34,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class LearnedClassifierFreeSamplingEmbeddings(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])

@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])

@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])


class Transformer2DModel(metaclass=DummyObject):
_backends = ["torch"]

Expand Down