Skip to content

Commit

Permalink
[img2img, inpainting] fix fp16 inference (#769)
Browse files Browse the repository at this point in the history
* handle dtype in vae and image2image pipeline

* fix inpaint in fp16

* dtype should be handled in add_noise

* style

* address review comments

* add simple fast tests to check fp16

* fix test name

* put mask in fp16
  • Loading branch information
patil-suraj authored and patrickvonplaten committed Oct 11, 2022
1 parent 7a6cf89 commit b2c9b54
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 60 deletions.
8 changes: 6 additions & 2 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,16 @@ def __init__(self, parameters, deterministic=False):
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
self.var = self.std = torch.zeros_like(
self.mean, device=self.parameters.device, dtype=self.parameters.dtype
)

def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
device = self.parameters.device
sample_device = "cpu" if device.type == "mps" else device
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device).to(device)
sample = torch.randn(self.mean.shape, generator=generator, device=sample_device)
# make sure sample is on the same device as the parameters and has same dtype
sample = sample.to(device=device, dtype=self.parameters.dtype)
x = self.mean + self.std * sample
return x

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,26 +217,6 @@ def __call__(
if isinstance(init_image, PIL.Image.Image):
init_image = preprocess(init_image)

# encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents

# expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)

# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
Expand Down Expand Up @@ -297,6 +277,28 @@ def __call__(
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

# encode the init image into latents and scale the latents
latents_dtype = text_embeddings.dtype
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents

# expand init_latents for batch_size
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)

# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (畏) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to 畏 in DDIM paper: https://arxiv.org/abs/2010.02502
Expand Down Expand Up @@ -341,7 +343,9 @@ def __call__(
image = image.cpu().permute(0, 2, 3, 1).numpy()

safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)

if output_type == "pil":
image = self.numpy_to_pil(image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,43 +234,6 @@ def __call__(
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)

# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)
init_image = init_image.to(self.device)

# encode the init image into latents and scale the latents
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)

init_latents = 0.18215 * init_latents

# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents

# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(self.device)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)

# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")

# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)

# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
Expand Down Expand Up @@ -335,6 +298,43 @@ def __call__(
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

# preprocess image
if not isinstance(init_image, torch.FloatTensor):
init_image = preprocess_image(init_image)

# encode the init image into latents and scale the latents
latents_dtype = text_embeddings.dtype
init_image = init_image.to(device=self.device, dtype=latents_dtype)
init_latent_dist = self.vae.encode(init_image).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents

# Expand init_latents for batch_size and num_images_per_prompt
init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
init_latents_orig = init_latents

# preprocess mask
if not isinstance(mask_image, torch.FloatTensor):
mask_image = preprocess_mask(mask_image)
mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)

# check sizes
if not mask.shape == init_latents.shape:
raise ValueError("The mask and init_image should be the same size!")

# get the original timestep using init_timestep
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)

timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)

# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device, dtype=latents_dtype)
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)

# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (畏) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to 畏 in DDIM paper: https://arxiv.org/abs/2010.02502
Expand Down
118 changes: 118 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,124 @@ def test_stable_diffusion_inpaint_num_images_per_prompt(self):

assert images.shape == (batch_size * num_images_per_prompt, 32, 32, 3)

@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
def test_stable_diffusion_fp16(self):
"""Test that stable diffusion works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

# put models in fp16
unet = unet.half()
vae = vae.half()
bert = bert.half()

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="np").images

assert image.shape == (1, 128, 128, 3)

@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
def test_stable_diffusion_img2img_fp16(self):
"""Test that stable diffusion img2img works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

init_image = self.dummy_image.to(torch_device)

# put models in fp16
unet = unet.half()
vae = vae.half()
bert = bert.half()

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionImg2ImgPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe(
[prompt],
generator=generator,
num_inference_steps=2,
output_type="np",
init_image=init_image,
).images

assert image.shape == (1, 32, 32, 3)

@unittest.skipIf(torch_device == "cpu", "This test requires a GPU")
def test_stable_diffusion_inpaint_fp16(self):
"""Test that stable diffusion inpaint works with fp16"""
unet = self.dummy_cond_unet
scheduler = PNDMScheduler(skip_prk_steps=True)
vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
mask_image = Image.fromarray(np.uint8(image + 4)).convert("RGB").resize((128, 128))

# put models in fp16
unet = unet.half()
vae = vae.half()
bert = bert.half()

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionInpaintPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = sd_pipe(
[prompt],
generator=generator,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
).images

assert image.shape == (1, 32, 32, 3)


class PipelineTesterMixin(unittest.TestCase):
def tearDown(self):
Expand Down

0 comments on commit b2c9b54

Please sign in to comment.