Skip to content

Commit

Permalink
Merge 2d87ed0 into 6a55459
Browse files Browse the repository at this point in the history
  • Loading branch information
srush authored Nov 26, 2019
2 parents 6a55459 + 2d87ed0 commit 22d5c2d
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 50 deletions.
63 changes: 34 additions & 29 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,43 +51,48 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False):
arc_scores = _convert(arc_scores_in)
arc_scores, batch, N, lengths = self._check_potentials(arc_scores, lengths)
arc_scores.requires_grad_(True)
DIRS = 2
alpha = [
[Chart((batch, DIRS, N, N), arc_scores, semiring) for _ in range(2)]
[
[Chart((batch, N, N), arc_scores, semiring) for _ in range(2)]
for _ in range(2)
]
for _ in range(2)
]
semiring.one_(alpha[A][C].data[:, :, :, :, 0].data)
semiring.one_(alpha[B][C].data[:, :, :, :, -1].data)

def stack(a, b=None):
if b is None:
return torch.stack([a, a], dim=2)
else:
return torch.stack([a, b], dim=2)
semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)

for k in range(1, N):
f = torch.arange(N - k), torch.arange(k, N)
AC = alpha[A][C][:, : N - k, :k]
ACL, ACR = AC.unbind(2)

BC = alpha[B][C][:, k:, N - k :]
BCL, BCR = BC.unbind(2)
arcs = semiring.dot(
semiring.times(stack(ACR), stack(BCL)),
stack(
arc_scores[:, :, f[1], f[0]], arc_scores[:, :, f[0], f[1]]
).unsqueeze(-1),
)
alpha[A][I][:, : N - k, k] = arcs
alpha[B][I][:, k:N, N - k - 1] = arcs
ACL = alpha[A][C][L][: N - k, :k]
ACR = alpha[A][C][R][: N - k, :k]

BCL = alpha[B][C][L][k:, N - k :]
BCR = alpha[B][C][R][k:, N - k :]
x = semiring.dot(ACR, BCL)

arcs_l = semiring.times(x, arc_scores[:, :, f[1], f[0]])

alpha[A][I][L][: N - k, k] = arcs_l
alpha[B][I][L][k:N, N - k - 1] = arcs_l

arcs_r = semiring.times(x, arc_scores[:, :, f[0], f[1]])
alpha[A][I][R][: N - k, k] = arcs_r
alpha[B][I][R][k:N, N - k - 1] = arcs_r

AIR = alpha[A][I][R][: N - k, 1 : k + 1]
BIL = alpha[B][I][L][k:, N - k - 1 : N - 1]

new = semiring.dot(ACL, BIL)
alpha[A][C][L][: N - k, k] = new
alpha[B][C][L][k:N, N - k - 1] = new

AIR = alpha[A][I][R, : N - k, 1 : k + 1]
BIL = alpha[B][I][L, k:, N - k - 1 : N - 1]
new = semiring.dot(stack(ACL, AIR), stack(BIL, BCR))
alpha[A][C][:, : N - k, k] = new
alpha[B][C][:, k:N, N - k - 1] = new
new = semiring.dot(AIR, BCR)
alpha[A][C][R][: N - k, k] = new
alpha[B][C][R][k:N, N - k - 1] = new

final = alpha[A][C][R, 0]
final = alpha[A][C][R][(0,)]
v = torch.stack([final[:, i, l] for i, l in enumerate(lengths)], dim=1)
return v, [arc_scores], alpha

Expand Down
34 changes: 27 additions & 7 deletions torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def _check_potentials(self, edge, lengths=None):

if lengths is None:
lengths = torch.LongTensor([N] * batch)

assert max(lengths) <= N, "Length longer than edge scores"
assert max(lengths) == N, "One length must be at least N"
# pass
else:
assert max(lengths) <= N, "Length longer than edge scores"
assert max(lengths) == N, "One length must be at least N"
assert C == C2, "Transition shape doesn't match"
return edge, batch, N, C, lengths

Expand All @@ -49,17 +50,36 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
"Compute forward pass by linear scan"
# Setup
semiring = self.semiring
ssize = semiring.size()
log_potentials, batch, N, C, lengths = self._check_potentials(
log_potentials, lengths
)
log_N, bin_N = self._bin_length(N - 1)
chart = self._chart((batch, bin_N, C, C), log_potentials, force_grad)

# Init
for b in range(lengths.shape[0]):
end = lengths[b] - 1
semiring.one_(chart[:, b, end:].diagonal(0, 2, 3))
chart[:, b, :end] = log_potentials[:, b, :end]
semiring.one_(chart[:, :, :].diagonal(0, 3, 4))

# Length mask
big = torch.zeros(
ssize,
batch,
bin_N,
C,
C,
dtype=log_potentials.dtype,
device=log_potentials.device,
)
big[:, :, : N - 1] = log_potentials
c = chart[:, :, :].view(ssize, batch * bin_N, C, C)
lp = big[:, :, :].view(ssize, batch * bin_N, C, C)
mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask.view(batch * bin_N, 1, 1).to(lp.device)
semiring.zero_mask_(lp.data, mask)
semiring.zero_mask_(c.data, (~mask))

c[:] = semiring.sum(torch.stack([c.data, lp], dim=-1))

# Scan
for n in range(1, log_N + 1):
Expand Down
43 changes: 34 additions & 9 deletions torch_struct/semimarkov.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):

# Setup
semiring = self.semiring
ssize = semiring.size()
log_potentials.requires_grad_(True)
log_potentials, batch, N, K, C, lengths = self._check_potentials(
log_potentials, lengths
Expand All @@ -33,13 +34,37 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
)

# Init.
for b in range(lengths.shape[0]):
end = lengths[b] - 1
semiring.one_(init[:, b, end:, 0, 0].diagonal(0, 2, 3))
init[:, b, :end, : (K - 1), 0] = log_potentials[:, b, :end, 1:K]
for k in range(1, K - 1):
semiring.one_(init[:, b, : end - (k - 1), k - 1, k].diagonal(0, 2, 3))
K_1 = K - 1
# semiring.one_(init.data[:, :, :, 0, 0].diagonal(0, -2, -1))

# Length mask
# big = torch.zeros(
# ssize,
# batch,
# bin_N,
# K,
# C,
# C,
# dtype=log_potentials.dtype,
# device=log_potentials.device,
# )
# big[:, :, : N - 1] = log_potentials
# c = init[:, :, :].view(ssize, batch * bin_N, K - 1, K - 1, C, C)
# lp = big[:, :, :].view(ssize, batch * bin_N, K, C, C)
# mask = torch.arange(bin_N) \
# .view(1, bin_N).expand(batch, bin_N)
# mask = mask >= (lengths - 1).view(batch, 1)
# mask = mask.view(batch * bin_N, 1, 1, 1).to(lp.device)
# semiring.zero_mask_(lp.data, mask)
# semiring.zero_mask_(c.data[:, :, :, 0], (~mask))
# c[:, :, : K - 1, 0] = semiring.sum(
# torch.stack([c.data[:, :, : K - 1, 0],
# lp[:, :, 1:K]], dim=-1)
# )
# end = torch.min(lengths) - 1
# for k in range(1, K - 1):
# semiring.one_(init.data[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1))

# K_1 = K - 1

# Order n, n-1
chart = (
Expand All @@ -48,8 +73,8 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
.view(-1, batch, bin_N, K_1 * C, K_1 * C)
)

for n in range(1, log_N + 1):
chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2])
# for n in range(1, log_N + 1):
# chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2])

final = chart.view(-1, batch, 1, K_1, C, K_1, C)
v = semiring.sum(semiring.sum(final[:, :, 0, 0, :, 0, :]))
Expand Down
33 changes: 29 additions & 4 deletions torch_struct/semirings/semirings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import torch
import genbmm

has_genbmm = False
try:
import genbmm
has_genbmm = True
except ImportError:
pass


def matmul(cls, a, b):
Expand Down Expand Up @@ -63,6 +69,17 @@ def zero_(xs):
"Fill *ssize x ...* tensor with additive identity."
raise NotImplementedError()

@classmethod
def zero_mask_(cls, xs, mask):
"Fill *ssize x ...* tensor with additive identity."
# xs.masked_fill_(mask[0], cls.zero)
# print(mask
# xs.masked_fill_(mask[0], cls.zero)
# xs[0, mask] = cls.zero
xs[0].masked_fill_(mask, cls.zero)
# print(mask.shape, xs.shape)
# assert False

@staticmethod
def one_(xs):
"Fill *ssize x ...* tensor with multiplicative identity."
Expand Down Expand Up @@ -143,7 +160,7 @@ def matmul(cls, a, b, dims=1):
(Faster than calling sum and times.)
"""

if isinstance(a, genbmm.BandedMatrix):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
return b.multiply(a.transpose())
else:
return torch.matmul(a, b)
Expand All @@ -158,7 +175,7 @@ class LogSemiring(_BaseLog):

@classmethod
def matmul(cls, a, b):
if isinstance(a, genbmm.BandedMatrix):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
return b.multiply_log(a.transpose())
else:
return _BaseLog.matmul(a, b)
Expand All @@ -173,7 +190,7 @@ class MaxSemiring(_BaseLog):

@classmethod
def matmul(cls, a, b):
if isinstance(a, genbmm.BandedMatrix):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
return b.multiply_max(a.transpose())
else:
return matmul(cls, a, b)
Expand Down Expand Up @@ -269,6 +286,8 @@ class EntropySemiring(Semiring):
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
"""

zero = 0

@staticmethod
def size():
return 2
Expand Down Expand Up @@ -301,6 +320,12 @@ def mul(a, b):
def prod(cls, xs, dim=-1):
return xs.sum(dim)

@classmethod
def zero_mask_(cls, xs, mask):
"Fill *ssize x ...* tensor with additive identity."
xs[0].masked_fill_(mask, -1e5)
xs[1].masked_fill_(mask, 0)

@staticmethod
def zero_(xs):
xs[0].fill_(-1e5)
Expand Down
4 changes: 3 additions & 1 deletion torch_struct/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def test_cky(data):
@settings(max_examples=50, deadline=None)
def test_generic_a(data):
model = data.draw(
sampled_from([Alignment]) # , LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree])
sampled_from(
[SemiMarkov]
) # , Alignment , LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree])
)
semiring = data.draw(sampled_from([LogSemiring, MaxSemiring]))
struct = model(semiring)
Expand Down

0 comments on commit 22d5c2d

Please sign in to comment.