From d921f3070f6cc3515a127a42df4480254a8c98eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Y=C3=A4n=2EPnG?= Date: Tue, 10 Mar 2020 20:43:40 +0000 Subject: [PATCH] enable gradient calculation in dist.topk() by default (#56) --- torch_struct/distributions.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 105322eb..31694fcc 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -93,9 +93,10 @@ def topk(self, k): Returns: kmax (*k x batch_shape x event_shape*) """ - return self._struct(KMaxSemiring(k)).marginals( - self.log_potentials, self.lengths, _raw=True - ) + with torch.enable_grad(): + return self._struct(KMaxSemiring(k)).marginals( + self.log_potentials, self.lengths, _raw=True + ) @lazy_property def mode(self):