Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 18, 2019
1 parent f05300a commit 97e4d6b
Showing 1 changed file with 63 additions and 20 deletions.
83 changes: 63 additions & 20 deletions torch_struct/cky_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,41 @@
A, B = 0, 1



class Get(torch.autograd.Function):
@staticmethod
def forward(ctx, chart, grad_chart, indices):
out = chart[indices]
ctx.save_for_backward(grad_chart)
ctx.indices = indices
return out

@staticmethod
def backward(ctx, grad_output):
grad_chart, = ctx.saved_tensors
grad_chart[ctx.indices] = grad_output
return grad_chart, None, None

# class Set(torch.autograd.Function):
# @staticmethod
# def forward(ctx, chart, grad_chart, indices, vals):
# chart[indices] = vals
# ctx.save_for_backward(grad_chart)
# ctx.indices = indices
# return None

# @staticmethod
# def backward(ctx, grad_output):
# grad_chart, = ctx.saved_tensors
# # grad_chart[ctx.indices] += grad_output
# # size, = ctx.saved_tensors
# # grad_chart = torch.zeros(*size.shape)
# grad_chart[ctx.indices] = grad_output
# return grad_chart, None, None


class CKY_CRF(_Struct):

def _check_potentials(self, edge, lengths=None):
batch, N, _, NT = edge.shape
edge.requires_grad_(True)
Expand All @@ -17,39 +51,48 @@ def _check_potentials(self, edge, lengths=None):
def _dp(self, scores, lengths=None, force_grad=False):
semiring = self.semiring
scores, batch, N, NT, lengths = self._check_potentials(scores, lengths)
beta = self._make_chart(2, (batch, N, 1), scores, force_grad)
zero = self._make_chart(1, (batch, N, 1), scores, force_grad)[0]
beta = self._make_chart(2, (batch, N, N), scores, force_grad)
grad_beta = self._make_chart(2, (batch, N, N), scores, force_grad)
grad_beta[A].fill_(0.0)
grad_beta[B].fill_(0.0)
L_DIM, R_DIM = 2, 3

# Initialize
reduced_scores = semiring.sum(scores)
term = reduced_scores.diagonal(0, L_DIM, R_DIM)
ns = torch.arange(N)
beta[A][:, :, ns, 0] = term
beta[B][:, :, ns, 0] = term
zero.requires_grad_(False)

def c(a, b):
return torch.cat([a, b], dim=-1)
def p(a, b):
return torch.cat([a, b], dim=2)
beta[B][:, :, ns, N-1] = term

ai = slice(None)
# Run
for w in range(1, N):
Y = beta[A][:, :, : N - w, :]
Z = beta[B][:, :, w:, :]
z = zero[:, :, : w]
# Y = beta[A][:, :, : N - w, :]
# Z = beta[B][:, :, w:, :]
# z = zero[:, :, : w]
# score = reduced_scores.diagonal(w, L_DIM, R_DIM)
# new = semiring.times(semiring.dot(Y, Z), score).unsqueeze(-1)
# print(new.shape, z.shape, beta[A].shape)
# beta[A] = c(beta[A], p(new, z))
# beta[B] = c(p(z, new), beta[B])

Y = Get.apply(beta[A], grad_beta[A],
(ai, ai, slice(None, N - w), slice(None, w)))
Z = Get.apply(beta[B], grad_beta[B],
(ai, ai, slice(w, None), slice(N-w, None)))
score = reduced_scores.diagonal(w, L_DIM, R_DIM)
new = semiring.times(semiring.dot(Y, Z), score).unsqueeze(-1)
beta[A] = c(beta[A], p(new, z))
beta[B] = c(p(z, new), beta[B])
new = semiring.times(semiring.dot(Y, Z), score)
# Set.apply(beta[A], grad_beta[A], (ai, ai, slice(None, N - w), w), new)

beta[A][(ai, ai, slice(None, N - w), w)] = new
# beta[A][:, :, : N - w, w] =
# beta[B][:, :, w:N, N - w - 1] = beta[A][:, :, :N-w, w]

beta[B][:, :, w:N, N - w - 1:N-w] = beta[A][(ai, ai, slice(None, N-w), slice(w, w+1))]
# Get.apply(beta[A], grad_beta[A],
# (ai, ai, slice(None, N-w), slice(w, w+1)))

# Y = beta[A][:, :, : N - w, :w]
# Z = beta[B][:, :, w:, N - w :]
# score = reduced_scores.diagonal(w, L_DIM, R_DIM)

# beta[A][:, :, : N - w, w] = semiring.times(semiring.dot(Y, Z), score)
# beta[B][:, :, w:N, N - w - 1] = beta[A][:, :, : N - w, w]

final = beta[A][:, :, 0]
log_Z = final[:, torch.arange(batch), lengths - 1]
Expand Down

0 comments on commit 97e4d6b

Please sign in to comment.