diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 105322eb..31694fcc 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -93,9 +93,10 @@ def topk(self, k): Returns: kmax (*k x batch_shape x event_shape*) """ - return self._struct(KMaxSemiring(k)).marginals( - self.log_potentials, self.lengths, _raw=True - ) + with torch.enable_grad(): + return self._struct(KMaxSemiring(k)).marginals( + self.log_potentials, self.lengths, _raw=True + ) @lazy_property def mode(self):