Skip to content

Commit

Permalink
Merge a6179d8 into 6f06de9
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 9, 2019
2 parents 6f06de9 + a6179d8 commit a2f89c3
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 148 deletions.
4 changes: 2 additions & 2 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def backward(ctx, grad_v):


class CKY(_Struct):
def sum(self, scores, lengths=None, force_grad=False, _autograd=False):
def sum(self, scores, lengths=None, force_grad=False, _autograd=True):
"""
Compute the inside pass of a CFG using CKY.
Expand Down Expand Up @@ -162,7 +162,7 @@ def _dp_backward(self, scores, lengths, alpha_in, v, force_grad=False):

return (term_marginals, edge_marginals, root_marginals)

def marginals(self, scores, lengths=None, _autograd=False):
def marginals(self, scores, lengths=None, _autograd=True):
"""
Compute the marginals of a CFG using CKY.
Expand Down
246 changes: 141 additions & 105 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import itertools
from .helpers import _Struct, roll
from .helpers import _Struct, roll2


def _convert(logits):
Expand All @@ -10,22 +10,27 @@ def _convert(logits):
).type_as(logits.data)
new_logits.fill_(-1e9)
new_logits[:, 1:, 1:] = logits
for i in range(0, logits.size(1)):
new_logits[:, 0, i + 1] = logits[:, i, i]
new_logits[:, i + 1, i + 1] = -1e9

N = logits.size(1)
new_logits[:, 0, 1:] = logits[:, torch.arange(N), torch.arange(N)]
new_logits[:, torch.arange(1, N), torch.arange(1, N)] = -1e9
return new_logits


def _unconvert(logits):
"Move root arcs to diagonal"
new_logits = torch.zeros(
logits.size(0), logits.size(1) - 1, logits.size(2) - 1
).type_as(logits.data)
logits.size(0),
logits.size(1) - 1,
logits.size(2) - 1,
dtype=logits.dtype,
device=logits.device,
)

new_logits.fill_(-1e9)
new_logits[:, :, :] = logits[:, 1:, 1:]
for i in range(0, new_logits.size(1)):
new_logits[:, i, i] = logits[:, 0, i + 1]

N = new_logits.size(1)
new_logits[:, torch.arange(N), torch.arange(N)] = logits[:, 0, 1:]
return new_logits


Expand All @@ -48,46 +53,77 @@ def _dp(self, arc_scores, lengths=None, force_grad=False):

DIRS = 2

alpha = [
self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
for _ in range(2)
]

def stack(a, b):
return torch.stack([a, b])

def sstack(a):
return torch.stack([a, a])

alpha = [
self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
for _ in range(2)
arcs = [
self._make_chart(1, (DIRS, batch, N - k), arc_scores, force_grad)[0]
for k in range(N)
]
arcs = self._make_chart(N, (DIRS, batch, N), arc_scores, force_grad)

# Inside step. assumes first token is root symbol
alpha[A][C][:, :, :, 0].data.fill_(semiring.one())
alpha[B][C][:, :, :, -1].data.fill_(semiring.one())
k = 0

AIR = alpha[A][I][R, :, : N - k, 1:k]
BIL = alpha[B][I][L, :, k:N, N - k : N - 1]
k = 1
AC2 = alpha[A][C][:, :, : N - k, :k]
BC2 = alpha[B][C][:, :, k:, N - k :]
AC, BC, AC_next = None, None, None

ends = [None]
for k in range(1, N):

def tf(a):
return torch.narrow(a, 2, 0, N - k)

def tb(a):
return torch.narrow(a, 2, 1, N - k)

f = torch.arange(N - k), torch.arange(k, N)
arcs[k] = semiring.dot(
sstack(alpha[A][C][R, :, : N - k, :k]),
sstack(alpha[B][C][L, :, k:, N - k :]),
stack(arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]]).unsqueeze(
-1
),
if k > 1:
AC2 = torch.cat([tf(AC), tf(AC_next).unsqueeze(-1)], dim=3)
if k > 1:
BC2 = torch.cat([tb(AC_next).unsqueeze(-1), tb(BC)], dim=3)

ACL, ACR = AC2.unbind()
BCL, BCR = BC2.unbind()
start = semiring.dot(BCL, ACR)
# if k == 1:
arcs[k] = stack(
semiring.times(start, arc_scores[:, f[1], f[0]]),
semiring.times(start, arc_scores[:, f[0], f[1]]),
)
alpha[A][I][:, :, : N - k, k] = arcs[k]
alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k]
alpha[A][C][:, :, : N - k, k] = semiring.dot(
stack(
alpha[A][C][L, :, : N - k, :k],
alpha[A][I][R, :, : N - k, 1 : k + 1],
),
stack(
alpha[B][I][L, :, k:, N - k - 1 : N - 1],
alpha[B][C][R, :, k:, N - k :],
),
arcsL, arcR = arcs[k].unbind()
# else:
# arcs[k] = stack(semiring.times(start), #, arc_scores[:, f[1], f[0]]),
# semiring.times(start)) #, arc_scores[:, f[0], f[1]]))

AIR2 = torch.cat(
[torch.narrow(AIR, 1, 0, N - k), arcR.unsqueeze(-1)], dim=2
)
alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k]
v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)])
print(v)
BIL2 = torch.cat(
[arcsL.unsqueeze(-1), torch.narrow(BIL, 1, 1, N - k)], dim=2
)
AC_next = stack(semiring.dot(ACL, BIL2), semiring.dot(AIR2, BCR))

ends.append(AC_next[R, :, 0])
AC = AC2
BC = BC2
AIR = AIR2
BIL = BIL2
v = torch.stack([ends[l][i] for i, l in enumerate(lengths)])
# v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)])
return (v, arcs[1:], alpha)

def _check_potentials(self, arc_scores, lengths=None):
Expand Down Expand Up @@ -116,109 +152,109 @@ def _dp_backward(self, arc_scores, lengths, alpha_in, v=None, force_grad=False):
for _ in range(2)
]

def stack(a, b):
return torch.stack([a, b], dim=-1)
def stack(a, b, dim=-1):
return torch.stack([a, b], dim=dim)

def sstack(a):
return torch.stack([a, a], dim=-1)

for k in range(N - 1, -1, -1):
# Initialize
for b, l in enumerate(lengths):
alpha[A][C][R, b, 0, l] = semiring.one()
alpha[B][C][R, b, l, N - l - 1] = semiring.one()

# R completes
# I -> C* C
# I -> C* C
# C -> I C*
a = semiring.dot(
*roll(
stack(alpha[A][I][R], alpha[A][I][L]),
sstack(alpha_in[A][C][L]),
N,
k,
1,
if N - k - 1 > 0:
# R completes
# I -> C* C
# I -> C* C
# C -> I C*
a = semiring.sum(
semiring.times(
*roll2(
stack(
stack(alpha[A][I][R], alpha[A][I][L]),
sstack(alpha_in[B][C][R]),
dim=0,
),
stack(
sstack(alpha_in[A][C][L]),
stack(alpha[B][I][L], alpha[B][I][R]),
dim=0,
),
N,
k,
1,
)
).view(2, batch, N - k - 1, -1),
dim=-1,
)
)

c = semiring.dot(*roll(alpha_in[B][I][R], alpha[B][C][R], N, k, 0))

alpha[A][C][R, :, : N - k - 1, k] = semiring.plus(
semiring.sum(a), alpha[A][C][R, :, : N - k - 1, k]
)
alpha[A][C][L, :, 1 : N - k, k] = a[1]
alpha[A][C][R, :, : N - k - 1, k] = a[0]

alpha[A][C][R][:, : N - k, k] = semiring.plus(
alpha[A][C][R][:, : N - k, k], c
)

# L completes
# I -> C* C
# I -> C* C
# C -> I C*
a = semiring.dot(
*roll(
sstack(alpha_in[B][C][R]),
stack(alpha[B][I][L], alpha[B][I][R]),
N,
k,
1,
for b, l in enumerate(lengths):
if l == k:
alpha[A][C][R, b, 0, l] = semiring.one()
alpha[B][C][R, b, l, N - l - 1] = semiring.one()

c = semiring.sum(
semiring.times(
*roll2(
stack(alpha[A][C][L], alpha_in[B][I][R], dim=0),
stack(alpha_in[A][I][L], alpha[B][C][R], dim=0),
N,
k,
0,
)
)
)

c = semiring.dot(*roll(alpha[A][C][L], alpha_in[A][I][L], N, k, 0))

alpha[A][C][L, :, 1 : N - k, k] = semiring.plus(
semiring.sum(a), alpha[A][C][L, :, 1 : N - k, k]
)
alpha[A][C][L][:, : N - k, k] = semiring.plus(
c, alpha[A][C][L][:, : N - k, k]
alpha[A][C][:, :, : N - k, k] = semiring.plus(
alpha[A][C][:, :, : N - k, k], c
)

# Compute reverses.
alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k]

if k > 0:
f = torch.arange(N - k), torch.arange(k, N)

# Incomplete
alpha[A][I][R][:, : N - k, k] = semiring.dot(
arc_scores[:, f[0], f[1]].unsqueeze(-1),
*roll(alpha[A][C][R], alpha_in[A][C][R], N, k)
alpha[A][I][:, :, : N - k, k] = semiring.dot(
stack(
arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]], dim=0
).unsqueeze(-1),
*roll2(
stack(alpha_in[B][C][L], alpha[A][C][R], dim=0),
stack(alpha[B][C][L], alpha_in[A][C][R], dim=0),
N,
k,
)
)

# C -> C I
alpha[A][I][L][:, : N - k, k] = semiring.dot(
arc_scores[:, f[1], f[0]].unsqueeze(-1),
*roll(alpha_in[B][C][L], alpha[B][C][L], N, k)
)

# Compute reverses
alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k]

v = alpha[A][C][R, :, 0, 0]
left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :])
right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :])
ret = torch.zeros(batch, N, N, dtype=left.dtype)
for k in torch.arange(N):
f = torch.arange(N - k), torch.arange(k, N)
ret[:, f[1], k] = left[:, k, f[0]]
ret[:, k, f[1]] = right[:, k, f[0]]

ret = torch.zeros(batch, N, N).type_as(left)
for k in range(N):
for d in range(N - k):
ret[:, k + d, k] = semiring.div_exp(
left[:, k, d] - arc_scores[:, k + d, k], v.view(batch)
)
ret[:, k, k + d] = semiring.div_exp(
right[:, k, d] - arc_scores[:, k, k + d], v.view(batch)
)
ret = semiring.div_exp(ret - arc_scores, v.view(batch, 1, 1))
return _unconvert(ret)

def _arrange_marginals(self, grads):
batch, N = grads[0][0].shape
N = N + 1
ret = torch.zeros(batch, N, N).cpu()

ret = torch.zeros(
batch, N, N, dtype=grads[0][0].dtype, device=grads[0][0].device
)
# for k in torch.arange(N):
# f = torch.arange(N - k), torch.arange(k, N)
# ret[:, f[1], k] = grad[L][:, k, f[0]]
# ret[:, k, f[1]] = grad[L][:, k, f[0]]

# ret = torch.zeros(batch, N, N).cpu()
for k, grad in enumerate(grads, 1):
f = torch.arange(N - k), torch.arange(k, N)
ret[:, f[0], f[1]] = grad[R].cpu()
ret[:, f[1], f[0]] = grad[L].cpu()
ret[:, f[0], f[1]] = grad[R]
ret[:, f[1], f[0]] = grad[L]
return _unconvert(ret)

@staticmethod
Expand Down
Loading

0 comments on commit a2f89c3

Please sign in to comment.