diff --git a/utils/fetch_torch_cuda_pipeline_test_matrix.py b/utils/fetch_torch_cuda_pipeline_test_matrix.py index 41a9c1c8270d..302898789728 100644 --- a/utils/fetch_torch_cuda_pipeline_test_matrix.py +++ b/utils/fetch_torch_cuda_pipeline_test_matrix.py @@ -34,8 +34,11 @@ def filter_pipelines(usage_dict, usage_cutoff=10000): if usage < usage_cutoff: continue - if "Pipeline" in diffusers_object: - output.append(diffusers_object) + is_diffusers_pipeline = hasattr(diffusers.pipelines, diffusers_object) + if not is_diffusers_pipeline: + continue + + output.append(diffusers_object) return output @@ -71,6 +74,7 @@ def fetch_pipeline_modules_to_test(): test_modules = [] for pipeline_name in pipeline_objects: module = getattr(diffusers, pipeline_name) + test_module = module.__module__.split(".")[-2].strip() test_modules.append(test_module)