Skip to content

Commit

Permalink
[wip] vq diffusion classifier free sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
williamberman committed Nov 15, 2022
1 parent 57525bb commit d0d5beb
Showing 1 changed file with 73 additions and 30 deletions.
103 changes: 73 additions & 30 deletions src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
Expand Up @@ -75,11 +75,69 @@ def __init__(
scheduler=scheduler,
)

def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

if do_classifier_free_guidance:
uncond_tokens = [""] * batch_size

max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]

# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)

# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

return text_embeddings

@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: int = 100,
guidance_scale: float = 5.0,
truncation_rate: float = 1.0,
num_images_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
Expand Down Expand Up @@ -137,6 +195,12 @@ def __call__(

batch_size = batch_size * num_images_per_prompt

do_classifier_free_guidance = guidance_scale > 1.0

text_embeddings = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance
)

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
Expand All @@ -145,35 +209,6 @@ def __call__(
f" {type(callback_steps)}."
)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids

if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]

# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)

# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)

# get the initial completely masked latents unless the user supplied it

latents_shape = (batch_size, self.transformer.num_latent_pixels)
Expand All @@ -198,9 +233,17 @@ def __call__(
sample = latents

for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the sample if we are doing classifier free guidance
latent_model_input = torch.cat([sample] * 2) if do_classifier_free_guidance else sample

# predict the un-noised image
# model_output == `log_p_x_0`
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
model_output = self.transformer(latent_model_input, encoder_hidden_states=text_embeddings, timestep=t).sample

if do_classifier_free_guidance:
model_output_uncond, model_output_text = model_output.chunk(2)
model_output = model_output_uncond + guidance_scale * (model_output_text - model_output_uncond)
model_output -= torch.logsumexp(model_output, dim=1, keepdim=True)

model_output = self.truncate(model_output, truncation_rate)

Expand Down

0 comments on commit d0d5beb

Please sign in to comment.