From 043c0edfc3d4024b740b58097449ed77467dbeb5 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 11 Jun 2024 05:08:09 +0200 Subject: [PATCH 1/2] fix --- .../pipeline_semantic_stable_diffusion.py | 32 ++++++++++--------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index e068387b6162..f9714341cf02 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -376,6 +376,7 @@ def __call__( # 2. Define call parameters batch_size = 1 if isinstance(prompt, str) else len(prompt) + device = self._execution_device if editing_prompt: enable_edit_guidance = True @@ -405,7 +406,7 @@ def __call__( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] + text_embeddings = self.text_encoder(text_input_ids.to(device))[0] # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = text_embeddings.shape @@ -433,9 +434,9 @@ def __call__( f" {self.tokenizer.model_max_length} tokens: {removed_text}" ) edit_concepts_input_ids = edit_concepts_input_ids[:, : self.tokenizer.model_max_length] - edit_concepts = self.text_encoder(edit_concepts_input_ids.to(self.device))[0] + edit_concepts = self.text_encoder(edit_concepts_input_ids.to(device))[0] else: - edit_concepts = editing_prompt_embeddings.to(self.device).repeat(batch_size, 1, 1) + edit_concepts = editing_prompt_embeddings.to(device).repeat(batch_size, 1, 1) # duplicate text embeddings for each generation per prompt, using mps friendly method bs_embed_edit, seq_len_edit, _ = edit_concepts.shape @@ -476,7 +477,7 @@ def __call__( truncation=True, return_tensors="pt", ) - uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(device))[0] # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = uncond_embeddings.shape[1] @@ -493,7 +494,7 @@ def __call__( # get the initial random noise unless the user supplied it # 4. Prepare timesteps - self.scheduler.set_timesteps(num_inference_steps, device=self.device) + self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables @@ -504,7 +505,7 @@ def __call__( height, width, text_embeddings.dtype, - self.device, + device, generator, latents, ) @@ -562,12 +563,12 @@ def __call__( if enable_edit_guidance: concept_weights = torch.zeros( (len(noise_pred_edit_concepts), noise_guidance.shape[0]), - device=self.device, + device=device, dtype=noise_guidance.dtype, ) noise_guidance_edit = torch.zeros( (len(noise_pred_edit_concepts), *noise_guidance.shape), - device=self.device, + device=device, dtype=noise_guidance.dtype, ) # noise_guidance_edit = torch.zeros_like(noise_guidance) @@ -644,12 +645,12 @@ def __call__( # noise_guidance_edit = noise_guidance_edit + noise_guidance_edit_tmp - warmup_inds = torch.tensor(warmup_inds).to(self.device) + warmup_inds = torch.tensor(warmup_inds).to(device) if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0: concept_weights = concept_weights.to("cpu") # Offload to cpu noise_guidance_edit = noise_guidance_edit.to("cpu") - concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds) + concept_weights_tmp = torch.index_select(concept_weights.to(device), 0, warmup_inds) concept_weights_tmp = torch.where( concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp ) @@ -657,7 +658,7 @@ def __call__( # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) noise_guidance_edit_tmp = torch.index_select( - noise_guidance_edit.to(self.device), 0, warmup_inds + noise_guidance_edit.to(device), 0, warmup_inds ) noise_guidance_edit_tmp = torch.einsum( "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp @@ -669,8 +670,8 @@ def __call__( del noise_guidance_edit_tmp del concept_weights_tmp - concept_weights = concept_weights.to(self.device) - noise_guidance_edit = noise_guidance_edit.to(self.device) + concept_weights = concept_weights.to(device) + noise_guidance_edit = noise_guidance_edit.to(device) concept_weights = torch.where( concept_weights < 0, torch.zeros_like(concept_weights), concept_weights @@ -679,6 +680,7 @@ def __call__( concept_weights = torch.nan_to_num(concept_weights) noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit) + noise_guidance_edit = noise_guidance_edit.to(edit_momentum.device) noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum @@ -689,7 +691,7 @@ def __call__( self.sem_guidance[i] = noise_guidance_edit.detach().cpu() if sem_guidance is not None: - edit_guidance = sem_guidance[i].to(self.device) + edit_guidance = sem_guidance[i].to(device) noise_guidance = noise_guidance + edit_guidance noise_pred = noise_pred_uncond + noise_guidance @@ -705,7 +707,7 @@ def __call__( # 8. Post-processing if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype) + image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype) else: image = latents has_nsfw_concept = None From 135dfd05d68834d6825ade6868434ff2ce60a487 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Tue, 11 Jun 2024 03:11:20 +0000 Subject: [PATCH 2/2] style --- .../pipeline_semantic_stable_diffusion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py index f9714341cf02..8f620b64327e 100644 --- a/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +++ b/src/diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py @@ -657,9 +657,7 @@ def __call__( concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0) # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp) - noise_guidance_edit_tmp = torch.index_select( - noise_guidance_edit.to(device), 0, warmup_inds - ) + noise_guidance_edit_tmp = torch.index_select(noise_guidance_edit.to(device), 0, warmup_inds) noise_guidance_edit_tmp = torch.einsum( "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp )