Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 7, 2019
1 parent fa30814 commit 068880d
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 137 deletions.
74 changes: 40 additions & 34 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

A, B = 0, 1


class DPManual2(Function):
@staticmethod
def forward(ctx, obj, terms, rules, roots, lengths):
Expand All @@ -21,12 +22,12 @@ def forward(ctx, obj, terms, rules, roots, lengths):
def backward(ctx, grad_v):
terms, rules, roots = ctx.saved_tensors
with torch.no_grad():
marginals = ctx.obj._dp_backward((terms, rules, roots),
ctx.lengths, ctx.alpha, ctx.v)
marginals = ctx.obj._dp_backward(
(terms, rules, roots), ctx.lengths, ctx.alpha, ctx.v
)
return None, marginals[0], marginals[1].sum(1).sum(1), marginals[2], None



class CKY(_Struct):
def sum(self, scores, lengths=None, force_grad=False, _autograd=False):
"""
Expand All @@ -41,7 +42,7 @@ def sum(self, scores, lengths=None, force_grad=False, _autograd=False):
v: b tensor of total sum
spans: list of N, b x N x (NT+t)
"""
if _autograd or not self.semiring is LogSemiring:
if _autograd or self.semiring is not LogSemiring:
return self._dp(scores, lengths)[0]
else:
return DPManual2.apply(self, *scores, lengths)
Expand Down Expand Up @@ -95,66 +96,71 @@ def _dp_backward(self, scores, lengths, alpha_in, v, force_grad=False):
beta = self._make_chart(2, (batch, N, N, NT + T), rules, force_grad)
span_l = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
span_r = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
top = self._make_chart(1, (batch, NT), rules, force_grad)[0]
term_use = self._make_chart(1, (batch, N, T), terms, force_grad)[0]


ssum = semiring.sum
st = semiring.times
X_Y_Z = rules.view(batch, 1, NT, S, S)

for w in range(N-1, -1, -1):
for w in range(N - 1, -1, -1):
for b, l in enumerate(lengths):
beta[A][b, 0, l-1, :NT] = roots[b]
beta[B][b, l-1, N-(l), :NT] = roots[b]
beta[A][b, 0, l - 1, :NT] = roots[b]
beta[B][b, l - 1, N - (l), :NT] = roots[b]

# LEFT
# all bigger on the left.
X = beta[A][:, :N-w-1, w+1:, :NT].view(batch, N-w-1, N-w-1, NT, 1, 1)
Z = alpha_in[A][:, w+1:N, 0:N-w-1].view(batch, N-w-1, N-w-1, 1, 1, S)
X = beta[A][:, : N - w - 1, w + 1 :, :NT].view(
batch, N - w - 1, N - w - 1, NT, 1, 1
)
Z = alpha_in[A][:, w + 1 : N, 0 : N - w - 1].view(
batch, N - w - 1, N - w - 1, 1, 1, S
)
t = st(ssum(st(X, Z), dim=2), X_Y_Z)
# sum out x and y
span_l[w] = ssum(ssum(t, dim =-3), dim=-1)
span_l[w] = ssum(ssum(t, dim=-3), dim=-1)

# RIGHT
X = beta[B][:, w+1:, :N-1-w, :NT].view(batch, N-w-1, N-w-1, NT, 1, 1)
Y = alpha_in[B][:, :N-w-1, w+1:, :].view(batch, N-w-1, N-w-1, 1, S, 1)
X = beta[B][:, w + 1 :, : N - 1 - w, :NT].view(
batch, N - w - 1, N - w - 1, NT, 1, 1
)
Y = alpha_in[B][:, : N - w - 1, w + 1 :, :].view(
batch, N - w - 1, N - w - 1, 1, S, 1
)
t = st(ssum(st(X, Y), dim=2), X_Y_Z)

span_r[w] = ssum(ssum(t, dim=-3), dim=-2)

beta[A][:, :N-w-1, w, :] = span_l[w]
beta[A][:, 1:N-w, w, :] = ssum(torch.stack([span_r[w],
beta[A][:, 1:N-w, w, :]]), dim=0)
beta[B][:, w:, N-w-1, :] = beta[A][:, :N - w, w, :]

beta[A][:, : N - w - 1, w, :] = span_l[w]
beta[A][:, 1 : N - w, w, :] = ssum(
torch.stack([span_r[w], beta[A][:, 1 : N - w, w, :]]), dim=0
)
beta[B][:, w:, N - w - 1, :] = beta[A][:, : N - w, w, :]

term_use[:, :, :] = st(beta[A][:, :, 0, NT:], terms)
term_marginals = self._make_chart(1, (batch, N, T), terms, force_grad=False)[0]
for n in range(N):
term_marginals[:, n] = semiring.div_exp(term_use[:, n],
v.view(batch, 1))
term_marginals[:, n] = semiring.div_exp(term_use[:, n], v.view(batch, 1))

root_marginals = self._make_chart(1, (batch, NT), terms, force_grad=False)[0]
for b in range(batch):
root_marginals[b] = semiring.div_exp(st(alpha_in[A][b, 0, lengths[b]-1, :NT], roots[b]),
v[b].view(1))
edge_marginals = self._make_chart(1, (batch, N, N, NT, S, S), terms, force_grad=False)[0]
root_marginals[b] = semiring.div_exp(
st(alpha_in[A][b, 0, lengths[b] - 1, :NT], roots[b]), v[b].view(1)
)
edge_marginals = self._make_chart(
1, (batch, N, N, NT, S, S), terms, force_grad=False
)[0]
edge_marginals.fill_(0)
for w in range(1, N):
Y = alpha_in[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1)
Z = alpha_in[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S)
score = semiring.times(
semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z
score = semiring.times(semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z)
score = st(score, beta[A][:, : N - w, w, :NT].view(batch, N - w, NT, 1, 1))
edge_marginals[:, : N - w, w - 1] = semiring.div_exp(
score, v.view(batch, 1, 1, 1, 1)
)
score = st(score, beta[A][:, :N-w, w, :NT].view(batch, N-w, NT, 1, 1))
edge_marginals[:, :N-w, w-1] = semiring.div_exp(score,
v.view(batch, 1, 1, 1, 1))
edge_marginals = edge_marginals.transpose(1, 2)


return (term_marginals, edge_marginals, root_marginals)

return (term_marginals, edge_marginals, root_marginals)

def marginals(self, scores, lengths=None, _autograd=False):
"""
Expand All @@ -177,7 +183,7 @@ def marginals(self, scores, lengths=None, _autograd=False):
v, (term_use, rule_use, top), alpha = self._dp(
scores, lengths=lengths, force_grad=True
)
if _autograd or not self.semiring is LogSemiring:
if _autograd or self.semiring is not LogSemiring:
marg = torch.autograd.grad(
v.sum(dim=0),
tuple(rule_use) + (top, term_use),
Expand All @@ -193,7 +199,7 @@ def marginals(self, scores, lengths=None, _autograd=False):
assert marg[-2].shape == (batch, NT)
return (marg[-1], rules, marg[-2])
else:
return self._dp_backward(edge, lengths, alpha, v)
return self._dp_backward(scores, lengths, alpha, v)

@staticmethod
def to_parts(spans, extra, lengths=None):
Expand Down
139 changes: 75 additions & 64 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import itertools
from .helpers import _Struct, DPManual, roll
from .semirings import LogSemiring
from .helpers import _Struct, roll


def _convert(logits):
"move root arcs from diagonal"
Expand Down Expand Up @@ -36,7 +36,7 @@ def _unconvert(logits):
class DepTree(_Struct):
"""
A projective dependency CRF.
Parameters:
arc_scores : b x N x N arc scores with root scores on diagonal.
"""
Expand All @@ -45,8 +45,9 @@ def _dp(self, arc_scores, lengths=None, force_grad=False):
semiring = self.semiring
arc_scores = _convert(arc_scores)
batch, N, lengths = self._check_potentials(arc_scores, lengths)

DIRS = 2

def stack(a, b):
return torch.stack([a, b])

Expand Down Expand Up @@ -87,14 +88,9 @@ def sstack(a):
alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k]
v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)])
print(v)
return (
v,
arcs[1:],
alpha
)


def _check_potentials(self, arc_scores, lengths = None):
return (v, arcs[1:], alpha)

def _check_potentials(self, arc_scores, lengths=None):
semiring = self.semiring
batch, N, N2 = arc_scores.shape
assert N == N2, "Non-square potentials"
Expand All @@ -106,8 +102,8 @@ def _check_potentials(self, arc_scores, lengths = None):
arc_scores[b, :, lengths[b] + 1 :] = semiring.zero()

return batch, N, lengths
def _dp_backward(self, arc_scores, lengths, alpha_in, v = None, force_grad=False):

def _dp_backward(self, arc_scores, lengths, alpha_in, v=None, force_grad=False):

# This function is super complicated.
semiring = self.semiring
Expand All @@ -119,87 +115,102 @@ def _dp_backward(self, arc_scores, lengths, alpha_in, v = None, force_grad=False
self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
for _ in range(2)
]
arcs = self._make_chart(N, (DIRS, batch, N), arc_scores, force_grad)

def stack(a, b):
return torch.stack([a, b], dim=-1)

def sstack(a):
return torch.stack([a, a], dim=-1)

for k in range(N-1, -1, -1):
for k in range(N - 1, -1, -1):
# Initialize
for b, l in enumerate(lengths):
alpha[A][C][R, b, 0, l] = semiring.one()
alpha[B][C][R, b, l, N-l-1] = semiring.one()

alpha[B][C][R, b, l, N - l - 1] = semiring.one()

# R completes
#I -> C* C
#I -> C* C
#C -> I C*
a = semiring.dot(*roll(stack(alpha[A][I][R], alpha[A][I][L]),
sstack(alpha_in[A][C][L]), N, k, 1))
# I -> C* C
# I -> C* C
# C -> I C*
a = semiring.dot(
*roll(
stack(alpha[A][I][R], alpha[A][I][L]),
sstack(alpha_in[A][C][L]),
N,
k,
1,
)
)

c = semiring.dot(*roll(alpha_in[B][I][R],
alpha[B][C][R], N, k, 0))
c = semiring.dot(*roll(alpha_in[B][I][R], alpha[B][C][R], N, k, 0))

alpha[A][C][R, :, :N-k-1, k] = semiring.plus(semiring.sum(a),
alpha[A][C][R, :, :N-k-1, k])
alpha[A][C][R, :, : N - k - 1, k] = semiring.plus(
semiring.sum(a), alpha[A][C][R, :, : N - k - 1, k]
)

alpha[A][C][R][:, : N - k, k] = semiring.plus(
alpha[A][C][R][:, : N - k, k], c
)

alpha[A][C][R][:, :N-k, k] = \
semiring.plus(alpha[A][C][R][:, :N-k, k], c)

# L completes
#I -> C* C
#I -> C* C
#C -> I C*
a = semiring.dot(*roll(sstack(alpha_in[B][C][R]),
stack(alpha[B][I][L], alpha[B][I][R]),
N, k, 1))


c = semiring.dot(*roll(alpha[A][C][L],
alpha_in[A][I][L], N, k, 0))

alpha[A][C][L, :, 1:N-k, k] = \
semiring.plus(semiring.sum(a), alpha[A][C][L, :, 1:N-k, k])
alpha[A][C][L][:, :N-k, k] = \
semiring.plus(c, alpha[A][C][L][:, :N-k, k])
# I -> C* C
# I -> C* C
# C -> I C*
a = semiring.dot(
*roll(
sstack(alpha_in[B][C][R]),
stack(alpha[B][I][L], alpha[B][I][R]),
N,
k,
1,
)
)

c = semiring.dot(*roll(alpha[A][C][L], alpha_in[A][I][L], N, k, 0))

alpha[A][C][L, :, 1 : N - k, k] = semiring.plus(
semiring.sum(a), alpha[A][C][L, :, 1 : N - k, k]
)
alpha[A][C][L][:, : N - k, k] = semiring.plus(
c, alpha[A][C][L][:, : N - k, k]
)

# Compute reverses.
alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k]

if k > 0:
f = torch.arange(N-k), torch.arange(k, N)
f = torch.arange(N - k), torch.arange(k, N)

# Incomplete
alpha[A][I][R][:, :N-k, k] = semiring.dot(
alpha[A][I][R][:, : N - k, k] = semiring.dot(
arc_scores[:, f[0], f[1]].unsqueeze(-1),
*roll(alpha[A][C][R],
alpha_in[A][C][R], N, k))
*roll(alpha[A][C][R], alpha_in[A][C][R], N, k)
)

#C -> C I
alpha[A][I][L][:, :N-k, k] = semiring.dot(
# C -> C I
alpha[A][I][L][:, : N - k, k] = semiring.dot(
arc_scores[:, f[1], f[0]].unsqueeze(-1),
*roll(alpha_in[B][C][L],
alpha[B][C][L], N, k))
*roll(alpha_in[B][C][L], alpha[B][C][L], N, k)
)

# Compute reverses
alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k]

v = alpha[A][C][R, :, 0, 0]
left = semiring.times(alpha[A][I][L, :, :, :],
alpha_in[A][I][L, :, :, :])
right = semiring.times(alpha[A][I][R, :, :, :],
alpha_in[A][I][R, :, :, :])
left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :])
right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :])

ret = torch.zeros(batch, N, N)
for k in range(N):
for d in range(N-k):
ret[:, k+d, k] = semiring.div_exp(left[:, k, d] - arc_scores[:, k+d, k], v.view(batch))
ret[:, k, k+d] = semiring.div_exp(right[:, k, d]- arc_scores[:, k, k+d], v.view(batch))
for d in range(N - k):
ret[:, k + d, k] = semiring.div_exp(
left[:, k, d] - arc_scores[:, k + d, k], v.view(batch)
)
ret[:, k, k + d] = semiring.div_exp(
right[:, k, d] - arc_scores[:, k, k + d], v.view(batch)
)
return _unconvert(ret)


def _arrange_marginals(self, grads):
batch, N = grads[0][0].shape
N = N + 1
Expand All @@ -209,7 +220,7 @@ def _arrange_marginals(self, grads):
ret[:, f[0], f[1]] = grad[R].cpu()
ret[:, f[1], f[0]] = grad[L].cpu()
return _unconvert(ret)

@staticmethod
def to_parts(sequence, extra=None, lengths=None):
"""
Expand Down
Loading

0 comments on commit 068880d

Please sign in to comment.