|
24 | 24 | import torch
|
25 | 25 | from requests.exceptions import HTTPError
|
26 | 26 |
|
27 |
| -from diffusers.models import ModelMixin, UNet2DConditionModel |
| 27 | +from diffusers.models import UNet2DConditionModel |
28 | 28 | from diffusers.models.attention_processor import AttnProcessor
|
29 | 29 | from diffusers.training_utils import EMAModel
|
30 | 30 | from diffusers.utils import torch_device
|
@@ -119,11 +119,6 @@ def test_from_save_pretrained(self):
|
119 | 119 | new_model.to(torch_device)
|
120 | 120 |
|
121 | 121 | with torch.no_grad():
|
122 |
| - # Warmup pass when using mps (see #372) |
123 |
| - if torch_device == "mps" and isinstance(model, ModelMixin): |
124 |
| - _ = model(**self.dummy_input) |
125 |
| - _ = new_model(**self.dummy_input) |
126 |
| - |
127 | 122 | image = model(**inputs_dict)
|
128 | 123 | if isinstance(image, dict):
|
129 | 124 | image = image.sample
|
@@ -161,11 +156,6 @@ def test_from_save_pretrained_variant(self):
|
161 | 156 | new_model.to(torch_device)
|
162 | 157 |
|
163 | 158 | with torch.no_grad():
|
164 |
| - # Warmup pass when using mps (see #372) |
165 |
| - if torch_device == "mps" and isinstance(model, ModelMixin): |
166 |
| - _ = model(**self.dummy_input) |
167 |
| - _ = new_model(**self.dummy_input) |
168 |
| - |
169 | 159 | image = model(**inputs_dict)
|
170 | 160 | if isinstance(image, dict):
|
171 | 161 | image = image.sample
|
@@ -203,10 +193,6 @@ def test_determinism(self):
|
203 | 193 | model.eval()
|
204 | 194 |
|
205 | 195 | with torch.no_grad():
|
206 |
| - # Warmup pass when using mps (see #372) |
207 |
| - if torch_device == "mps" and isinstance(model, ModelMixin): |
208 |
| - model(**self.dummy_input) |
209 |
| - |
210 | 196 | first = model(**inputs_dict)
|
211 | 197 | if isinstance(first, dict):
|
212 | 198 | first = first.sample
|
@@ -377,10 +363,6 @@ def recursive_check(tuple_object, dict_object):
|
377 | 363 | model.eval()
|
378 | 364 |
|
379 | 365 | with torch.no_grad():
|
380 |
| - # Warmup pass when using mps (see #372) |
381 |
| - if torch_device == "mps" and isinstance(model, ModelMixin): |
382 |
| - model(**self.dummy_input) |
383 |
| - |
384 | 366 | outputs_dict = model(**inputs_dict)
|
385 | 367 | outputs_tuple = model(**inputs_dict, return_dict=False)
|
386 | 368 |
|
|
0 commit comments