diff --git a/docs/source/en/hybrid_inference/overview.md b/docs/source/en/hybrid_inference/overview.md index b44393c77cbd..6a5c3313070a 100644 --- a/docs/source/en/hybrid_inference/overview.md +++ b/docs/source/en/hybrid_inference/overview.md @@ -48,6 +48,7 @@ Hybrid Inference offers a fast and simple way to offload local generation requir ## Changelog +- March 11 2025: Added Wan 2.1 VAE decode - March 10 2025: Added VAE encode - March 2 2025: Initial release with VAE decoding diff --git a/docs/source/en/hybrid_inference/vae_decode.md b/docs/source/en/hybrid_inference/vae_decode.md index 1457090550c7..5320608ddc2d 100644 --- a/docs/source/en/hybrid_inference/vae_decode.md +++ b/docs/source/en/hybrid_inference/vae_decode.md @@ -54,6 +54,7 @@ For the majority of these GPUs the memory usage % dictates other models (text en | **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) | | **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) | | **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) | +| **Wan2.1** | [https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud](https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud) | [`Wan-AI/Wan2.1-T2V-1.3B-Diffusers`](https://hf.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers) | > [!TIP] diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index fa12318f4714..638678ef78d7 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -62,7 +62,7 @@ DECODE_ENDPOINT_SD_XL = "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud/" DECODE_ENDPOINT_FLUX = "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/" DECODE_ENDPOINT_HUNYUAN_VIDEO = "https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/" - +DECODE_ENDPOINT_WAN_2_1 = "https://lafotb093i5cnx2w.us-east-1.aws.endpoints.huggingface.cloud/" ENCODE_ENDPOINT_SD_V1 = "https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/" ENCODE_ENDPOINT_SD_XL = "https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud/" diff --git a/src/diffusers/utils/remote_utils.py b/src/diffusers/utils/remote_utils.py index fbce33d97f54..2df7a19f68b4 100644 --- a/src/diffusers/utils/remote_utils.py +++ b/src/diffusers/utils/remote_utils.py @@ -80,13 +80,6 @@ def check_inputs_decode( and not isinstance(processor, (VaeImageProcessor, VideoProcessor)) ): raise ValueError("`processor` is required.") - if do_scaling and scaling_factor is None: - deprecate( - "do_scaling", - "1.0.0", - "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.", - standard_warn=False, - ) def postprocess_decode( diff --git a/tests/remote/test_remote_decode.py b/tests/remote/test_remote_decode.py index cec96e729a48..e1f3435d33ef 100644 --- a/tests/remote/test_remote_decode.py +++ b/tests/remote/test_remote_decode.py @@ -26,6 +26,7 @@ DECODE_ENDPOINT_HUNYUAN_VIDEO, DECODE_ENDPOINT_SD_V1, DECODE_ENDPOINT_SD_XL, + DECODE_ENDPOINT_WAN_2_1, ) from diffusers.utils.remote_utils import ( remote_decode, @@ -176,18 +177,6 @@ def test_output_type_pt_partial_postprocess_return_type_pt(self): f"{output_slice}", ) - def test_do_scaling_deprecation(self): - inputs = self.get_dummy_inputs() - inputs.pop("scaling_factor", None) - inputs.pop("shift_factor", None) - with self.assertWarns(FutureWarning) as warning: - _ = remote_decode(output_type="pt", partial_postprocess=True, **inputs) - self.assertEqual( - str(warning.warnings[0].message), - "`do_scaling` is deprecated, pass `scaling_factor` and `shift_factor` if required.", - str(warning.warnings[0].message), - ) - def test_input_tensor_type_base64_deprecation(self): inputs = self.get_dummy_inputs() with self.assertWarns(FutureWarning) as warning: @@ -209,7 +198,7 @@ def test_output_tensor_type_base64_deprecation(self): ) -class RemoteAutoencoderKLHunyuanVideoMixin(RemoteAutoencoderKLMixin): +class RemoteAutoencoderKLVideoMixin(RemoteAutoencoderKLMixin): def test_no_scaling(self): inputs = self.get_dummy_inputs() if inputs["scaling_factor"] is not None: @@ -221,7 +210,6 @@ def test_no_scaling(self): processor = self.processor_cls() output = remote_decode( output_type="pt", - # required for now, will be removed in next update do_scaling=False, processor=processor, **inputs, @@ -337,6 +325,8 @@ def test_output_type_mp4(self): inputs = self.get_dummy_inputs() output = remote_decode(output_type="mp4", return_type="mp4", **inputs) self.assertTrue(isinstance(output, bytes), f"Expected `bytes` output, got {type(output)}") + with open("test.mp4", "wb") as f: + f.write(output) class RemoteAutoencoderKLSDv1Tests( @@ -442,7 +432,7 @@ class RemoteAutoencoderKLFluxPackedTests( class RemoteAutoencoderKLHunyuanVideoTests( - RemoteAutoencoderKLHunyuanVideoMixin, + RemoteAutoencoderKLVideoMixin, unittest.TestCase, ): shape = ( @@ -467,6 +457,31 @@ class RemoteAutoencoderKLHunyuanVideoTests( return_pt_slice = torch.tensor([0.1656, 0.2661, 0.3157, 0.0693, 0.1755, 0.2252, 0.0127, 0.1221, 0.1708]) +class RemoteAutoencoderKLWanTests( + RemoteAutoencoderKLVideoMixin, + unittest.TestCase, +): + shape = ( + 1, + 16, + 3, + 40, + 64, + ) + out_hw = ( + 320, + 512, + ) + endpoint = DECODE_ENDPOINT_WAN_2_1 + dtype = torch.float16 + processor_cls = VideoProcessor + output_pt_slice = torch.tensor([203, 174, 178, 204, 171, 177, 209, 183, 182], dtype=torch.uint8) + partial_postprocess_return_pt_slice = torch.tensor( + [206, 209, 221, 202, 199, 222, 207, 210, 217], dtype=torch.uint8 + ) + return_pt_slice = torch.tensor([0.6196, 0.6382, 0.7310, 0.5869, 0.5625, 0.7373, 0.6240, 0.6465, 0.7002]) + + class RemoteAutoencoderKLSlowTestMixin: channels: int = 4 endpoint: str = None