1- from typing import Optional , Union , List , Callable , Dict , Any
1+ from typing import Any , Callable , Dict , List , Optional , Union
22
33import numpy as np
44import PIL
55import torch
6+
67from diffusers import StableDiffusionImg2ImgPipeline
78from diffusers .pipelines .stable_diffusion import StableDiffusionPipelineOutput
89
910
1011class MaskedStableDiffusionImg2ImgPipeline (StableDiffusionImg2ImgPipeline ):
11-
1212 debug_save = False
1313
1414 @torch .no_grad ()
@@ -38,13 +38,13 @@ def __call__(
3838 callback_steps : int = 1 ,
3939 cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
4040 mask : Union [
41- torch .FloatTensor ,
42- PIL .Image .Image ,
43- np .ndarray ,
44- List [torch .FloatTensor ],
45- List [PIL .Image .Image ],
46- List [np .ndarray ],
47- ] = None ,
41+ torch .FloatTensor ,
42+ PIL .Image .Image ,
43+ np .ndarray ,
44+ List [torch .FloatTensor ],
45+ List [PIL .Image .Image ],
46+ List [np .ndarray ],
47+ ] = None ,
4848 ):
4949 r"""
5050 The call function to the pipeline for generation.
@@ -158,7 +158,8 @@ def __call__(
158158
159159 # mean of the latent distribution
160160 init_latents = [
161- self .vae .encode (image .to (device = device , dtype = prompt_embeds .dtype )[i : i + 1 ]).latent_dist .mean for i in range (batch_size )
161+ self .vae .encode (image .to (device = device , dtype = prompt_embeds .dtype )[i : i + 1 ]).latent_dist .mean
162+ for i in range (batch_size )
162163 ]
163164 init_latents = torch .cat (init_latents , dim = 0 )
164165
@@ -194,7 +195,7 @@ def __call__(
194195 latents = torch .lerp (init_latents * self .vae .config .scaling_factor , latents , latent_mask )
195196 noise_pred = torch .lerp (torch .zeros_like (noise_pred ), noise_pred , latent_mask )
196197
197- # compute the previous noisy sample x_t -> x_t-1
198+ # compute the previous noisy sample x_t -> x_t-1
198199 latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs , return_dict = False )[0 ]
199200
200201 # call the callback, if provided
@@ -236,7 +237,7 @@ def __call__(
236237
237238 def _make_latent_mask (self , latents , mask ):
238239 if mask is not None :
239- latent_mask = list ()
240+ latent_mask = []
240241 if not isinstance (mask , list ):
241242 tmp_mask = [mask ]
242243 else :
@@ -250,7 +251,7 @@ def _make_latent_mask(self, latents, mask):
250251 m = m / 255.0
251252 m = self .image_processor .numpy_to_pil (m )[0 ]
252253 if m .mode != "L" :
253- m = m .convert ('L' )
254+ m = m .convert ("L" )
254255 resized = self .image_processor .resize (m , l_height , l_width )
255256 if self .debug_save :
256257 resized .save ("latent_mask.png" )
0 commit comments