Skip to content

Commit

Permalink
darglint ignore logpartition docstring mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Jan 22, 2021
1 parent 71004b2 commit cded5e1
Showing 1 changed file with 5 additions and 16 deletions.
21 changes: 5 additions & 16 deletions torch_struct/helpers.py
Expand Up @@ -6,11 +6,7 @@
class Chart:
def __init__(self, size, potentials, semiring):
self.data = semiring.zero_(
torch.zeros(
*((semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
torch.zeros(*((semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device)
)
self.grad = self.data.detach().clone().fill_(0.0)

Expand Down Expand Up @@ -49,6 +45,7 @@ def logpartition(self, scores, lengths=None, force_grad=False):
An exceptional case is the `CKY` struct, which takes log potential parameters from production rules
for a PCFG, which are by definition independent of position in the sequence.
# noqa: DAR401, DAR202
"""
raise NotImplementedError()

Expand Down Expand Up @@ -80,11 +77,7 @@ def _make_chart(self, N, size, potentials, force_grad=False):
return [
(
self.semiring.zero_(
torch.zeros(
*((self.semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
torch.zeros(*((self.semiring.size(),) + size), dtype=potentials.dtype, device=potentials.device)
).requires_grad_(force_grad and not potentials.requires_grad)
)
for _ in range(N)
Expand Down Expand Up @@ -120,9 +113,7 @@ def marginals(self, logpotentials, lengths=None, _raw=False):
"""
with torch.autograd.enable_grad(): # in case input potentials don't have grads enabled.
v, edges = self.logpartition(
logpotentials, lengths=lengths, force_grad=True
)
v, edges = self.logpartition(logpotentials, lengths=lengths, force_grad=True)
if _raw:
all_m = []
for k in range(v.shape[0]):
Expand All @@ -139,9 +130,7 @@ def marginals(self, logpotentials, lengths=None, _raw=False):
return torch.stack(all_m, dim=0)
else:
obj = self.semiring.unconvert(v).sum(dim=0)
marg = torch.autograd.grad(
obj, edges, create_graph=True, only_inputs=True, allow_unused=False
)
marg = torch.autograd.grad(obj, edges, create_graph=True, only_inputs=True, allow_unused=False)
a_m = self._arrange_marginals(marg)
return self.semiring.unconvert(a_m)

Expand Down

0 comments on commit cded5e1

Please sign in to comment.