From 9ebb411d4671c08c5940a0f295c4a642bd09fce6 Mon Sep 17 00:00:00 2001 From: Darren Hsu <35377472+darhsu@users.noreply.github.com> Date: Fri, 20 Sep 2024 02:49:22 +0000 Subject: [PATCH 1/5] Support bfloat16 for Upsample2D --- src/diffusers/models/upsampling.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index fd5ed28c7070..d60797f43432 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -17,11 +17,15 @@ import torch import torch.nn as nn import torch.nn.functional as F +from packaging import version from ..utils import deprecate from .normalization import RMSNorm +is_torch_less_than_2_1 = version.parse(version.parse(torch.__version__).base_version) < version.parse("2.1") + + class Upsample1D(nn.Module): """A 1D upsampling layer with an optional convolution. @@ -151,11 +155,10 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None if self.use_conv_transpose: return self.conv(hidden_states) - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch - # https://github.com/pytorch/pytorch/issues/86679 + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1 + # https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767 dtype = hidden_states.dtype - if dtype == torch.bfloat16: + if dtype == torch.bfloat16 and is_torch_less_than_2_1: hidden_states = hidden_states.to(torch.float32) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 @@ -170,8 +173,8 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None else: hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: + # Cast back to original dtype + if dtype == torch.bfloat16 and is_torch_less_than_2_1: hidden_states = hidden_states.to(dtype) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed From 6e28deca80bb0b470dd5c51045afd47a5cd672af Mon Sep 17 00:00:00 2001 From: Darren Hsu <35377472+darhsu@users.noreply.github.com> Date: Fri, 20 Sep 2024 04:08:37 +0000 Subject: [PATCH 2/5] Add test and use is_torch_version --- src/diffusers/models/upsampling.py | 4 ++-- tests/models/test_layers_utils.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index d60797f43432..e121109ae5e8 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -17,13 +17,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from packaging import version from ..utils import deprecate +from ..utils.import_utils import is_torch_version from .normalization import RMSNorm -is_torch_less_than_2_1 = version.parse(version.parse(torch.__version__).base_version) < version.parse("2.1") +is_torch_less_than_2_1 = is_torch_version("<", "2.1") class Upsample1D(nn.Module): diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index 66e142f8c66a..4a4dc48e5002 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -120,6 +120,18 @@ def test_upsample_default(self): expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254]) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + def test_upsample_bfloat16(self): + torch.manual_seed(0) + sample = torch.randn(1, 32, 32, 32).to(torch.bfloat16) + upsample = Upsample2D(channels=32, use_conv=False) + with torch.no_grad(): + upsampled = upsample(sample) + + assert upsampled.shape == (1, 32, 64, 64) + output_slice = upsampled[0, -1, -3:, -3:] + expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254], dtype=torch.bfloat16) + assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + def test_upsample_with_conv(self): torch.manual_seed(0) sample = torch.randn(1, 32, 32, 32) From 59c0bf38a280911f2551ce2d662f6465275cbfa9 Mon Sep 17 00:00:00 2001 From: Darren Hsu <35377472+darhsu@users.noreply.github.com> Date: Fri, 20 Sep 2024 08:47:59 +0000 Subject: [PATCH 3/5] Resolve comments and add decorator --- src/diffusers/models/upsampling.py | 7 ++----- src/diffusers/utils/testing_utils.py | 14 ++++++++++++++ tests/models/test_layers_utils.py | 2 ++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index e121109ae5e8..cf07e45b0c5c 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -23,9 +23,6 @@ from .normalization import RMSNorm -is_torch_less_than_2_1 = is_torch_version("<", "2.1") - - class Upsample1D(nn.Module): """A 1D upsampling layer with an optional convolution. @@ -158,7 +155,7 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1 # https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767 dtype = hidden_states.dtype - if dtype == torch.bfloat16 and is_torch_less_than_2_1: + if dtype == torch.bfloat16 and is_torch_version("<", "2.1"): hidden_states = hidden_states.to(torch.float32) # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 @@ -174,7 +171,7 @@ def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # Cast back to original dtype - if dtype == torch.bfloat16 and is_torch_less_than_2_1: + if dtype == torch.bfloat16 and is_torch_version("<", "2.1"): hidden_states = hidden_states.to(dtype) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index be3e9983c80f..7e12cf3c6150 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -252,6 +252,20 @@ def require_torch_2(test_case): ) +def require_torch_version_greater_equal(torch_version): + """Decorator marking a test that requires torch with a specific version or greater.""" + + def decorator(test_case): + correct_torch_version = version.parse( + version.parse(torch.__version__).base_version + ) >= version.parse(torch_version) + return unittest.skipUnless( + correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}" + )(test_case) + + return decorator + + def require_torch_gpu(test_case): """Decorator marking a test that requires CUDA and PyTorch.""" return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index 4a4dc48e5002..561113258e8f 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -27,6 +27,7 @@ from diffusers.utils.testing_utils import ( backend_manual_seed, require_torch_accelerator_with_fp64, + require_torch_version_greater_equal, torch_device, ) @@ -120,6 +121,7 @@ def test_upsample_default(self): expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254]) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) + @require_torch_version_greater_equal("2.1") def test_upsample_bfloat16(self): torch.manual_seed(0) sample = torch.randn(1, 32, 32, 32).to(torch.bfloat16) From 7525b0a2bfb7dfb6ba2bae20c20fa92ed2dfd526 Mon Sep 17 00:00:00 2001 From: Darren Hsu <35377472+darhsu@users.noreply.github.com> Date: Fri, 20 Sep 2024 21:48:16 +0000 Subject: [PATCH 4/5] Simplify require_torch_version_greater_equal decorator --- src/diffusers/utils/testing_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 7e12cf3c6150..7dc3f414d55c 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -256,9 +256,7 @@ def require_torch_version_greater_equal(torch_version): """Decorator marking a test that requires torch with a specific version or greater.""" def decorator(test_case): - correct_torch_version = version.parse( - version.parse(torch.__version__).base_version - ) >= version.parse(torch_version) + correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version) return unittest.skipUnless( correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}" )(test_case) From 7b9553252a369e8f000aa1dd3a47836d4fdaf966 Mon Sep 17 00:00:00 2001 From: Darren Hsu <35377472+darhsu@users.noreply.github.com> Date: Thu, 26 Sep 2024 00:23:08 +0000 Subject: [PATCH 5/5] Run make style --- tests/models/test_layers_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index 561113258e8f..415bb12b73c6 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -131,7 +131,9 @@ def test_upsample_bfloat16(self): assert upsampled.shape == (1, 32, 64, 64) output_slice = upsampled[0, -1, -3:, -3:] - expected_slice = torch.tensor([-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254], dtype=torch.bfloat16) + expected_slice = torch.tensor( + [-0.2173, -1.2079, -1.2079, 0.2952, 1.1254, 1.1254, 0.2952, 1.1254, 1.1254], dtype=torch.bfloat16 + ) assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3) def test_upsample_with_conv(self):