Skip to content

Commit

Permalink
Merge 8dbcdf1 into 1c9b038
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 14, 2019
2 parents 1c9b038 + 8dbcdf1 commit 86b8813
Showing 1 changed file with 18 additions and 31 deletions.
49 changes: 18 additions & 31 deletions torch_struct/cky_crf.py
Expand Up @@ -12,47 +12,34 @@ def _dp(self, scores, lengths=None, force_grad=False):
scores = semiring.convert(scores)
if lengths is None:
lengths = torch.LongTensor([N] * batch)
beta = self._make_chart(2, (batch, N, N, NT), scores, force_grad)
span = self._make_chart(N, (batch, N, NT), scores, force_grad)
rule_use = [
self._make_chart(1, (batch, N - w, NT), scores, force_grad)[0]
for w in range(N)
]
scores.requires_grad_(True)
beta = self._make_chart(2, (batch, N, N), scores, force_grad)

# Initialize
reduced_scores = semiring.sum(scores)
ns = torch.arange(N)
rule_use[0][:] = scores[:, :, ns, ns]
rule_use[0].requires_grad_(True)
beta[A][:, :, ns, 0] = rule_use[0]
beta[B][:, :, ns, N - 1] = rule_use[0]
rule_use = reduced_scores[:, :, ns, ns]
beta[A][:, :, ns, 0] = rule_use
beta[B][:, :, ns, N - 1] = rule_use

# Run
for w in range(1, N):
Y = beta[A][:, :, : N - w, :w].view(ssize, batch, N - w, 1, w, NT, 1)
Z = beta[B][:, :, w:, N - w :].view(ssize, batch, N - w, 1, w, 1, NT)
f = torch.arange(N - w), torch.arange(w, N)
X = scores[:, :, f[0], f[1]].view(ssize, batch, N - w, NT)
merge = semiring.times(Y, Z).view(ssize, batch, N - w, 1, -1)
rule_use[w][:] = semiring.times(semiring.sum(merge), X)
Y = beta[A][:, :, : N - w, :w]
Z = beta[B][:, :, w:, N - w :]
f = torch.arange(N - w)
X = reduced_scores[:, :, f, f + w]

span[w] = rule_use[w].view(ssize, batch, N - w, NT)
beta[A][:, :, : N - w, w] = span[w]
beta[A][:, :, : N - w, w] = semiring.times(
semiring.sum(semiring.times(Y, Z)), X
)
beta[B][:, :, w:N, N - w - 1] = beta[A][:, :, : N - w, w]

final = semiring.sum(beta[A][:, :, 0, :])
final = beta[A][:, :, 0]
log_Z = torch.stack([final[:, b, l - 1] for b, l in enumerate(lengths)], dim=1)
return log_Z, rule_use, beta
return log_Z, [scores], beta

def _arrange_marginals(self, grads):
semiring = self.semiring
_, batch, N, NT = grads[0].shape
rules = torch.zeros(
batch, N, N, NT, dtype=grads[0].dtype, device=grads[0].device
)

for w, grad in enumerate(grads):
grad = semiring.unconvert(grad)
f = torch.arange(N - w), torch.arange(w, N)
rules[:, f[0], f[1]] = self.semiring.unconvert(grad)
return rules
return grads[0]

def enumerate(self, scores):
semiring = self.semiring
Expand Down

0 comments on commit 86b8813

Please sign in to comment.