Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 19, 2019
1 parent e70201c commit a64688e
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 126 deletions.
106 changes: 45 additions & 61 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from .helpers import _Struct
from .helpers import _Struct, Chart

A, B = 0, 1

Expand All @@ -18,93 +18,77 @@ def _dp(self, scores, lengths=None, force_grad=False):
S = NT + T

terms, rules, roots = (
semiring.convert(terms),
semiring.convert(rules),
semiring.convert(roots),
semiring.convert(terms).requires_grad_(True),
semiring.convert(rules).requires_grad_(True),
semiring.convert(roots).requires_grad_(True),
)
if lengths is None:
lengths = torch.LongTensor([N] * batch)

# Charts
beta = self._make_chart(2, (batch, N, N, NT), rules, force_grad)
span = self._make_chart(N, (batch, N, NT), rules, force_grad)

# Terminals and Tops
top = self._chart((batch, NT), rules, force_grad)
term_use = self._chart((batch, N, T), terms, force_grad)
term_use[:] = terms + 0.0
beta = [Chart((batch, N, N, NT), rules, semiring)
for _ in range(2)]
span = [None for _ in range(N)]
v = (ssize, batch)
term_use = terms + 0.0

# Split into NT/T groups
NTs = slice(0, NT)
Ts = slice(NT, S)
rules = rules.view(ssize, batch, 1, NT, S, S)
X_Y_Z = (
rules[..., NTs, NTs]
.contiguous()
.view(ssize, batch, NT, NT * NT)
.transpose(-2, -1)
)
X_Y1_Z = (
rules[..., Ts, NTs]
.contiguous()
.view(ssize, batch, NT, T * NT)
.transpose(-2, -1)
)
X_Y_Z1 = (
rules[..., NTs, Ts]
.contiguous()
.view(ssize, batch, NT, NT * T)
.transpose(-2, -1)
)
X_Y1_Z1 = (
rules[..., Ts, Ts]
.contiguous()
.view(ssize, batch, NT, T * T)
.transpose(-2, -1)
)
def arr(a, b):
return rules[..., a, b] \
.contiguous() \
.view(*v + (NT, -1)) \
.transpose(-2, -1)

matmul = semiring.matmul
times = semiring.times
X_Y_Z = arr(NTs, NTs)
X_Y1_Z = arr(Ts, NTs)
X_Y_Z1 = arr(NTs, Ts)
X_Y1_Z1 = arr(Ts, Ts)

# Here
for w in range(1, N):
all_span = []
v2 = v +(N - w, -1)

Y = beta[A][..., : N - w, :w, :].transpose(-2, -1)
Z = beta[B][..., w:, N - w :, :]
all_span.append(
semiring.matmul(
semiring.matmul(Y, Z).view(ssize, batch, N - w, NT * NT), X_Y_Z
)
Y = beta[A][: N - w, :w, :]
Z = beta[B][w:, N - w :, :]
X1 = matmul(
matmul(Y.transpose(-2, -1), Z).view(*v2), X_Y_Z
)
all_span.append(X1)

Y_term = term_use[..., : N - w, :, None]
Z_term = term_use[..., w:, None, :]

Y = beta[A][..., : N - w, w - 1, :, None]
all_span.append(
semiring.matmul(
semiring.times(Y, Z_term).view(ssize, batch, N - w, T * NT), X_Y_Z1
Y = Y[...,-1, :].unsqueeze(-1)
X2 = matmul(
times(Y, Z_term).view(*v2), X_Y_Z1
)
)

Z = beta[B][..., w:, N - w, None, :]
all_span.append(
semiring.matmul(
semiring.times(Y_term, Z).view(ssize, batch, N - w, NT * T), X_Y1_Z
Z = Z[..., 0, :].unsqueeze(-2)
X3 = matmul(
times(Y_term, Z).view(*v2), X_Y1_Z
)
)
all_span += [X2, X3]

if w == 1:
all_span.append(
semiring.matmul(
semiring.times(Y_term, Z_term).view(ssize, batch, N - w, T * T),
X_Y1_Z1,
)
X4 = matmul(
times(Y_term, Z_term).view(*v2),
X_Y1_Z1,
)
all_span.append(X4)

span[w] = semiring.sum(torch.stack(all_span, dim=-1))
beta[A][..., : N - w, w, :] = span[w]
beta[B][..., w:N, N - w - 1, :] = beta[A][..., : N - w, w, :]
beta[A][: N - w, w, :] = span[w]
beta[B][w:N, N - w - 1, :] = span[w]

top[:] = torch.stack(
[beta[A][:, i, 0, l - 1, NTs] for i, l in enumerate(lengths)], dim=1
final = beta[A][0, :, NTs]
top = torch.stack(
[final[:, i, l-1]
for i, l in enumerate(lengths)], dim=1
)
log_Z = semiring.dot(top, roots)
return semiring.unconvert(log_Z), (term_use, rules, top, span[1:]), beta
Expand Down
7 changes: 4 additions & 3 deletions torch_struct/cky_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ def _dp(self, scores, lengths=None, force_grad=False):
scores, batch, N, NT, lengths = self._check_potentials(scores, lengths)


beta = [Chart((batch, N, N), scores, semiring) for _ in range(2)]
beta = [Chart((batch, N, N), scores, semiring)
for _ in range(2)]
L_DIM, R_DIM = 2, 3

# Initialize
reduced_scores = semiring.sum(scores)
term = reduced_scores.diagonal(0, L_DIM, R_DIM)

ns = torch.arange(N)
beta[A][ns, 0] = term
beta[B][ns, N-1] = term

Expand All @@ -70,7 +71,7 @@ def _dp(self, scores, lengths=None, force_grad=False):
beta[A][left, w] = new
beta[B][right, N - w - 1] = new

final = beta[A][0, I]
final = beta[A][0, :]
log_Z = final[:, torch.arange(batch), lengths - 1]
return log_Z, [scores], beta

Expand Down
93 changes: 36 additions & 57 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import itertools
from .helpers import _Struct
from .helpers import _Struct, Chart


def _convert(logits):
Expand Down Expand Up @@ -50,69 +50,48 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False):
semiring = self.semiring
arc_scores = _convert(arc_scores_in)
arc_scores, batch, N, lengths = self._check_potentials(arc_scores, lengths)

arc_scores.requires_grad_(True)
DIRS = 2
alpha = [
self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
for _ in range(2)
]

def stack(a, b):
return torch.stack([a, b], dim=1)

def sstack(a):
return torch.stack([a, a], dim=1)

# Inside step. assumes first token is root symbol
semiring.one_(alpha[A][C][:, :, :, :, 0].data)
semiring.one_(alpha[B][C][:, :, :, :, -1].data)
k = 0

AIR = alpha[A][I][:, R, :, : N - k, 1:k]
BIL = alpha[B][I][:, L, :, k:N, N - k : N - 1]
k = 1
AC2 = alpha[A][C][:, :, :, : N - k, :k]
BC2 = alpha[B][C][:, :, :, k:, N - k :]
AC, BC, AC_next = None, None, None

ends = [None]
for k in range(1, N):

def tf(a):
return torch.narrow(a, 3, 0, N - k)
alpha = [[Chart((batch, DIRS, N, N), arc_scores, semiring)
for _ in range(2)] for _ in range(2)]
semiring.one_(alpha[A][C].data[:, :, :, :, 0].data)
semiring.one_(alpha[B][C].data[:, :, :, :, -1].data)
def stack(a, b=None):
if b is None:
return torch.stack([a, a], dim=2)
else:
return torch.stack([a, b], dim=2)

def tb(a):
return torch.narrow(a, 3, 1, N - k)

for k in range(1, N):
f = torch.arange(N - k), torch.arange(k, N)
if k > 1:
AC2 = torch.cat([tf(AC), tf(AC_next).unsqueeze(-1)], dim=4)
if k > 1:
BC2 = torch.cat([tb(AC_next).unsqueeze(-1), tb(BC)], dim=4)

ACL, ACR = AC2.unbind(dim=1)
BCL, BCR = BC2.unbind(dim=1)
start = semiring.dot(BCL, ACR)

arcsL = semiring.times(start, arc_scores[:, :, f[1], f[0]])
arcsR = semiring.times(start, arc_scores[:, :, f[0], f[1]])

AIR2 = torch.cat(
[torch.narrow(AIR, 2, 0, N - k), arcsR.unsqueeze(-1)], dim=3
AC = alpha[A][C][:, : N - k, :k]
ACL, ACR = AC.unbind(2)

BC = alpha[B][C][:, k:, N - k :]
BCL, BCR = BC.unbind(2)
arcs = semiring.dot(
semiring.times(
stack(ACR), stack(BCL)),
stack(arc_scores[:, :, f[1], f[0]],
arc_scores[:, :, f[0], f[1]]).unsqueeze(-1),
)
BIL2 = torch.cat(
[arcsL.unsqueeze(-1), torch.narrow(BIL, 2, 1, N - k)], dim=3
alpha[A][I][:, : N - k, k] = arcs
alpha[B][I][:, k:N, N - k - 1] = arcs

AIR = alpha[A][I][R, : N - k, 1 : k + 1]
BIL = alpha[B][I][L, k:, N - k - 1 : N - 1]
new = semiring.dot(
stack(ACL, AIR),
stack(BIL, BCR),
)
AC_next = stack(semiring.dot(ACL, BIL2), semiring.dot(AIR2, BCR))

ends.append(AC_next[:, R, :, 0])
AC = AC2
BC = BC2
AIR = AIR2
BIL = BIL2
v = torch.stack([ends[l][:, i] for i, l in enumerate(lengths)], dim=1)
return (v, [arc_scores], alpha)
alpha[A][C][:, : N - k, k] = new
alpha[B][C][:, k:N, N - k - 1] = new

final = alpha[A][C][R, 0]
v = torch.stack([final[:, i, l] for i, l in enumerate(lengths)], dim=1)
return v, [arc_scores], alpha


def _check_potentials(self, arc_scores, lengths=None):
semiring = self.semiring
Expand Down
17 changes: 13 additions & 4 deletions torch_struct/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
class Get(torch.autograd.Function):
@staticmethod
def forward(ctx, chart, grad_chart, indices):
out = chart[indices]
ctx.save_for_backward(grad_chart)
out = chart[indices]
ctx.indices = indices
return out

Expand All @@ -32,7 +32,8 @@ def backward(ctx, grad_output):


class Chart:
def __init__(self, size, potentials, semiring):
def __init__(self, size, potentials, semiring,
cache=True):
self.data = semiring.zero_(
torch.zeros(
*((semiring.size(),) + size),
Expand All @@ -41,13 +42,21 @@ def __init__(self, size, potentials, semiring):
)
)
self.grad = self.data.detach().clone().fill_(0.0)
self.cache = cache

def __getitem__(self, ind):
I = slice(None)
return Get.apply(self.data, self.grad, (I, I) + ind)
if self.cache:
return Get.apply(self.data, self.grad, (I, I) + ind)
else:
return self.data[(I, I) + ind]

def __setitem__(self, ind, new):
I = slice(None)
self.data = Set.apply(self.data, (I, I) + ind, new)
if self.cache:
self.data = Set.apply(self.data, (I, I) + ind, new)
else:
self.data[(I, I) + ind] = new


def get(self, ind):
Expand Down
5 changes: 4 additions & 1 deletion torch_struct/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def test_generic_a(data):
struct = model(MaxSemiring)
score = struct.sum(vals)
marginals = struct.marginals(vals)
# print(marginals)
# # assert(False)
assert torch.isclose(score, struct.score(vals, marginals)).all()


Expand Down Expand Up @@ -247,7 +249,7 @@ def test_parts_from_sequence(data, seed):
@settings(max_examples=50, deadline=None)
def test_generic_lengths(data, seed):
model = data.draw(
sampled_from([CKY_CRF])#, Alignment, LinearChain, SemiMarkov, CKY, DepTree])
sampled_from([CKY, Alignment, LinearChain, SemiMarkov, CKY_CRF, DepTree])
)
struct = model()
torch.manual_seed(seed)
Expand All @@ -259,6 +261,7 @@ def test_generic_lengths(data, seed):
m = model(MaxSemiring).marginals(vals, lengths=lengths)
maxes = struct.score(vals, m)
part = model().sum(vals, lengths=lengths)
print(maxes, part)
assert (maxes <= part).all()
m_part = model(MaxSemiring).sum(vals, lengths=lengths)
assert (torch.isclose(maxes, m_part)).all(), maxes - m_part
Expand Down

0 comments on commit a64688e

Please sign in to comment.