Skip to content

Commit

Permalink
update install reqs if op is compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffra committed Nov 3, 2020
1 parent 1377e7f commit 0815d11
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 11 deletions.
12 changes: 12 additions & 0 deletions op_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,18 @@ def is_compatible(self):
'''
return True

def python_requirements(self):
'''
Override if op wants to define special dependencies, otherwise will
take self.name and load requirements-<op-name>.txt if it exists.
'''
path = f'requirements/requirements-{self.name}.txt'
requirements = []
if os.path.isfile():
with open(path, 'r') as fd:
requirements = [r.strip() for r in fd.readlines()]
return requirements

def command_exists(self, cmd):
if '|' in cmd:
cmds = cmd.split("|")
Expand Down
1 change: 1 addition & 0 deletions op_builder/cpu_adam.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import torch
import warnings
from .builder import CUDAOpBuilder


Expand Down
11 changes: 1 addition & 10 deletions op_builder/sparse_attn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import warnings
from .builder import OpBuilder #, command_exists
from .builder import OpBuilder


class SparseAttnBuilder(OpBuilder):
Expand Down Expand Up @@ -33,14 +33,5 @@ def is_compatible(self):
f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}'
)

try:
import triton
triton_installed = True
except ImportError:
triton_installed = False
if not triton_installed:
self.warning(
f"{self.NAME} requires the python package 'triton' to be installed")

return super().is_compatible(
) and deps_compatible and torch_compatible and triton_installed
File renamed without changes.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def fetch_requirements(path):

install_requires = fetch_requirements('requirements/requirements.txt')
extras_require = {
'sparse_attn': fetch_requirements('requirements/requirements-sparse-attn.txt'),
'1bit_adam': fetch_requirements('requirements/requirements-1bit-adam.txt'),
'readthedocs': fetch_requirements('requirements/requirements-readthedocs.txt'),
'dev': fetch_requirements('requirements/requirements-dev.txt'),
}

Expand Down Expand Up @@ -91,6 +91,8 @@ def op_enabled(op_name):
if builder.is_compatible():
install_ops[op_name] = op_enabled(op_name)
ext_modules.append(builder.builder())
# Add any necessary python requirements for this op
install_requires += builder.python_requirements()

compatible_ops = {op_name: op.is_compatible() for (op_name, op) in ALL_OPS.items()}

Expand Down

0 comments on commit 0815d11

Please sign in to comment.