Skip to content

Commit

Permalink
update setup
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Feb 14, 2021
1 parent b4638f1 commit 926b94c
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
6 changes: 6 additions & 0 deletions setup.py
@@ -1,5 +1,8 @@
from setuptools import setup

with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()

setup(
name="torch_struct",
version="0.5",
Expand All @@ -9,9 +12,12 @@
"torch_struct",
"torch_struct.semirings",
],
long_description=long_description,
package_data={"torch_struct": []},
long_description_content_type="text/markdown",
url="https://github.com/harvardnlp/pytorch_struct",
install_requires=["torch"],
setup_requires=["pytest-runner"],
tests_require=["pytest"],
python_requires='>=3.6',
)
4 changes: 3 additions & 1 deletion torch_struct/semirings/checkpoint.py
@@ -1,8 +1,10 @@
import torch

has_genbmm = False
try:
import genbmm
from genbmm import BandedMatrix
has_genbmm = True
except ImportError:
pass

Expand Down Expand Up @@ -52,7 +54,7 @@ def backward(ctx, grad_output):
class _CheckpointSemiring(cls):
@staticmethod
def matmul(a, b):
if isinstance(a, genbmm.BandedMatrix):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
lu = a.lu + b.lu
ld = a.ld + b.ld
c = _CheckBand.apply(a.data, a.lu, a.ld, b.data, b.lu, b.ld)
Expand Down

0 comments on commit 926b94c

Please sign in to comment.