Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def __call__(

# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device

if editing_prompt:
enable_edit_guidance = True
Expand Down Expand Up @@ -405,7 +406,7 @@ def __call__(
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
text_embeddings = self.text_encoder(text_input_ids.to(device))[0]

# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed, seq_len, _ = text_embeddings.shape
Expand Down Expand Up @@ -433,9 +434,9 @@ def __call__(
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length]
edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0]
edit_concepts = self.text_encoder(edit_concepts_input_ids.to(device))[0]
else:
edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1)
edit_concepts = editing_prompt_embeddings.to(device).repeat(batch_size, 1, 1)

# duplicate text embeddings for each generation per prompt, using mps friendly method
bs_embed_edit, seq_len_edit, _ = edit_concepts.shape
Expand Down Expand Up @@ -476,7 +477,7 @@ def __call__(
truncation=True,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0]

# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
Expand All @@ -493,7 +494,7 @@ def __call__(
# get the initial random noise unless the user supplied it

# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps

# 5. Prepare latent variables
Expand All @@ -504,7 +505,7 @@ def __call__(
height,
width,
text_embeddings.dtype,
self.device,
device,
generator,
latents,
)
Expand Down Expand Up @@ -562,12 +563,12 @@ def __call__(
if enable_edit_guidance:
concept_weights = torch.zeros(
(len(noise_pred_edit_concepts), noise_guidance.shape[0]),
device=self.device,
device=device,
dtype=noise_guidance.dtype,
)
noise_guidance_edit = torch.zeros(
(len(noise_pred_edit_concepts), *noise_guidance.shape),
device=self.device,
device=device,
dtype=noise_guidance.dtype,
)
# noise_guidance_edit = torch.zeros_like(noise_guidance)
Expand Down Expand Up @@ -644,21 +645,19 @@ def __call__(

# noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp

warmup_inds = torch.tensor(warmup_inds).to(self.device)
warmup_inds = torch.tensor(warmup_inds).to(device)
if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
concept_weights = concept_weights.to("cpu") # Offload to cpu
noise_guidance_edit = noise_guidance_edit.to("cpu")

concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds)
concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds)
concept_weights_tmp = torch.where(
concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
)
concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
# concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)

noise_guidance_edit_tmp = torch.index_select(
noise_guidance_edit.to(self.device), 0, warmup_inds
)
noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds)
noise_guidance_edit_tmp = torch.einsum(
"cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
)
Expand All @@ -669,8 +668,8 @@ def __call__(

del noise_guidance_edit_tmp
del concept_weights_tmp
concept_weights = concept_weights.to(self.device)
noise_guidance_edit = noise_guidance_edit.to(self.device)
concept_weights = concept_weights.to(device)
noise_guidance_edit = noise_guidance_edit.to(device)

concept_weights = torch.where(
concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
Expand All @@ -679,6 +678,7 @@ def __call__(
concept_weights = torch.nan_to_num(concept_weights)

noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
noise_guidance_edit = noise_guidance_edit.to(edit_momentum.device)

noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum

Expand All @@ -689,7 +689,7 @@ def __call__(
self.sem_guidance[i] = noise_guidance_edit.detach().cpu()

if sem_guidance is not None:
edit_guidance = sem_guidance[i].to(self.device)
edit_guidance = sem_guidance[i].to(device)
noise_guidance = noise_guidance + edit_guidance

noise_pred = noise_pred_uncond + noise_guidance
Expand All @@ -705,7 +705,7 @@ def __call__(
# 8. Post-processing
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
else:
image = latents
has_nsfw_concept = None
Expand Down