Skip to content

Commit

Permalink
fixes and tested
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra committed Nov 3, 2020
1 parent 0815d11 commit 769fa1c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 13 deletions.
2 changes: 1 addition & 1 deletion op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def python_requirements(self):
'''
path = f'requirements/requirements-{self.name}.txt'
requirements = []
if os.path.isfile():
if os.path.isfile(path):
with open(path, 'r') as fd:
requirements = [r.strip() for r in fd.readlines()]
return requirements
Expand Down
3 changes: 1 addition & 2 deletions op_builder/sparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,4 @@ def is_compatible(self):
f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}'
)

return super().is_compatible(
) and deps_compatible and torch_compatible and triton_installed
return super().is_compatible() and deps_compatible and torch_compatible
22 changes: 12 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,18 @@ def op_enabled(op_name):


install_ops = dict.fromkeys(ALL_OPS.keys(), False)
for op_name in ALL_OPS.keys():
# Is op disabled from environment variable
if op_enabled(op_name):
builder = ALL_OPS[op_name]
# Is op compatible with machine arch/deps
if builder.is_compatible():
install_ops[op_name] = op_enabled(op_name)
ext_modules.append(builder.builder())
# Add any necessary python requirements for this op
install_requires += builder.python_requirements()
for op_name, builder in ALL_OPS.items():
op_compatible = builder.is_compatible()

# If op is compatible update install reqs so it can potentially build/run later
if op_compatible:
reqs = builder.python_requirements()
install_requires += builder.python_requirements()

# If op install enabled, add builder to extensions
if op_enabled(op_name) and op_compatible:
install_ops[op_name] = op_enabled(op_name)
ext_modules.append(builder.builder())

compatible_ops = {op_name: op.is_compatible() for (op_name, op) in ALL_OPS.items()}

Expand Down

0 comments on commit 769fa1c

Please sign in to comment.