Skip to content

Commit cca76a6

Browse files
up
1 parent dbe0719 commit cca76a6

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/test_pipelines_common.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,22 @@ def test_save_load_local(self):
9090
max_diff = np.abs(output - output_loaded).max()
9191
self.assertLess(max_diff, 1e-5)
9292

93+
def test_signature(self):
94+
assert hasattr(self.pipeline_class, "__call__"), f"{self.pipeline_class} should have a `__call__` method"
95+
parameters = inspect.signature(self.pipeline_class.__call__).parameters
96+
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
97+
required_parameters.pop("self")
98+
99+
allowed_required_params = ["prompt", "image", "mask_image", "example_image"]
100+
for param in required_parameters.keys():
101+
assert param in allowed_required_params
102+
103+
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
104+
105+
required_optional_params = ["generator", "num_inference_steps"]
106+
for param in required_optional_params:
107+
assert param in optional_parameters
108+
93109
def test_dict_tuple_outputs_equivalent(self):
94110
if torch_device == "mps" and self.pipeline_class in (
95111
DanceDiffusionPipeline,

0 commit comments

Comments
 (0)