@@ -33,11 +33,9 @@ class EMAModelTests(unittest.TestCase):
3333 generator = torch .manual_seed (0 )
3434
3535 def get_models (self , decay = 0.9999 ):
36- unet = UNet2DConditionModel .from_pretrained (self .model_id , subfolder = "unet" , device = torch_device )
37- ema_unet = UNet2DConditionModel .from_pretrained (self .model_id , subfolder = "unet" )
38- ema_unet = EMAModel (
39- ema_unet .parameters (), decay = decay , model_cls = UNet2DConditionModel , model_config = ema_unet .config
40- )
36+ unet = UNet2DConditionModel .from_pretrained (self .model_id , subfolder = "unet" )
37+ unet = unet .to (torch_device )
38+ ema_unet = EMAModel (unet .parameters (), decay = decay , model_cls = UNet2DConditionModel , model_config = unet .config )
4139 return unet , ema_unet
4240
4341 def get_dummy_inputs (self ):
@@ -149,6 +147,7 @@ def test_serialization(self):
149147 with tempfile .TemporaryDirectory () as tmpdir :
150148 ema_unet .save_pretrained (tmpdir )
151149 loaded_unet = UNet2DConditionModel .from_pretrained (tmpdir , model_cls = UNet2DConditionModel )
150+ loaded_unet = loaded_unet .to (unet .device )
152151
153152 # Since no EMA step has been performed the outputs should match.
154153 output = unet (noisy_latents , timesteps , encoder_hidden_states ).sample
0 commit comments