Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 18, 2019
1 parent 905dd78 commit f05300a
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions torch_struct/cky_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,39 @@ def _check_potentials(self, edge, lengths=None):
def _dp(self, scores, lengths=None, force_grad=False):
semiring = self.semiring
scores, batch, N, NT, lengths = self._check_potentials(scores, lengths)
beta = self._make_chart(2, (batch, N, N), scores, force_grad)
beta = self._make_chart(2, (batch, N, 1), scores, force_grad)
zero = self._make_chart(1, (batch, N, 1), scores, force_grad)[0]
L_DIM, R_DIM = 2, 3

# Initialize
reduced_scores = semiring.sum(scores)
term = reduced_scores.diagonal(0, L_DIM, R_DIM)
ns = torch.arange(N)
beta[A][:, :, ns, 0] = term
beta[B][:, :, ns, N - 1] = term
beta[B][:, :, ns, 0] = term
zero.requires_grad_(False)

def c(a, b):
return torch.cat([a, b], dim=-1)
def p(a, b):
return torch.cat([a, b], dim=2)

# Run
for w in range(1, N):
Y = beta[A][:, :, : N - w, :w]
Z = beta[B][:, :, w:, N - w :]
Y = beta[A][:, :, : N - w, :]
Z = beta[B][:, :, w:, :]
z = zero[:, :, : w]
score = reduced_scores.diagonal(w, L_DIM, R_DIM)
beta[A][:, :, : N - w, w] = semiring.times(semiring.dot(Y, Z), score)
beta[B][:, :, w:N, N - w - 1] = beta[A][:, :, : N - w, w]
new = semiring.times(semiring.dot(Y, Z), score).unsqueeze(-1)
beta[A] = c(beta[A], p(new, z))
beta[B] = c(p(z, new), beta[B])

# Y = beta[A][:, :, : N - w, :w]
# Z = beta[B][:, :, w:, N - w :]
# score = reduced_scores.diagonal(w, L_DIM, R_DIM)

# beta[A][:, :, : N - w, w] = semiring.times(semiring.dot(Y, Z), score)
# beta[B][:, :, w:N, N - w - 1] = beta[A][:, :, : N - w, w]

final = beta[A][:, :, 0]
log_Z = final[:, torch.arange(batch), lengths - 1]
Expand Down

0 comments on commit f05300a

Please sign in to comment.