From 926b94cfb0e4e21d0dc3a87cb998fc7adf1b9310 Mon Sep 17 00:00:00 2001 From: srush Date: Sat, 13 Feb 2021 21:34:46 -0500 Subject: [PATCH] update setup --- setup.py | 6 ++++++ torch_struct/semirings/checkpoint.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f9e514cd..5af04f24 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,8 @@ from setuptools import setup +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + setup( name="torch_struct", version="0.5", @@ -9,9 +12,12 @@ "torch_struct", "torch_struct.semirings", ], + long_description=long_description, package_data={"torch_struct": []}, + long_description_content_type="text/markdown", url="https://github.com/harvardnlp/pytorch_struct", install_requires=["torch"], setup_requires=["pytest-runner"], tests_require=["pytest"], + python_requires='>=3.6', ) diff --git a/torch_struct/semirings/checkpoint.py b/torch_struct/semirings/checkpoint.py index b2dacba5..c4e10c4f 100644 --- a/torch_struct/semirings/checkpoint.py +++ b/torch_struct/semirings/checkpoint.py @@ -1,8 +1,10 @@ import torch +has_genbmm = False try: import genbmm from genbmm import BandedMatrix + has_genbmm = True except ImportError: pass @@ -52,7 +54,7 @@ def backward(ctx, grad_output): class _CheckpointSemiring(cls): @staticmethod def matmul(a, b): - if isinstance(a, genbmm.BandedMatrix): + if has_genbmm and isinstance(a, genbmm.BandedMatrix): lu = a.lu + b.lu ld = a.ld + b.ld c = _CheckBand.apply(a.data, a.lu, a.ld, b.data, b.lu, b.ld)