Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/diffusers/models/autoencoder_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,6 @@ def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[Decode
output = torch.cat(output)
else:
output = self._tiled_decode(x) if self.use_tiling else self.decoder(x)
# Refer to the following discussion to know why this is needed.
# https://github.com/huggingface/diffusers/pull/4384#discussion_r1279401854
output = output.mul_(2).sub_(1)

if not return_dict:
return (output,)
Expand All @@ -333,8 +330,15 @@ def forward(
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
enc = self.encode(sample).latents

# scale latents to be in [0, 1], then quantize latents to a byte tensor,
# as if we were storing the latents in an RGBA uint8 image.
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte()
unscaled_enc = self.unscale_latents(scaled_enc)

# unquantize latents back into [0, 1], then unscale latents back to their original range,
# as if we were loading the latents from an RGBA uint8 image.
unscaled_enc = self.unscale_latents(scaled_enc / 255.0)
Comment thread
sayakpaul marked this conversation as resolved.

dec = self.decode(unscaled_enc)

if not return_dict:
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,8 @@ def custom_forward(*inputs):
x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)

else:
x = self.layers(x)
# scale image from [-1, 1] to [0, 1] to match TAESD convention
x = self.layers(x.add(1).div(2))

return x

Expand Down Expand Up @@ -790,4 +791,5 @@ def custom_forward(*inputs):
else:
x = self.layers(x)

return x
# scale image from [0, 1] to [-1, 1] to match diffusers convention
return x.mul(2).sub(1)
24 changes: 23 additions & 1 deletion tests/models/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,32 @@ def test_stable_diffusion(self):
assert sample.shape == image.shape

output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor([0.9858, 0.9262, 0.8629, 1.0974, -0.091, -0.2485, 0.0936, 0.0604])
expected_output_slice = torch.tensor([0.0093, 0.6385, -0.1274, 0.1631, -0.1762, 0.5232, -0.3108, -0.0382])
Comment thread
sayakpaul marked this conversation as resolved.

assert torch_all_close(output_slice, expected_output_slice, atol=3e-3)

@parameterized.expand([(True,), (False,)])
def test_tae_roundtrip(self, enable_tiling):
# load the autoencoder
model = self.get_sd_vae_model()
if enable_tiling:
model.enable_tiling()

# make a black image with a white square in the middle,
# which is large enough to split across multiple tiles
image = -torch.ones(1, 3, 1024, 1024, device=torch_device)
image[..., 256:768, 256:768] = 1.0

# round-trip the image through the autoencoder
with torch.no_grad():
sample = model(image).sample

# the autoencoder reconstruction should match original image, sorta
def downscale(x):
return torch.nn.functional.avg_pool2d(x, model.spatial_scale_factor)

assert torch_all_close(downscale(sample), downscale(image), atol=0.125)


@slow
class AutoencoderKLIntegrationTests(unittest.TestCase):
Expand Down