Skip to content

Commit

Permalink
Faster 0-th order CKY and other clean up (#21)
Browse files Browse the repository at this point in the history
* .

* .

* .

* .
  • Loading branch information
srush committed Oct 15, 2019
1 parent a938e44 commit 4c4c01c
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 177 deletions.
50 changes: 18 additions & 32 deletions torch_struct/cky_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,52 +7,38 @@
class CKY_CRF(_Struct):
def _dp(self, scores, lengths=None, force_grad=False):
semiring = self.semiring
ssize = semiring.size()
batch, N, _, NT = scores.shape
scores = semiring.convert(scores)
if lengths is None:
lengths = torch.LongTensor([N] * batch)
beta = self._make_chart(2, (batch, N, N, NT), scores, force_grad)
span = self._make_chart(N, (batch, N, NT), scores, force_grad)
rule_use = [
self._make_chart(1, (batch, N - w, NT), scores, force_grad)[0]
for w in range(N)
]
scores.requires_grad_(True)
beta = self._make_chart(2, (batch, N, N), scores, force_grad)

# Initialize
reduced_scores = semiring.sum(scores)
ns = torch.arange(N)
rule_use[0][:] = scores[:, :, ns, ns]
rule_use[0].requires_grad_(True)
beta[A][:, :, ns, 0] = rule_use[0]
beta[B][:, :, ns, N - 1] = rule_use[0]
rule_use = reduced_scores[:, :, ns, ns]
beta[A][:, :, ns, 0] = rule_use
beta[B][:, :, ns, N - 1] = rule_use

# Run
for w in range(1, N):
Y = beta[A][:, :, : N - w, :w].view(ssize, batch, N - w, 1, w, NT, 1)
Z = beta[B][:, :, w:, N - w :].view(ssize, batch, N - w, 1, w, 1, NT)
f = torch.arange(N - w), torch.arange(w, N)
X = scores[:, :, f[0], f[1]].view(ssize, batch, N - w, NT)
merge = semiring.times(Y, Z).view(ssize, batch, N - w, 1, -1)
rule_use[w][:] = semiring.times(semiring.sum(merge), X)
Y = beta[A][:, :, : N - w, :w]
Z = beta[B][:, :, w:, N - w :]
f = torch.arange(N - w)
X = reduced_scores[:, :, f, f + w]

span[w] = rule_use[w].view(ssize, batch, N - w, NT)
beta[A][:, :, : N - w, w] = span[w]
beta[A][:, :, : N - w, w] = semiring.times(
semiring.sum(semiring.times(Y, Z)), X
)
beta[B][:, :, w:N, N - w - 1] = beta[A][:, :, : N - w, w]

final = semiring.sum(beta[A][:, :, 0, :])
final = beta[A][:, :, 0]
log_Z = torch.stack([final[:, b, l - 1] for b, l in enumerate(lengths)], dim=1)
return log_Z, rule_use, beta
return log_Z, [scores], beta

def _arrange_marginals(self, grads):
semiring = self.semiring
_, batch, N, NT = grads[0].shape
rules = torch.zeros(
batch, N, N, NT, dtype=grads[0].dtype, device=grads[0].device
)

for w, grad in enumerate(grads):
grad = semiring.unconvert(grad)
f = torch.arange(N - w), torch.arange(w, N)
rules[:, f[0], f[1]] = self.semiring.unconvert(grad)
return rules
return self.semiring.unconvert(grads[0])

def enumerate(self, scores):
semiring = self.semiring
Expand Down
245 changes: 109 additions & 136 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ class DepTree(_Struct):
arc_scores : b x N x N arc scores with root scores on diagonal.
"""

def _dp(self, arc_scores, lengths=None, force_grad=False):
def _dp(self, arc_scores_in, lengths=None, force_grad=False):
semiring = self.semiring
arc_scores = _convert(arc_scores)
arc_scores = _convert(arc_scores_in)
arc_scores, batch, N, lengths = self._check_potentials(arc_scores, lengths)

arc_scores.requires_grad_(True)
DIRS = 2
alpha = [
self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
Expand All @@ -62,11 +64,6 @@ def stack(a, b):
def sstack(a):
return torch.stack([a, a], dim=1)

arcs = [
self._make_chart(1, (DIRS, batch, N - k), arc_scores, force_grad)[0]
for k in range(N)
]

# Inside step. assumes first token is root symbol
semiring.one_(alpha[A][C][:, :, :, :, 0].data)
semiring.one_(alpha[B][C][:, :, :, :, -1].data)
Expand Down Expand Up @@ -97,18 +94,12 @@ def tb(a):
ACL, ACR = AC2.unbind(dim=1)
BCL, BCR = BC2.unbind(dim=1)
start = semiring.dot(BCL, ACR)
# if k == 1:
arcs[k] = stack(
semiring.times(start, arc_scores[:, :, f[1], f[0]]),
semiring.times(start, arc_scores[:, :, f[0], f[1]]),
)
arcsL, arcR = arcs[k].unbind(dim=1)
# else:
# arcs[k] = stack(semiring.times(start), #, arc_scores[:, f[1], f[0]]),
# semiring.times(start)) #, arc_scores[:, f[0], f[1]]))

arcsL = semiring.times(start, arc_scores[:, :, f[1], f[0]])
arcsR = semiring.times(start, arc_scores[:, :, f[0], f[1]])

AIR2 = torch.cat(
[torch.narrow(AIR, 2, 0, N - k), arcR.unsqueeze(-1)], dim=3
[torch.narrow(AIR, 2, 0, N - k), arcsR.unsqueeze(-1)], dim=3
)
BIL2 = torch.cat(
[arcsL.unsqueeze(-1), torch.narrow(BIL, 2, 1, N - k)], dim=3
Expand All @@ -121,8 +112,7 @@ def tb(a):
AIR = AIR2
BIL = BIL2
v = torch.stack([ends[l][:, i] for i, l in enumerate(lengths)], dim=1)
# v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)])
return (v, arcs[1:], alpha)
return (v, [arc_scores], alpha)

def _check_potentials(self, arc_scores, lengths=None):
semiring = self.semiring
Expand All @@ -138,124 +128,8 @@ def _check_potentials(self, arc_scores, lengths=None):

return arc_scores, batch, N, lengths

# def _dp_backward(self, arc_scores, lengths, alpha_in, v=None, force_grad=False):

# # This function is super complicated and was just too slow to include
# semiring = self.semiring
# arc_scores = _convert(arc_scores)
# batch, N, lengths = self._check_potentials(arc_scores, lengths)
# DIRS = 2

# alpha = [
# self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
# for _ in range(2)
# ]

# def stack(a, b, dim=-1):
# return torch.stack([a, b], dim=dim)

# def sstack(a):
# return torch.stack([a, a], dim=-1)

# for k in range(N - 1, -1, -1):
# # Initialize
# if N - k - 1 > 0:
# # R completes
# # I -> C* C
# # I -> C* C
# # C -> I C*
# a = semiring.sum(
# semiring.times(
# *roll2(
# stack(
# stack(alpha[A][I][R], alpha[A][I][L]),
# sstack(alpha_in[B][C][R]),
# dim=0,
# ),
# stack(
# sstack(alpha_in[A][C][L]),
# stack(alpha[B][I][L], alpha[B][I][R]),
# dim=0,
# ),
# N,
# k,
# 1,
# )
# ).view(2, batch, N - k - 1, -1),
# dim=-1,
# )
# alpha[A][C][L, :, 1 : N - k, k] = a[1]
# alpha[A][C][R, :, : N - k - 1, k] = a[0]

# for b, l in enumerate(lengths):
# if l == k:
# alpha[A][C][R, b, 0, l] = semiring.one()
# alpha[B][C][R, b, l, N - l - 1] = semiring.one()

# c = semiring.sum(
# semiring.times(
# *roll2(
# stack(alpha[A][C][L], alpha_in[B][I][R], dim=0),
# stack(alpha_in[A][I][L], alpha[B][C][R], dim=0),
# N,
# k,
# 0,
# )
# )
# )
# alpha[A][C][:, :, : N - k, k] = semiring.plus(
# alpha[A][C][:, :, : N - k, k], c
# )

# # Compute reverses.
# alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k]

# if k > 0:
# f = torch.arange(N - k), torch.arange(k, N)
# alpha[A][I][:, :, : N - k, k] = semiring.dot(
# stack(
# arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]], dim=0
# ).unsqueeze(-1),
# *roll2(
# stack(alpha_in[B][C][L], alpha[A][C][R], dim=0),
# stack(alpha[B][C][L], alpha_in[A][C][R], dim=0),
# N,
# k,
# )
# )
# alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k]

# v = alpha[A][C][R, :, 0, 0]
# left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :])
# right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :])
# ret = torch.zeros(batch, N, N, dtype=left.dtype)
# for k in torch.arange(N):
# f = torch.arange(N - k), torch.arange(k, N)
# ret[:, f[1], k] = left[:, k, f[0]]
# ret[:, k, f[1]] = right[:, k, f[0]]

# ret = semiring.div_exp(ret - arc_scores, v.view(batch, 1, 1))
# return _unconvert(ret)

def _arrange_marginals(self, grads):
_, batch, N = grads[0][0].shape
N = N + 1

ret = torch.zeros(
batch, N, N, dtype=grads[0][0].dtype, device=grads[0][0].device
)
# for k in torch.arange(N):
# f = torch.arange(N - k), torch.arange(k, N)
# ret[:, f[1], k] = grad[L][:, k, f[0]]
# ret[:, k, f[1]] = grad[L][:, k, f[0]]

# ret = torch.zeros(batch, N, N).cpu()
for k, grad in enumerate(grads, 1):
grad = self.semiring.unconvert(grad)
f = torch.arange(N - k), torch.arange(k, N)
ret[:, f[0], f[1]] = grad[R]
ret[:, f[1], f[0]] = grad[L]
return _unconvert(ret)
return _unconvert(self.semiring.unconvert(grads[0]))

@staticmethod
def to_parts(sequence, extra=None, lengths=None):
Expand Down Expand Up @@ -446,3 +320,102 @@ def _is_projective(parse):
):
return False
return True

# def _dp_backward(self, arc_scores, lengths, alpha_in, v=None, force_grad=False):

# # This function is super complicated and was just too slow to include
# semiring = self.semiring
# arc_scores = _convert(arc_scores)
# batch, N, lengths = self._check_potentials(arc_scores, lengths)
# DIRS = 2

# alpha = [
# self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
# for _ in range(2)
# ]

# def stack(a, b, dim=-1):
# return torch.stack([a, b], dim=dim)

# def sstack(a):
# return torch.stack([a, a], dim=-1)

# for k in range(N - 1, -1, -1):
# # Initialize
# if N - k - 1 > 0:
# # R completes
# # I -> C* C
# # I -> C* C
# # C -> I C*
# a = semiring.sum(
# semiring.times(
# *roll2(
# stack(
# stack(alpha[A][I][R], alpha[A][I][L]),
# sstack(alpha_in[B][C][R]),
# dim=0,
# ),
# stack(
# sstack(alpha_in[A][C][L]),
# stack(alpha[B][I][L], alpha[B][I][R]),
# dim=0,
# ),
# N,
# k,
# 1,
# )
# ).view(2, batch, N - k - 1, -1),
# dim=-1,
# )
# alpha[A][C][L, :, 1 : N - k, k] = a[1]
# alpha[A][C][R, :, : N - k - 1, k] = a[0]

# for b, l in enumerate(lengths):
# if l == k:
# alpha[A][C][R, b, 0, l] = semiring.one()
# alpha[B][C][R, b, l, N - l - 1] = semiring.one()

# c = semiring.sum(
# semiring.times(
# *roll2(
# stack(alpha[A][C][L], alpha_in[B][I][R], dim=0),
# stack(alpha_in[A][I][L], alpha[B][C][R], dim=0),
# N,
# k,
# 0,
# )
# )
# )
# alpha[A][C][:, :, : N - k, k] = semiring.plus(
# alpha[A][C][:, :, : N - k, k], c
# )

# # Compute reverses.
# alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k]

# if k > 0:
# f = torch.arange(N - k), torch.arange(k, N)
# alpha[A][I][:, :, : N - k, k] = semiring.dot(
# stack(
# arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]], dim=0
# ).unsqueeze(-1),
# *roll2(
# stack(alpha_in[B][C][L], alpha[A][C][R], dim=0),
# stack(alpha[B][C][L], alpha_in[A][C][R], dim=0),
# N,
# k,
# )
# )
# alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k]

# v = alpha[A][C][R, :, 0, 0]
# left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :])
# right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :])
# ret = torch.zeros(batch, N, N, dtype=left.dtype)
# for k in torch.arange(N):
# f = torch.arange(N - k), torch.arange(k, N)
# ret[:, f[1], k] = left[:, k, f[0]]
# ret[:, k, f[1]] = right[:, k, f[0]]

# ret = semiring.div_exp(ret - arc_scores, v.view(batch, 1, 1))
# return _unconvert(ret)
Loading

0 comments on commit 4c4c01c

Please sign in to comment.