Skip to content

Commit

Permalink
outside
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 8, 2019
1 parent d60fed3 commit 6b21221
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 64 deletions.
130 changes: 76 additions & 54 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import itertools
from torch.autograd import Function
from .helpers import _Struct, roll, roll2
from .helpers import _Struct, roll2


def _convert(logits):
Expand Down Expand Up @@ -33,18 +32,6 @@ def _unconvert(logits):
# Constants
A, B, R, C, L, I = 0, 1, 1, 1, 0, 0

class MySlice(Function):
@staticmethod
def forward(ctx, input, e, a, b, c, d):
output = input.clone().zero_()
ctx.save_for_backward(output, torch.tensor([e, a, b, c, d]))
return input[e, :, a:b, c:d]

@staticmethod
def backward(ctx, grad_v):
output, a = ctx.saved_tensors
output[a[0], :, a[1]:a[2], a[3]:a[4]] = grad_v
return output, None, None, None, None, None

class DepTree(_Struct):
"""
Expand All @@ -60,19 +47,35 @@ def _dp(self, arc_scores, lengths=None, force_grad=False):
batch, N, lengths = self._check_potentials(arc_scores, lengths)

DIRS = 2
s = MySlice.apply
#def s(input, e, a, b, c, d):
# return input[e, :, a:b, c:d]

alpha = [
self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
for _ in range(2)
]
# Want to fix this slicing function.
# class MySlice(Function):
# @staticmethod
# def forward(ctx, alpha, beta, s1, s2, e, a, b, c, d):
# indices = torch.tensor([s1, s2, e, a, b, c, d])
# ctx.save_for_backward(indices)
# return alpha[e, :, a:b, c:d]

# @staticmethod
# def backward(ctx, grad_v):
# a, = ctx.saved_tensors
# beta[a[0]][a[1]][a[2], :, a[3]:a[4], a[5]:a[6]] += grad_v
# return None, None, None, None, None, None, None, None, None

# s = MySlice.apply
def s(input, e, a, b, c, d):
return input[e, :, a:b, c:d]

def stack(a, b):
return torch.stack([a, b])

def sstack(a):
return torch.stack([a, a])

alpha = [
self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
for _ in range(2)
]
arcs = self._make_chart(N, (DIRS, batch, N), arc_scores, force_grad)

# Inside step. assumes first token is root symbol
Expand All @@ -82,25 +85,30 @@ def sstack(a):
for k in range(1, N):
f = torch.arange(N - k), torch.arange(k, N)
arcs[k] = semiring.times(
sstack(semiring.sum(semiring.times(
s(alpha[A][C], R, 0, N - k, 0, k),
s(alpha[B][C], L, k, N, N - k, N)))),
stack(arc_scores[:, f[1], f[0]],
arc_scores[:, f[0], f[1]])
sstack(
semiring.sum(
semiring.times(
s(alpha[A][C], R, 0, N - k, 0, k),
s(alpha[B][C], L, k, N, N - k, N),
)
)
),
stack(arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]]),
)
alpha[A][I][:, :, : N - k, k] = arcs[k]
alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k]
alpha[A][C][:, :, : N - k, k] = semiring.dot(
stack(
s(alpha[A][C], L,0, N - k, 0, k),
s(alpha[A][I], R,0, N - k, 1, k + 1),
s(alpha[A][C], L, 0, N - k, 0, k),
s(alpha[A][I], R, 0, N - k, 1, k + 1),
),
stack(
s(alpha[B][I], L, k, N, N - k - 1, N - 1),
s(alpha[B][C], R, k, N, N - k, N),
),
)
alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k]

v = torch.stack([alpha[A][C][R, i, 0, l] for i, l in enumerate(lengths)])
return (v, arcs[1:], alpha)

Expand Down Expand Up @@ -143,16 +151,25 @@ def sstack(a):
# I -> C* C
# I -> C* C
# C -> I C*
a = semiring.sum(semiring.times(
*roll2(
stack(stack(alpha[A][I][R], alpha[A][I][L]),
sstack(alpha_in[B][C][R]), dim=0),
stack(sstack(alpha_in[A][C][L]),
stack(alpha[B][I][L], alpha[B][I][R]), dim=0),
N,
k,
1,
)).view(2, batch, N-k-1, -1), dim=-1
a = semiring.sum(
semiring.times(
*roll2(
stack(
stack(alpha[A][I][R], alpha[A][I][L]),
sstack(alpha_in[B][C][R]),
dim=0,
),
stack(
sstack(alpha_in[A][C][L]),
stack(alpha[B][I][L], alpha[B][I][R]),
dim=0,
),
N,
k,
1,
)
).view(2, batch, N - k - 1, -1),
dim=-1,
)
alpha[A][C][L, :, 1 : N - k, k] = a[1]
alpha[A][C][R, :, : N - k - 1, k] = a[0]
Expand All @@ -162,13 +179,17 @@ def sstack(a):
alpha[A][C][R, b, 0, l] = semiring.one()
alpha[B][C][R, b, l, N - l - 1] = semiring.one()

c = semiring.sum(semiring.times(
*roll2(stack(alpha[A][C][L],
alpha_in[B][I][R],
dim=0),
stack(alpha_in[A][I][L],
alpha[B][C][R], dim=0),
N, k, 0)))
c = semiring.sum(
semiring.times(
*roll2(
stack(alpha[A][C][L], alpha_in[B][I][R], dim=0),
stack(alpha_in[A][I][L], alpha[B][C][R], dim=0),
N,
k,
0,
)
)
)
alpha[A][C][:, :, : N - k, k] = semiring.plus(
alpha[A][C][:, :, : N - k, k], c
)
Expand All @@ -179,17 +200,18 @@ def sstack(a):
if k > 0:
f = torch.arange(N - k), torch.arange(k, N)
alpha[A][I][:, :, : N - k, k] = semiring.dot(
stack(arc_scores[:, f[1], f[0]],
arc_scores[:, f[0], f[1]], dim=0).unsqueeze(-1),
*roll2(stack(alpha_in[B][C][L],
alpha[A][C][R], dim=0),
stack(alpha[B][C][L],
alpha_in[A][C][R], dim=0),
N, k)
stack(
arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]], dim=0
).unsqueeze(-1),
*roll2(
stack(alpha_in[B][C][L], alpha[A][C][R], dim=0),
stack(alpha[B][C][L], alpha_in[A][C][R], dim=0),
N,
k,
)
)
alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k]


v = alpha[A][C][R, :, 0, 0]
left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :])
right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :])
Expand All @@ -199,7 +221,7 @@ def sstack(a):
ret[:, f[1], k] = left[:, k, f[0]]
ret[:, k, f[1]] = right[:, k, f[0]]

ret = semiring.div_exp(ret - arc_scores, v.view(batch, 1, 1))
ret = semiring.div_exp(ret - arc_scores, v.view(batch, 1, 1))
return _unconvert(ret)

def _arrange_marginals(self, grads):
Expand Down
9 changes: 8 additions & 1 deletion torch_struct/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
def roll(a, b, N, k, gap=0):
return (a[:, : N - (k + gap), (k + gap) :], b[:, k + gap :, : N - (k + gap)])


def roll2(a, b, N, k, gap=0):
return (a[:, :, : N - (k + gap), (k + gap) :], b[:, :, k + gap :, : N - (k + gap)])

Expand Down Expand Up @@ -33,7 +34,13 @@ def backward(ctx, grad_v):
with torch.no_grad():
marginals = ctx.obj._dp_backward(input, ctx.lengths, ctx.alpha)

return None, marginals.mul(grad_v.view((grad_v.shape[0],) + tuple([1]*marginals.dim()))), None
return (
None,
marginals.mul(
grad_v.view((grad_v.shape[0],) + tuple([1] * marginals.dim()))
),
None,
)


class _Struct:
Expand Down
15 changes: 6 additions & 9 deletions torch_struct/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,11 @@ def test_fb(data):
)
marginals2 = model().marginals(vals, lengths=lengths, _autograd=True)
v, _, alpha = model()._dp(vals, lengths=lengths)
print(v)
print(marginals2)
marginals = model()._dp_backward(vals, lengths, alpha, v)

if isinstance(marginals, tuple):
for i, (m1, m2) in enumerate(zip(marginals[:], marginals2[:])):
assert torch.isclose(m1, m2).all(), (
not torch.isclose(m1, m2)
).nonzero()
assert torch.isclose(m1, m2).all(), (not torch.isclose(m1, m2)).nonzero()
else:
assert torch.isclose(marginals, marginals2).all()

Expand Down Expand Up @@ -128,15 +124,16 @@ def test_params(data, seed):
# torch.autograd.set_detect_anomaly(True)
semiring = LogSemiring
alpha = model(semiring).sum(vals)
x = alpha.sum().backward()
alpha.sum().backward()

if not isinstance(vals, tuple):
b = vals.grad.detach()
vals.grad.zero_()
alpha = model(semiring).sum(vals, _autograd=False)
x2 = alpha.sum().backward()
alpha.sum().backward()
c = vals.grad.detach()
assert(torch.isclose(b, c).all())
assert torch.isclose(b, c).all()


def test_hmm():
C, V, batch, N = 5, 20, 2, 5
Expand Down

0 comments on commit 6b21221

Please sign in to comment.