From 769fa1c05db92ef9020d2f6ddf4c630a39546b2b Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 3 Nov 2020 23:10:07 +0000 Subject: [PATCH] fixes and tested --- op_builder/builder.py | 2 +- op_builder/sparse_attn.py | 3 +-- setup.py | 22 ++++++++++++---------- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/op_builder/builder.py b/op_builder/builder.py index 6dd3d574d036..ea3f79bc64de 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -102,7 +102,7 @@ def python_requirements(self): ''' path = f'requirements/requirements-{self.name}.txt' requirements = [] - if os.path.isfile(): + if os.path.isfile(path): with open(path, 'r') as fd: requirements = [r.strip() for r in fd.readlines()] return requirements diff --git a/op_builder/sparse_attn.py b/op_builder/sparse_attn.py index 550ad3084f72..4c716f859970 100644 --- a/op_builder/sparse_attn.py +++ b/op_builder/sparse_attn.py @@ -33,5 +33,4 @@ def is_compatible(self): f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}' ) - return super().is_compatible( - ) and deps_compatible and torch_compatible and triton_installed + return super().is_compatible() and deps_compatible and torch_compatible diff --git a/setup.py b/setup.py index f3d114a70bc1..bf1fc086ffff 100755 --- a/setup.py +++ b/setup.py @@ -83,16 +83,18 @@ def op_enabled(op_name): install_ops = dict.fromkeys(ALL_OPS.keys(), False) -for op_name in ALL_OPS.keys(): - # Is op disabled from environment variable - if op_enabled(op_name): - builder = ALL_OPS[op_name] - # Is op compatible with machine arch/deps - if builder.is_compatible(): - install_ops[op_name] = op_enabled(op_name) - ext_modules.append(builder.builder()) - # Add any necessary python requirements for this op - install_requires += builder.python_requirements() +for op_name, builder in ALL_OPS.items(): + op_compatible = builder.is_compatible() + + # If op is compatible update install reqs so it can potentially build/run later + if op_compatible: + reqs = builder.python_requirements() + install_requires += builder.python_requirements() + + # If op install enabled, add builder to extensions + if op_enabled(op_name) and op_compatible: + install_ops[op_name] = op_enabled(op_name) + ext_modules.append(builder.builder()) compatible_ops = {op_name: op.is_compatible() for (op_name, op) in ALL_OPS.items()}