Skip to content

Commit

Permalink
Fix AutoencoderTiny encoder scaling convention
Browse files Browse the repository at this point in the history
  * Add [-1, 1] -> [0, 1] rescaling to EncoderTiny (this fixes huggingface#4676)

  * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny
    (i.e. immediately after the final conv, as early as possible)

  * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward

  * Update AutoencoderTinyIntegrationTests to protect against scaling issues.
    The new test constructs a simple image, round-trips it through AutoencoderTiny,
    and confirms the decoded result is approximately equal to the source image.
    This test will fail if new AutoencoderTiny scaling issues are introduced.

  * Context: Raw TAESD weights expect images in [0, 1], but diffusers'
    convention represents images with zero-centered values in [-1, 1],
    so AutoencoderTiny needs to scale / unscale images at the start of
    encoding and at the end of decoding in order to work with diffusers.
  • Loading branch information
madebyollin committed Aug 19, 2023
1 parent 74d902e commit 0601a74
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
5 changes: 1 addition & 4 deletions src/diffusers/models/autoencoder_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,6 @@ def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
output = torch.cat(output)
else:
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
# Refer to the following discussion to know why this is needed.
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
output = output.mul_(2).sub_(1)

if not return_dict:
return (output,)
Expand All @@ -334,7 +331,7 @@ def forward(
"""
enc = self.encode(sample).latents
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
unscaled_enc = self.unscale_latents(scaled_enc)
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
dec = self.decode(unscaled_enc)

if not return_dict:
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,8 @@ def custom_forward(*inputs):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)

else:
x = self.layers(x)
# scale image from [-1, 1] to [0, 1] to match TAESD convention
x = self.layers(x.add(1).div(2))

return x

Expand Down Expand Up @@ -790,4 +791,5 @@ def custom_forward(*inputs):
else:
x = self.layers(x)

return x
# scale image from [0, 1] to [-1, 1] to match diffusers convention
return x.mul(2).sub(1)
18 changes: 11 additions & 7 deletions tests/models/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,19 +302,23 @@ def test_tae_tiling(self, in_shape, out_shape):
dec = model.decode(zeros).sample
assert dec.shape == out_shape

def test_stable_diffusion(self):
def test_roundtrip(self):
# load the autoencoder
model = self.get_sd_vae_model()
image = self.get_sd_image(seed=33)

# make a black image with white square in the middle
image = -torch.ones(1, 3, 512, 512, device=torch_device)
image[..., 128:384, 128:384] = 1.0

# round-trip the image through the autoencoder
with torch.no_grad():
sample = model(image).sample

assert sample.shape == image.shape
# the autoencoder reconstruction should match original image, sorta
def downscale(x):
return torch.nn.functional.avg_pool2d(x, 8)
assert torch_all_close(downscale(sample), downscale(image), atol=0.125)

output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604])

assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)


@slow
Expand Down

0 comments on commit 0601a74

Please sign in to comment.