diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 8b52859e972e..8d1052173e66 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -32,10 +32,10 @@ def get_timestep_embedding( assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32) + exponent = exponent / (half_dim - downscale_freq_shift) - emb_coeff = -math.log(max_period) / (half_dim - downscale_freq_shift) - emb = torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) - emb = torch.exp(emb * emb_coeff) + emb = torch.exp(exponent).to(device=timesteps.device) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 98244261d1cc..15cf6e26a955 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -331,7 +331,9 @@ def __init__( def forward(self, x, temb, hey=False): h = x - h = self.norm1(h) + # make sure hidden states is in float32 + # when running in half-precision + h = self.norm1(h.float()).type(h.dtype) h = self.nonlinearity(h) if self.upsample is not None: @@ -347,7 +349,9 @@ def forward(self, x, temb, hey=False): temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] h = h + temb - h = self.norm2(h) + # make sure hidden states is in float32 + # when running in half-precision + h = self.norm2(h.float()).type(h.dtype) h = self.nonlinearity(h) h = self.dropout(h) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 6203d76f2586..db4c33690c9d 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -132,6 +132,9 @@ def forward( elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) + # broadcast to batch dimension + timesteps = timesteps.broadcast_to(sample.shape[0]) + t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb) @@ -166,7 +169,9 @@ def forward( sample = upsample_block(sample, res_samples, emb) # 6. post-process - sample = self.conv_norm_out(sample) + # make sure hidden states is in float32 + # when running in half-precision + sample = self.conv_norm_out(sample.float()).type(sample.dtype) sample = self.conv_act(sample) sample = self.conv_out(sample) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index a39223811a00..25c4e37d8a6d 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -133,6 +133,9 @@ def forward( elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: timesteps = timesteps[None].to(sample.device) + # broadcast to batch dimension + timesteps = timesteps.broadcast_to(sample.shape[0]) + t_emb = self.time_proj(timesteps) emb = self.time_embedding(t_emb) @@ -172,8 +175,9 @@ def forward( sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples) # 6. post-process - - sample = self.conv_norm_out(sample) + # make sure hidden states is in float32 + # when running in half-precision + sample = self.conv_norm_out(sample.float()).type(sample.dtype) sample = self.conv_act(sample) sample = self.conv_out(sample) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9db646af09b6..407d6236e744 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -55,7 +55,13 @@ def __call__( self.text_encoder.to(torch_device) # get prompt text embeddings - text_input = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt") + text_input = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) @@ -79,19 +85,25 @@ def __call__( latents = torch.randn( (batch_size, self.unet.in_channels, height // 8, width // 8), generator=generator, + device=torch_device, ) - latents = latents.to(torch_device) + + # set timesteps + accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) + extra_set_kwargs = {} + if accepts_offset: + extra_set_kwargs["offset"] = 1 + + self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_kwargs = {} + extra_step_kwargs = {} if accepts_eta: - extra_kwargs["eta"] = eta - - self.scheduler.set_timesteps(num_inference_steps) + extra_step_kwargs["eta"] = eta for t in tqdm(self.scheduler.timesteps): # expand the latents if we are doing classifier free guidance @@ -106,7 +118,7 @@ def __call__( noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] # scale and decode the image latents with vae latents = 1 / 0.18215 * latents diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index ed76873f8a96..d513fa9d5c7f 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -59,6 +59,7 @@ def __init__( trained_betas=None, timestep_values=None, clip_sample=True, + clip_alpha_at_one=True, tensor_format="pt", ): @@ -75,7 +76,12 @@ def __init__( self.alphas = 1.0 - self.betas self.alphas_cumprod = np.cumprod(self.alphas, axis=0) - self.one = np.array(1.0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `clip_alpha_at_one` decides whether we set this paratemer simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = np.array(1.0) if clip_alpha_at_one else self.alphas_cumprod[0] # setable values self.num_inference_steps = None @@ -86,7 +92,7 @@ def __init__( def _get_variance(self, timestep, prev_timestep): alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev @@ -94,11 +100,12 @@ def _get_variance(self, timestep, prev_timestep): return variance - def set_timesteps(self, num_inference_steps): + def set_timesteps(self, num_inference_steps, offset=0): self.num_inference_steps = num_inference_steps self.timesteps = np.arange( 0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps )[::-1].copy() + self.timesteps += offset self.set_format(tensor_format=self.tensor_format) def step( @@ -126,7 +133,7 @@ def step( # 2. compute alphas, betas alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod beta_prod_t = 1 - alpha_prod_t # 3. compute predicted original sample from predicted noise also called diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 0ff38c08bd4d..894a4294d664 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -37,6 +37,7 @@ PNDMScheduler, ScoreSdeVePipeline, ScoreSdeVeScheduler, + StableDiffusionPipeline, UNet2DModel, VQModel, ) @@ -45,8 +46,6 @@ from diffusers.testing_utils import floats_tensor, slow, torch_device from diffusers.training_utils import EMAModel -from ..src.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline - torch.backends.cuda.matmul.allow_tf32 = False @@ -667,7 +666,7 @@ def test_output_pretrained(self): output_slice = output[0, -1, -3:, -3:].flatten() # fmt: off - expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03]) + expected_output_slice = torch.tensor([-4.0078e-01, -3.8304e-04, -1.2681e-01, -1.1462e-01, 2.0095e-01, 1.0893e-01, -8.8248e-02, -3.0361e-01, -9.8646e-03]) # fmt: on self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-2)) @@ -842,38 +841,51 @@ def test_ldm_text2img_fast(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU") def test_stable_diffusion(self): - pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers") + sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers") prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) - image = pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy")[ - "sample" - ] + generator = torch.Generator(device=torch_device).manual_seed(0) + with torch.autocast("cuda"): + output = sd_pipe( + [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np" + ) + + image = output["sample"] image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 512, 512, 3) - # fmt: off - expected_slice = np.array([0.09609553, 0.09020892, 0.07902172, 0.07634321, 0.08755809, 0.06491277, 0.07687345, 0.07173461, 0.07374045]) - # fmt: on + expected_slice = np.array([0.898, 0.9194, 0.91, 0.8955, 0.915, 0.919, 0.9233, 0.9307, 0.8887]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 @slow - def test_stable_diffusion_fast(self): - pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers") + @unittest.skipIf(torch_device == "cpu", "Stable diffusion is suppused to run on GPU") + def test_stable_diffusion_fast_ddim(self): + sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1-diffusers") + + scheduler = DDIMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + clip_alpha_at_one=False, + ) + sd_pipe.scheduler = scheduler prompt = "A painting of a squirrel eating a burger" - generator = torch.manual_seed(0) - image = pipe([prompt], generator=generator, num_inference_steps=5, output_type="numpy")["sample"] + generator = torch.Generator(device=torch_device).manual_seed(0) + + with torch.autocast("cuda"): + output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy") + image = output["sample"] image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 512, 512, 3) - # fmt: off - expected_slice = np.array([0.16537648, 0.17572534, 0.14657784, 0.20084214, 0.19819549, 0.16032678, 0.30438453, 0.22730353, 0.21307352]) - # fmt: on - assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + expected_slice = np.array([0.8354, 0.83, 0.866, 0.838, 0.8315, 0.867, 0.836, 0.8584, 0.869]) + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 @slow def test_score_sde_ve_pipeline(self): @@ -890,6 +902,7 @@ def test_score_sde_ve_pipeline(self): image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.64363, 0.5868, 0.3031, 0.2284, 0.7409, 0.3216, 0.25643, 0.6557, 0.2633]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2