Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 19, 2019
1 parent bff0fec commit 1e7e0aa
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 41 deletions.
73 changes: 33 additions & 40 deletions torch_struct/cky_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,20 @@ def forward(ctx, chart, grad_chart, indices):
@staticmethod
def backward(ctx, grad_output):
grad_chart, = ctx.saved_tensors
grad_chart[ctx.indices] = grad_output
grad_chart[ctx.indices] += grad_output
return grad_chart, None, None

# class Set(torch.autograd.Function):
# @staticmethod
# def forward(ctx, chart, grad_chart, indices, vals):
# chart[indices] = vals
# ctx.save_for_backward(grad_chart)
# ctx.indices = indices
# return None
class Set(torch.autograd.Function):
@staticmethod
def forward(ctx, chart, indices, vals):
chart[indices] = vals
ctx.indices = indices
return chart

# @staticmethod
# def backward(ctx, grad_output):
# grad_chart, = ctx.saved_tensors
# # grad_chart[ctx.indices] += grad_output
# # size, = ctx.saved_tensors
# # grad_chart = torch.zeros(*size.shape)
# grad_chart[ctx.indices] = grad_output
# return grad_chart, None, None
@staticmethod
def backward(ctx, grad_output):
z = grad_output[ctx.indices]
return None, None, z


class CKY_CRF(_Struct):
Expand Down Expand Up @@ -64,38 +59,36 @@ def _dp(self, scores, lengths=None, force_grad=False):
beta[A][:, :, ns, 0] = term
beta[B][:, :, ns, N-1] = term

ai = slice(None)
def ind(pos, width):
return (I, I, pos, width)

I = slice(None)
# Run
print("hellO")
for w in range(1, N):
# Y = beta[A][:, :, : N - w, :]
# Z = beta[B][:, :, w:, :]
# z = zero[:, :, : w]
# score = reduced_scores.diagonal(w, L_DIM, R_DIM)
# new = semiring.times(semiring.dot(Y, Z), score).unsqueeze(-1)
# print(new.shape, z.shape, beta[A].shape)
# beta[A] = c(beta[A], p(new, z))
# beta[B] = c(p(z, new), beta[B])
left = slice(None, N-w)
right = slice(w, None)

Y = Get.apply(beta[A], grad_beta[A],
(ai, ai, slice(None, N - w), slice(None, w)))
Z = Get.apply(beta[B], grad_beta[B],
(ai, ai, slice(w, None), slice(N-w, None)))
score = reduced_scores.diagonal(w, L_DIM, R_DIM)
new = semiring.times(semiring.dot(Y, Z), score)
# Set.apply(beta[A], grad_beta[A], (ai, ai, slice(None, N - w), w), new)
# beta[B][(ai, ai, slice(w,N), N - w - 1)] = \
# new

# beta[A][(ai, ai, slice(None, N-w), slice(w, w+1))]

beta[A][(ai, ai, slice(None, N - w), w)] = new
# Get.apply(beta[A], grad_beta[A],
# (ai, ai, slice(None, N-w), slice(w, w+1)))
# beta[A][(ai, ai, slice(None, N - w), w)] = new
# beta[A][:, :, : N - w, w] =
# beta[B][:, :, w:N, N - w - 1] = beta[A][:, :, :N-w, w]

beta[B][:, :, w:N, N - w - 1:N-w] = beta[A][(ai, ai, slice(None, N-w), slice(w, w+1))]
# Get.apply(beta[A], grad_beta[A],
# (ai, ai, slice(None, N-w), slice(w, w+1)))


Y = Get.apply(beta[A], grad_beta[A],
ind(left, slice(None, w)))
Z = Get.apply(beta[B], grad_beta[B],
ind(right, slice(N-w, None)))
score = reduced_scores.diagonal(w, L_DIM, R_DIM)
new = semiring.times(semiring.dot(Y, Z), score)
beta[A] = Set.apply(beta[A], ind(left, w), new)
beta[B] = Set.apply(beta[B], ind(right, N - w - 1), new)

final = beta[A][:, :, 0]
final = Get.apply(beta[A], grad_beta[A], ind(0, I))
log_Z = final[:, torch.arange(batch), lengths - 1]
return log_Z, [scores], beta

Expand Down
2 changes: 1 addition & 1 deletion torch_struct/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_parts_from_sequence(data, seed):
@settings(max_examples=50, deadline=None)
def test_generic_lengths(data, seed):
model = data.draw(
sampled_from([Alignment, LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree])
sampled_from([CKY_CRF])#, Alignment, LinearChain, SemiMarkov, CKY, DepTree])
)
struct = model()
torch.manual_seed(seed)
Expand Down

0 comments on commit 1e7e0aa

Please sign in to comment.