Skip to content

Commit

Permalink
[FIX] Check LSTM proj_size only for torch>=1.8.0
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Jun 30, 2021
1 parent 5726cc1 commit 4f126b0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
6 changes: 4 additions & 2 deletions backpack/core/derivatives/lstm.py
Expand Up @@ -5,6 +5,7 @@
from torch.nn import LSTM

from backpack.core.derivatives.basederivatives import BaseParameterDerivatives
from backpack.utils import TORCH_VERSION, VERSION_1_8_0


class LSTMDerivatives(BaseParameterDerivatives):
Expand Down Expand Up @@ -47,8 +48,9 @@ def _check_parameters(module: LSTM) -> None:
raise NotImplementedError("only dropout = 0 is supported")
if module.bidirectional is not False:
raise NotImplementedError("only bidirectional = False is supported")
if module.proj_size != 0:
raise NotImplementedError("only proj_size = 0 is supported")
if TORCH_VERSION >= VERSION_1_8_0:
if module.proj_size != 0:
raise NotImplementedError("only proj_size = 0 is supported")

@staticmethod
def _forward_pass(
Expand Down
5 changes: 5 additions & 0 deletions backpack/utils/__init__.py
@@ -1 +1,6 @@
"""Contains utility functions."""

from pkg_resources import get_distribution, packaging

TORCH_VERSION = packaging.version.parse(get_distribution("torch").version)
VERSION_1_8_0 = packaging.version.parse("1.8.0")

0 comments on commit 4f126b0

Please sign in to comment.