From 05fa02984950ca804872ab6eb8a7df73d1ef306f Mon Sep 17 00:00:00 2001 From: Alex Rush Date: Sun, 8 Sep 2019 14:40:48 -0400 Subject: [PATCH] . --- torch_struct/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index e15e0c02..d51ff9e8 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -86,7 +86,7 @@ def marginals(self, edge, lengths=None, _autograd=True): marginals: b x (N-1) x C x C table """ - v, edge, alpha = self._dp(edge, lengths=lengths, force_grad=True) + v, edges, alpha = self._dp(edge, lengths=lengths, force_grad=True) if ( _autograd or self.semiring is not LogSemiring @@ -94,7 +94,7 @@ def marginals(self, edge, lengths=None, _autograd=True): ): marg = torch.autograd.grad( v.sum(dim=0), - edge, + edges, create_graph=True, only_inputs=True, allow_unused=False,