From 31183ee173882588dbf8a39cf86ae20851e41060 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 2 Mar 2023 13:23:21 +0530 Subject: [PATCH 01/12] ema test cases. --- src/diffusers/training_utils.py | 1 + tests/test_ema.py | 116 ++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+) create mode 100644 tests/test_ema.py diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 67a8e48d381f..75c9bac10ad8 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -196,6 +196,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): decay = self.get_decay(self.optimization_step) self.cur_decay_value = decay one_minus_decay = 1 - decay + print(f"step() has been called with {one_minus_decay}.") for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: diff --git a/tests/test_ema.py b/tests/test_ema.py new file mode 100644 index 000000000000..941ae0f2da0e --- /dev/null +++ b/tests/test_ema.py @@ -0,0 +1,116 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import gc +import numpy as np +import torch +from torch import nn + +from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock +from diffusers.models.embeddings import get_timestep_embedding +from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D +from diffusers.models.transformer_2d import Transformer2DModel +from diffusers import UNet2DConditionModel +from diffusers.utils import torch_device +from diffusers.training_utils import EMAModel + +class EMAModelTests(unittest.TestCase): + model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" + + def get_models(self): + unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") + ema_unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") + ema_unet = EMAModel( + ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config + ) + return unet, ema_unet + + def test_optimization_steps_updated(self): + unet, ema_unet = self.get_models() + # Take the first (hypothetical) EMA step. + ema_unet.step(unet.parameters()) + assert ema_unet.optimization == 1 + + # Take two more. + for _ in range(2): + ema_unet.step(unet.parameters()) + assert ema_unet.optimization == 3 + + del unet, ema_unet + + def test_shadow_params_not_updated(self): + unet, ema_unet = self.get_models() + # Since the `unet` is not being updated (i.e., backprop'd) + # there won't be any difference between the `params` of `unet` + # and `ema_unet` even if we call `ema_unet.step(unet.parameters())`. + ema_unet.step(unet.parameters()) + orig_params = list(unet.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert torch.allclose(s_param, param) + + # The above holds true even if we call `ema.step()` multiple times since + # `unet` params are still not being updated. + for _ in range(4): + ema_unet.step(unet.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert torch.allclose(s_param, param) + + del unet, ema_unet + + def test_shadow_params_updated(self): + unet, ema_unet = self.get_models() + # Here we simulate the parameter updates for `unet`. Since there might + # be some parameters which are initialized to zero we take extra care to + # initialize their values to something non-zero before the multiplication. + updated_params = [] + for param in unet.parameters(): + updated_params.append(torch.randn_like(param) + (param * torch.randn_like(param))) + + # Load the updated parameters into `unet`. + updated_state_dict = {} + for i, k in enumerate(unet.state_dict().keys()): + updated_state_dict.update({k: updated_params[i]}) + unet.load_state_dict(updated_state_dict) + + # Take the EMA step. + ema_unet.step(unet.parameters()) + + # Now the EMA'd parameters won't be equal to the original model parameters. + orig_params = list(unet.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert ~torch.allclose(s_param, param) + + # Ensure this is the case when we take multiple EMA steps. + for _ in range(4): + ema_unet.step(unet.parameters()) + for s_param, param in zip(ema_unet.shadow_params, orig_params): + assert ~torch.allclose(s_param, param) + + def test_consecutive_shadow_params_not_updated(self): + # EMA steps are supposed to be taken after we have taken a backprop step. + # If that is not the case shadown params after two consecutive steps should + # be one and the same + pass + + def test_consecutive_shadow_params_updated(self): + pass + + + + def tearDown(self): + super().tearDown() + gc.collect() From 313414ac8e53d98f1f4f37f20788da001c1ff3d6 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 2 Mar 2023 13:28:32 +0530 Subject: [PATCH 02/12] debugging maessages. --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 75c9bac10ad8..58218f10525b 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -196,7 +196,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): decay = self.get_decay(self.optimization_step) self.cur_decay_value = decay one_minus_decay = 1 - decay - print(f"step() has been called with {one_minus_decay}.") + print(f"step() has been called with one_minus_decay: {one_minus_decay}, decay: {decay}.") for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: From f47487fac6ce6000c87564fbb997b16484a45fca Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 2 Mar 2023 13:28:58 +0530 Subject: [PATCH 03/12] debugging maessages. --- src/diffusers/training_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 58218f10525b..1a4f6ddb262d 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -200,6 +200,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: + print(f"Subtraction quantity: {one_minus_decay * (s_param - param)}") s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param) From 2d0f7b574ef31b1a90067cdb2e4cf4057d160ba1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 2 Mar 2023 14:24:36 +0530 Subject: [PATCH 04/12] add: tests for ema. --- src/diffusers/training_utils.py | 2 - tests/test_ema.py | 83 +++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 33 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 1a4f6ddb262d..67a8e48d381f 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -196,11 +196,9 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): decay = self.get_decay(self.optimization_step) self.cur_decay_value = decay one_minus_decay = 1 - decay - print(f"step() has been called with one_minus_decay: {one_minus_decay}, decay: {decay}.") for s_param, param in zip(self.shadow_params, parameters): if param.requires_grad: - print(f"Subtraction quantity: {one_minus_decay * (s_param - param)}") s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param) diff --git a/tests/test_ema.py b/tests/test_ema.py index 941ae0f2da0e..24c394403063 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -13,32 +13,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc import unittest -import gc -import numpy as np import torch -from torch import nn -from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock -from diffusers.models.embeddings import get_timestep_embedding -from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D -from diffusers.models.transformer_2d import Transformer2DModel from diffusers import UNet2DConditionModel -from diffusers.utils import torch_device from diffusers.training_utils import EMAModel + class EMAModelTests(unittest.TestCase): model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" - def get_models(self): + def get_models(self, decay=0.9999): unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") ema_unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") ema_unet = EMAModel( - ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config + ema_unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=ema_unet.config ) - return unet, ema_unet - + return unet, ema_unet + + def similuate_backprop(self, unet): + updated_state_dict = {} + for k, param in unet.state_dict().items(): + updated_param = torch.randn_like(param) + (param * torch.randn_like(param)) + updated_state_dict.update({k: updated_param}) + unet.load_state_dict(updated_state_dict) + return unet + def test_optimization_steps_updated(self): unet, ema_unet = self.get_models() # Take the first (hypothetical) EMA step. @@ -68,7 +70,7 @@ def test_shadow_params_not_updated(self): ema_unet.step(unet.parameters()) for s_param, param in zip(ema_unet.shadow_params, orig_params): assert torch.allclose(s_param, param) - + del unet, ema_unet def test_shadow_params_updated(self): @@ -76,21 +78,13 @@ def test_shadow_params_updated(self): # Here we simulate the parameter updates for `unet`. Since there might # be some parameters which are initialized to zero we take extra care to # initialize their values to something non-zero before the multiplication. - updated_params = [] - for param in unet.parameters(): - updated_params.append(torch.randn_like(param) + (param * torch.randn_like(param))) - - # Load the updated parameters into `unet`. - updated_state_dict = {} - for i, k in enumerate(unet.state_dict().keys()): - updated_state_dict.update({k: updated_params[i]}) - unet.load_state_dict(updated_state_dict) + unet_pseudo_updated_step_one = self.similuate_backprop(unet) # Take the EMA step. - ema_unet.step(unet.parameters()) + ema_unet.step(unet_pseudo_updated_step_one.parameters()) # Now the EMA'd parameters won't be equal to the original model parameters. - orig_params = list(unet.parameters()) + orig_params = list(unet_pseudo_updated_step_one.parameters()) for s_param, param in zip(ema_unet.shadow_params, orig_params): assert ~torch.allclose(s_param, param) @@ -100,16 +94,43 @@ def test_shadow_params_updated(self): for s_param, param in zip(ema_unet.shadow_params, orig_params): assert ~torch.allclose(s_param, param) - def test_consecutive_shadow_params_not_updated(self): - # EMA steps are supposed to be taken after we have taken a backprop step. - # If that is not the case shadown params after two consecutive steps should - # be one and the same - pass - def test_consecutive_shadow_params_updated(self): - pass + # If we call EMA step after a backpropagation consecutively for two times, + # the shadow params from those two steps should be different. + unet, ema_unet = self.get_models() + # First backprop + EMA + unet_step_one = self.similuate_backprop(unet) + ema_unet.step(unet_step_one.parameters()) + step_one_shadow_params = ema_unet.shadow_params + # Second backprop + EMA + unet_step_two = self.similuate_backprop(unet_step_one) + ema_unet.step(unet_step_two.parameters()) + step_two_shadow_params = ema_unet.shadow_params + + for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): + assert ~torch.allclose(step_one, step_two) + + del unet, ema_unet + + def test_zero_decay(self): + # If there's no decay even if there are backprops, EMA steps + # won't take any effect i.e., the shadow params would remain the + # same. + unet, ema_unet = self.get_models(decay=0.0) + unet_step_one = self.similuate_backprop(unet) + ema_unet.step(unet_step_one.parameters()) + step_one_shadow_params = ema_unet.shadow_params + + unet_step_two = self.similuate_backprop(unet_step_one) + ema_unet.step(unet_step_two.parameters()) + step_two_shadow_params = ema_unet.shadow_params + + for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): + assert torch.allclose(step_one, step_two) + + del unet, ema_unet def tearDown(self): super().tearDown() From 17994e561a04abcf70f02aac82f1bad324bbfcbf Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 2 Mar 2023 14:32:21 +0530 Subject: [PATCH 05/12] fix: optimization_step arg, --- tests/test_ema.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_ema.py b/tests/test_ema.py index 24c394403063..2a9364419866 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -45,12 +45,12 @@ def test_optimization_steps_updated(self): unet, ema_unet = self.get_models() # Take the first (hypothetical) EMA step. ema_unet.step(unet.parameters()) - assert ema_unet.optimization == 1 + assert ema_unet.optimization_step == 1 # Take two more. for _ in range(2): ema_unet.step(unet.parameters()) - assert ema_unet.optimization == 3 + assert ema_unet.optimization_step == 3 del unet, ema_unet From 49aed340abf30c57de0e569f652417300fdcd27d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Mar 2023 15:57:39 +0530 Subject: [PATCH 06/12] handle device placement. --- tests/test_ema.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_ema.py b/tests/test_ema.py index 2a9364419866..499b81c5c4b2 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -20,6 +20,7 @@ from diffusers import UNet2DConditionModel from diffusers.training_utils import EMAModel +from diffusers.utils.testing_utils import torch_device class EMAModelTests(unittest.TestCase): @@ -31,7 +32,7 @@ def get_models(self, decay=0.9999): ema_unet = EMAModel( ema_unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=ema_unet.config ) - return unet, ema_unet + return unet.to(torch_device), ema_unet.to(torch_device) def similuate_backprop(self, unet): updated_state_dict = {} From 258bbe11ed4b61a7c4bf65de9bea30322f280ca1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Mar 2023 15:59:00 +0530 Subject: [PATCH 07/12] Apply suggestions from code review Co-authored-by: Will Berman --- tests/test_ema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_ema.py b/tests/test_ema.py index 499b81c5c4b2..2b4616056fa7 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -34,7 +34,7 @@ def get_models(self, decay=0.9999): ) return unet.to(torch_device), ema_unet.to(torch_device) - def similuate_backprop(self, unet): + def simulate_backprop(self, unet): updated_state_dict = {} for k, param in unet.state_dict().items(): updated_param = torch.randn_like(param) + (param * torch.randn_like(param)) From f4c4e0c31068238f01e2e9e8219f851fd706dd34 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Mar 2023 16:01:55 +0530 Subject: [PATCH 08/12] remove del and gc. --- tests/test_ema.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/test_ema.py b/tests/test_ema.py index 2b4616056fa7..68d01af59851 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -53,8 +53,6 @@ def test_optimization_steps_updated(self): ema_unet.step(unet.parameters()) assert ema_unet.optimization_step == 3 - del unet, ema_unet - def test_shadow_params_not_updated(self): unet, ema_unet = self.get_models() # Since the `unet` is not being updated (i.e., backprop'd) @@ -72,8 +70,6 @@ def test_shadow_params_not_updated(self): for s_param, param in zip(ema_unet.shadow_params, orig_params): assert torch.allclose(s_param, param) - del unet, ema_unet - def test_shadow_params_updated(self): unet, ema_unet = self.get_models() # Here we simulate the parameter updates for `unet`. Since there might @@ -113,8 +109,6 @@ def test_consecutive_shadow_params_updated(self): for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): assert ~torch.allclose(step_one, step_two) - del unet, ema_unet - def test_zero_decay(self): # If there's no decay even if there are backprops, EMA steps # won't take any effect i.e., the shadow params would remain the @@ -130,9 +124,3 @@ def test_zero_decay(self): for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): assert torch.allclose(step_one, step_two) - - del unet, ema_unet - - def tearDown(self): - super().tearDown() - gc.collect() From 885027543f9a67536d65de3a80c634401789daf1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Mar 2023 16:08:00 +0530 Subject: [PATCH 09/12] address PR feedback. --- tests/test_ema.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_ema.py b/tests/test_ema.py index 68d01af59851..ef7db47d503b 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import gc import unittest import torch From d6241d2025f350d0c76b8739e70d8b2224f8ef80 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Mar 2023 17:17:02 +0530 Subject: [PATCH 10/12] add: tests for serialization. --- tests/test_ema.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/test_ema.py b/tests/test_ema.py index ef7db47d503b..2e4fa5b0b888 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tempfile import unittest import torch @@ -24,6 +25,12 @@ class EMAModelTests(unittest.TestCase): model_id = "hf-internal-testing/tiny-stable-diffusion-pipe" + batch_size = 1 + prompt_length = 77 + text_encoder_hidden_dim = 32 + num_in_channels = 4 + latent_height = latent_width = 64 + generator = torch.manual_seed(0) def get_models(self, decay=0.9999): unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") @@ -33,6 +40,16 @@ def get_models(self, decay=0.9999): ) return unet.to(torch_device), ema_unet.to(torch_device) + def get_dummy_inputs(self): + noisy_latents = torch.randn( + self.batch_size, self.num_in_channels, self.latent_height, self.latent_width, generator=self.generator + ).to(torch_device) + timesteps = torch.randint(0, 1000, size=(self.batch_size,), generator=self.generator).to(torch_device) + encoder_hidden_states = torch.randn( + self.batch_size, self.prompt_length, self.text_encoder_hidden_dim, generator=self.generator + ).to(torch_device) + return noisy_latents, timesteps, encoder_hidden_states + def simulate_backprop(self, unet): updated_state_dict = {} for k, param in unet.state_dict().items(): @@ -123,3 +140,17 @@ def test_zero_decay(self): for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): assert torch.allclose(step_one, step_two) + + def test_serialization(self): + unet, ema_unet = self.get_models() + noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs() + + with tempfile.TemporaryDirectory() as tmpdir: + ema_unet.save_pretrained(tmpdir) + loaded_unet = UNet2DConditionModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel) + + # Since no EMA step has been performed the outputs should match. + output = unet(noisy_latents, timesteps, encoder_hidden_states).sample + output_loaded = loaded_unet(noisy_latents, timesteps, encoder_hidden_states).sample + + assert torch.allclose(output, output_loaded) From aee2846144b6d0df329f474129915e6b98343516 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 13 Mar 2023 17:32:32 +0530 Subject: [PATCH 11/12] fix: typos. --- tests/test_ema.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_ema.py b/tests/test_ema.py index 2e4fa5b0b888..d0af2055841d 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -33,12 +33,12 @@ class EMAModelTests(unittest.TestCase): generator = torch.manual_seed(0) def get_models(self, decay=0.9999): - unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") + unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet", device=torch_device) ema_unet = UNet2DConditionModel.from_pretrained(self.model_id, subfolder="unet") ema_unet = EMAModel( ema_unet.parameters(), decay=decay, model_cls=UNet2DConditionModel, model_config=ema_unet.config ) - return unet.to(torch_device), ema_unet.to(torch_device) + return unet, ema_unet def get_dummy_inputs(self): noisy_latents = torch.randn( @@ -91,7 +91,7 @@ def test_shadow_params_updated(self): # Here we simulate the parameter updates for `unet`. Since there might # be some parameters which are initialized to zero we take extra care to # initialize their values to something non-zero before the multiplication. - unet_pseudo_updated_step_one = self.similuate_backprop(unet) + unet_pseudo_updated_step_one = self.simulate_backprop(unet) # Take the EMA step. ema_unet.step(unet_pseudo_updated_step_one.parameters()) @@ -113,12 +113,12 @@ def test_consecutive_shadow_params_updated(self): unet, ema_unet = self.get_models() # First backprop + EMA - unet_step_one = self.similuate_backprop(unet) + unet_step_one = self.simulate_backprop(unet) ema_unet.step(unet_step_one.parameters()) step_one_shadow_params = ema_unet.shadow_params # Second backprop + EMA - unet_step_two = self.similuate_backprop(unet_step_one) + unet_step_two = self.simulate_backprop(unet_step_one) ema_unet.step(unet_step_two.parameters()) step_two_shadow_params = ema_unet.shadow_params @@ -130,11 +130,11 @@ def test_zero_decay(self): # won't take any effect i.e., the shadow params would remain the # same. unet, ema_unet = self.get_models(decay=0.0) - unet_step_one = self.similuate_backprop(unet) + unet_step_one = self.simulate_backprop(unet) ema_unet.step(unet_step_one.parameters()) step_one_shadow_params = ema_unet.shadow_params - unet_step_two = self.similuate_backprop(unet_step_one) + unet_step_two = self.simulate_backprop(unet_step_one) ema_unet.step(unet_step_two.parameters()) step_two_shadow_params = ema_unet.shadow_params From 8f2a75f396b2783101cc962f13c38a8bb00ff36d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Tue, 14 Mar 2023 10:19:28 +0530 Subject: [PATCH 12/12] skip_mps to serialization. --- tests/test_ema.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_ema.py b/tests/test_ema.py index d0af2055841d..9f99457080d5 100644 --- a/tests/test_ema.py +++ b/tests/test_ema.py @@ -20,7 +20,7 @@ from diffusers import UNet2DConditionModel from diffusers.training_utils import EMAModel -from diffusers.utils.testing_utils import torch_device +from diffusers.utils.testing_utils import skip_mps, torch_device class EMAModelTests(unittest.TestCase): @@ -141,6 +141,7 @@ def test_zero_decay(self): for step_one, step_two in zip(step_one_shadow_params, step_two_shadow_params): assert torch.allclose(step_one, step_two) + @skip_mps def test_serialization(self): unet, ema_unet = self.get_models() noisy_latents, timesteps, encoder_hidden_states = self.get_dummy_inputs()