Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def sum(self, scores, lengths=None, force_grad=False, _autograd=True):
if _autograd or self.semiring is not LogSemiring:
return self._dp(scores, lengths)[0]
else:
return DPManual2.apply(self, *scores, lengths)
return DPManual2.apply(self, *scores, lengths=lengths)

def _dp(self, scores, lengths=None, force_grad=False):
terms, rules, roots = scores
Expand All @@ -67,14 +67,29 @@ def _dp(self, scores, lengths=None, force_grad=False):
term_use[:] = terms + 0.0
beta[A][:, :, 0, NT:] = term_use
beta[B][:, :, N - 1, NT:] = term_use
X_Y_Z = rules.view(batch, 1, NT, S, S)[:, :, :, :NT, :NT]
X_Y_Z1 = rules.view(batch, 1, NT, S, S)[:, :, :, :NT, NT:]
X_Y1_Z = rules.view(batch, 1, NT, S, S)[:, :, :, NT:, :NT]
X_Y1_Z1 = rules.view(batch, 1, NT, S, S)[:, :, :, NT:, NT:]
for w in range(1, N):
Y = beta[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1)
Z = beta[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S)
Y, Z = Y.clone(), Z.clone()
X_Y_Z = rules.view(batch, 1, NT, S, S)
rule_use[w - 1][:] = semiring.times(
Y = beta[A][:, : N - w, :w, :NT].view(batch, N - w, w, 1, NT, 1)
Z = beta[B][:, w:, N - w :, :NT].view(batch, N - w, w, 1, 1, NT)
rule_use[w - 1][:, :, :, :NT, :NT] = semiring.times(
semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z
)
Y = beta[A][:, : N - w, w - 1, :NT].view(batch, N - w, 1, NT, 1)
Z = beta[B][:, w:, N - 1, NT:].view(batch, N - w, 1, 1, T)
rule_use[w - 1][:, :, :, :NT, NT:] = semiring.times(Y, Z, X_Y_Z1)

Y = beta[A][:, : N - w, 0, NT:].view(batch, N - w, 1, T, 1)
Z = beta[B][:, w:, N - w, :NT].view(batch, N - w, 1, 1, NT)
rule_use[w - 1][:, :, :, NT:, :NT] = semiring.times(Y, Z, X_Y1_Z)

if w == 1:
Y = beta[A][:, : N - w, w - 1, NT:].view(batch, N - w, 1, T, 1)
Z = beta[B][:, w:, N - w, NT:].view(batch, N - w, 1, 1, T)
rule_use[w - 1][:, :, :, NT:, NT:] = semiring.times(Y, Z, X_Y1_Z1)

rulesmid = rule_use[w - 1].view(batch, N - w, NT, S * S)
span[w] = semiring.sum(rulesmid, dim=3)
beta[A][:, : N - w, w, :NT] = span[w]
Expand Down
198 changes: 99 additions & 99 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import itertools
from .helpers import _Struct, roll2
from .helpers import _Struct


def _convert(logits):
Expand Down Expand Up @@ -139,104 +139,104 @@ def _check_potentials(self, arc_scores, lengths=None):

return batch, N, lengths

def _dp_backward(self, arc_scores, lengths, alpha_in, v=None, force_grad=False):

# This function is super complicated.
semiring = self.semiring
arc_scores = _convert(arc_scores)
batch, N, lengths = self._check_potentials(arc_scores, lengths)
DIRS = 2

alpha = [
self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
for _ in range(2)
]

def stack(a, b, dim=-1):
return torch.stack([a, b], dim=dim)

def sstack(a):
return torch.stack([a, a], dim=-1)

for k in range(N - 1, -1, -1):
# Initialize
if N - k - 1 > 0:
# R completes
# I -> C* C
# I -> C* C
# C -> I C*
a = semiring.sum(
semiring.times(
*roll2(
stack(
stack(alpha[A][I][R], alpha[A][I][L]),
sstack(alpha_in[B][C][R]),
dim=0,
),
stack(
sstack(alpha_in[A][C][L]),
stack(alpha[B][I][L], alpha[B][I][R]),
dim=0,
),
N,
k,
1,
)
).view(2, batch, N - k - 1, -1),
dim=-1,
)
alpha[A][C][L, :, 1 : N - k, k] = a[1]
alpha[A][C][R, :, : N - k - 1, k] = a[0]

for b, l in enumerate(lengths):
if l == k:
alpha[A][C][R, b, 0, l] = semiring.one()
alpha[B][C][R, b, l, N - l - 1] = semiring.one()

c = semiring.sum(
semiring.times(
*roll2(
stack(alpha[A][C][L], alpha_in[B][I][R], dim=0),
stack(alpha_in[A][I][L], alpha[B][C][R], dim=0),
N,
k,
0,
)
)
)
alpha[A][C][:, :, : N - k, k] = semiring.plus(
alpha[A][C][:, :, : N - k, k], c
)

# Compute reverses.
alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k]

if k > 0:
f = torch.arange(N - k), torch.arange(k, N)
alpha[A][I][:, :, : N - k, k] = semiring.dot(
stack(
arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]], dim=0
).unsqueeze(-1),
*roll2(
stack(alpha_in[B][C][L], alpha[A][C][R], dim=0),
stack(alpha[B][C][L], alpha_in[A][C][R], dim=0),
N,
k,
)
)
alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k]

v = alpha[A][C][R, :, 0, 0]
left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :])
right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :])
ret = torch.zeros(batch, N, N, dtype=left.dtype)
for k in torch.arange(N):
f = torch.arange(N - k), torch.arange(k, N)
ret[:, f[1], k] = left[:, k, f[0]]
ret[:, k, f[1]] = right[:, k, f[0]]

ret = semiring.div_exp(ret - arc_scores, v.view(batch, 1, 1))
return _unconvert(ret)
# def _dp_backward(self, arc_scores, lengths, alpha_in, v=None, force_grad=False):

# # This function is super complicated and was just too slow to include
# semiring = self.semiring
# arc_scores = _convert(arc_scores)
# batch, N, lengths = self._check_potentials(arc_scores, lengths)
# DIRS = 2

# alpha = [
# self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
# for _ in range(2)
# ]

# def stack(a, b, dim=-1):
# return torch.stack([a, b], dim=dim)

# def sstack(a):
# return torch.stack([a, a], dim=-1)

# for k in range(N - 1, -1, -1):
# # Initialize
# if N - k - 1 > 0:
# # R completes
# # I -> C* C
# # I -> C* C
# # C -> I C*
# a = semiring.sum(
# semiring.times(
# *roll2(
# stack(
# stack(alpha[A][I][R], alpha[A][I][L]),
# sstack(alpha_in[B][C][R]),
# dim=0,
# ),
# stack(
# sstack(alpha_in[A][C][L]),
# stack(alpha[B][I][L], alpha[B][I][R]),
# dim=0,
# ),
# N,
# k,
# 1,
# )
# ).view(2, batch, N - k - 1, -1),
# dim=-1,
# )
# alpha[A][C][L, :, 1 : N - k, k] = a[1]
# alpha[A][C][R, :, : N - k - 1, k] = a[0]

# for b, l in enumerate(lengths):
# if l == k:
# alpha[A][C][R, b, 0, l] = semiring.one()
# alpha[B][C][R, b, l, N - l - 1] = semiring.one()

# c = semiring.sum(
# semiring.times(
# *roll2(
# stack(alpha[A][C][L], alpha_in[B][I][R], dim=0),
# stack(alpha_in[A][I][L], alpha[B][C][R], dim=0),
# N,
# k,
# 0,
# )
# )
# )
# alpha[A][C][:, :, : N - k, k] = semiring.plus(
# alpha[A][C][:, :, : N - k, k], c
# )

# # Compute reverses.
# alpha[B][C][:, :, k:N, N - k - 1] = alpha[A][C][:, :, : N - k, k]

# if k > 0:
# f = torch.arange(N - k), torch.arange(k, N)
# alpha[A][I][:, :, : N - k, k] = semiring.dot(
# stack(
# arc_scores[:, f[1], f[0]], arc_scores[:, f[0], f[1]], dim=0
# ).unsqueeze(-1),
# *roll2(
# stack(alpha_in[B][C][L], alpha[A][C][R], dim=0),
# stack(alpha[B][C][L], alpha_in[A][C][R], dim=0),
# N,
# k,
# )
# )
# alpha[B][I][:, :, k:N, N - k - 1] = alpha[A][I][:, :, : N - k, k]

# v = alpha[A][C][R, :, 0, 0]
# left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :])
# right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :])
# ret = torch.zeros(batch, N, N, dtype=left.dtype)
# for k in torch.arange(N):
# f = torch.arange(N - k), torch.arange(k, N)
# ret[:, f[1], k] = left[:, k, f[0]]
# ret[:, k, f[1]] = right[:, k, f[0]]

# ret = semiring.div_exp(ret - arc_scores, v.view(batch, 1, 1))
# return _unconvert(ret)

def _arrange_marginals(self, grads):
batch, N = grads[0][0].shape
Expand Down
14 changes: 7 additions & 7 deletions torch_struct/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from torch.autograd import Function


def roll(a, b, N, k, gap=0):
return (a[:, : N - (k + gap), (k + gap) :], b[:, k + gap :, : N - (k + gap)])
# def roll(a, b, N, k, gap=0):
# return (a[:, : N - (k + gap), (k + gap) :], b[:, k + gap :, : N - (k + gap)])


def roll2(a, b, N, k, gap=0):
return (a[:, :, : N - (k + gap), (k + gap) :], b[:, :, k + gap :, : N - (k + gap)])
# def roll2(a, b, N, k, gap=0):
# return (a[:, :, : N - (k + gap), (k + gap) :], b[:, :, k + gap :, : N - (k + gap)])


class _Struct:
Expand All @@ -22,8 +22,7 @@ def score(self, potentials, parts):
def _make_chart(self, N, size, potentials, force_grad=False):
return [
(
torch.zeros(*size)
.type_as(potentials)
torch.zeros(*size, dtype=potentials.dtype, device=potentials.device)
.fill_(self.semiring.zero())
.requires_grad_(force_grad and not potentials.requires_grad)
)
Expand Down Expand Up @@ -76,12 +75,12 @@ def marginals(self, edge, lengths=None, _autograd=True):
marginals: b x (N-1) x C x C table

"""
v, edges, alpha = self._dp(edge, lengths=lengths, force_grad=True)
if (
_autograd
or self.semiring is not LogSemiring
or not hasattr(self, "_dp_backward")
):
v, edges, _ = self._dp(edge, lengths=lengths, force_grad=True)
marg = torch.autograd.grad(
v.sum(dim=0),
edges,
Expand All @@ -91,4 +90,5 @@ def marginals(self, edge, lengths=None, _autograd=True):
)
return self._arrange_marginals(marg)
else:
v, _, alpha = self._dp(edge, lengths=lengths, force_grad=True)
return self._dp_backward(edge, lengths, alpha)
12 changes: 6 additions & 6 deletions torch_struct/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ def test_generic_a(data):
print(alpha, count)
assert torch.isclose(count[0], alpha[0])

# vals, _ = model._rand()
# struct = model(MaxSemiring)
# score = struct.sum(vals)
# marginals = struct.marginals(vals)
# assert torch.isclose(score, struct.score(vals, marginals)).all()
vals, _ = model._rand()
struct = model(MaxSemiring)
score = struct.sum(vals)
marginals = struct.marginals(vals)
assert torch.isclose(score, struct.score(vals, marginals)).all()


@given(data(), integers(min_value=1, max_value=10))
Expand Down Expand Up @@ -114,7 +114,7 @@ def test_generic_lengths(data, seed):
@given(data(), integers(min_value=1, max_value=10))
def test_params(data, seed):
model = data.draw(
sampled_from([DepTree])
sampled_from([DepTree, CKY])
) # LinearChain, SemiMarkov, DepTree, CKY]))
struct = model()
torch.manual_seed(seed)
Expand Down