diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index e15e0c02..d51ff9e8 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -86,7 +86,7 @@ def marginals(self, edge, lengths=None, _autograd=True): marginals: b x (N-1) x C x C table """ - v, edge, alpha = self._dp(edge, lengths=lengths, force_grad=True) + v, edges, alpha = self._dp(edge, lengths=lengths, force_grad=True) if ( _autograd or self.semiring is not LogSemiring @@ -94,7 +94,7 @@ def marginals(self, edge, lengths=None, _autograd=True): ): marg = torch.autograd.grad( v.sum(dim=0), - edge, + edges, create_graph=True, only_inputs=True, allow_unused=False,