Skip to content

Commit

Permalink
Speed ups for linear chain and semi-markov (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Nov 26, 2019
1 parent 6a55459 commit b3ce3b1
Show file tree
Hide file tree
Showing 5 changed files with 123 additions and 51 deletions.
63 changes: 34 additions & 29 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,43 +51,48 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False):
arc_scores = _convert(arc_scores_in)
arc_scores, batch, N, lengths = self._check_potentials(arc_scores, lengths)
arc_scores.requires_grad_(True)
DIRS = 2
alpha = [
[Chart((batch, DIRS, N, N), arc_scores, semiring) for _ in range(2)]
[
[Chart((batch, N, N), arc_scores, semiring) for _ in range(2)]
for _ in range(2)
]
for _ in range(2)
]
semiring.one_(alpha[A][C].data[:, :, :, :, 0].data)
semiring.one_(alpha[B][C].data[:, :, :, :, -1].data)

def stack(a, b=None):
if b is None:
return torch.stack([a, a], dim=2)
else:
return torch.stack([a, b], dim=2)
semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)

for k in range(1, N):
f = torch.arange(N - k), torch.arange(k, N)
AC = alpha[A][C][:, : N - k, :k]
ACL, ACR = AC.unbind(2)

BC = alpha[B][C][:, k:, N - k :]
BCL, BCR = BC.unbind(2)
arcs = semiring.dot(
semiring.times(stack(ACR), stack(BCL)),
stack(
arc_scores[:, :, f[1], f[0]], arc_scores[:, :, f[0], f[1]]
).unsqueeze(-1),
)
alpha[A][I][:, : N - k, k] = arcs
alpha[B][I][:, k:N, N - k - 1] = arcs
ACL = alpha[A][C][L][: N - k, :k]
ACR = alpha[A][C][R][: N - k, :k]

BCL = alpha[B][C][L][k:, N - k :]
BCR = alpha[B][C][R][k:, N - k :]
x = semiring.dot(ACR, BCL)

arcs_l = semiring.times(x, arc_scores[:, :, f[1], f[0]])

alpha[A][I][L][: N - k, k] = arcs_l
alpha[B][I][L][k:N, N - k - 1] = arcs_l

arcs_r = semiring.times(x, arc_scores[:, :, f[0], f[1]])
alpha[A][I][R][: N - k, k] = arcs_r
alpha[B][I][R][k:N, N - k - 1] = arcs_r

AIR = alpha[A][I][R][: N - k, 1 : k + 1]
BIL = alpha[B][I][L][k:, N - k - 1 : N - 1]

new = semiring.dot(ACL, BIL)
alpha[A][C][L][: N - k, k] = new
alpha[B][C][L][k:N, N - k - 1] = new

AIR = alpha[A][I][R, : N - k, 1 : k + 1]
BIL = alpha[B][I][L, k:, N - k - 1 : N - 1]
new = semiring.dot(stack(ACL, AIR), stack(BIL, BCR))
alpha[A][C][:, : N - k, k] = new
alpha[B][C][:, k:N, N - k - 1] = new
new = semiring.dot(AIR, BCR)
alpha[A][C][R][: N - k, k] = new
alpha[B][C][R][k:N, N - k - 1] = new

final = alpha[A][C][R, 0]
final = alpha[A][C][R][(0,)]
v = torch.stack([final[:, i, l] for i, l in enumerate(lengths)], dim=1)
return v, [arc_scores], alpha

Expand Down
36 changes: 28 additions & 8 deletions torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@ def _check_potentials(self, edge, lengths=None):

if lengths is None:
lengths = torch.LongTensor([N] * batch)

assert max(lengths) <= N, "Length longer than edge scores"
assert max(lengths) == N, "One length must be at least N"
# pass
else:
assert max(lengths) <= N, "Length longer than edge scores"
assert max(lengths) == N, "One length must be at least N"
assert C == C2, "Transition shape doesn't match"
return edge, batch, N, C, lengths

Expand All @@ -49,22 +50,41 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
"Compute forward pass by linear scan"
# Setup
semiring = self.semiring
ssize = semiring.size()
log_potentials, batch, N, C, lengths = self._check_potentials(
log_potentials, lengths
)
log_N, bin_N = self._bin_length(N - 1)
chart = self._chart((batch, bin_N, C, C), log_potentials, force_grad)

# Init
for b in range(lengths.shape[0]):
end = lengths[b] - 1
semiring.one_(chart[:, b, end:].diagonal(0, 2, 3))
chart[:, b, :end] = log_potentials[:, b, :end]
semiring.one_(chart[:, :, :].diagonal(0, 3, 4))

# Length mask
big = torch.zeros(
ssize,
batch,
bin_N,
C,
C,
dtype=log_potentials.dtype,
device=log_potentials.device,
)
big[:, :, : N - 1] = log_potentials
c = chart[:, :, :].view(ssize, batch * bin_N, C, C)
lp = big[:, :, :].view(ssize, batch * bin_N, C, C)
mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask.view(batch * bin_N, 1, 1).to(lp.device)
semiring.zero_mask_(lp.data, mask)
semiring.zero_mask_(c.data, (~mask))

c[:] = semiring.sum(torch.stack([c.data, lp], dim=-1))

# Scan
for n in range(1, log_N + 1):
chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2])
v = semiring.sum(semiring.sum(chart[:, :, 0]))
v = semiring.sum(semiring.sum(chart[:, :, 0].contiguous()))
return v, [log_potentials], None

# def _dp_standard(self, edge, lengths=None, force_grad=False):
Expand Down
41 changes: 33 additions & 8 deletions torch_struct/semimarkov.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):

# Setup
semiring = self.semiring
ssize = semiring.size()
log_potentials.requires_grad_(True)
log_potentials, batch, N, K, C, lengths = self._check_potentials(
log_potentials, lengths
Expand All @@ -33,12 +34,36 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
)

# Init.
for b in range(lengths.shape[0]):
end = lengths[b] - 1
semiring.one_(init[:, b, end:, 0, 0].diagonal(0, 2, 3))
init[:, b, :end, : (K - 1), 0] = log_potentials[:, b, :end, 1:K]
for k in range(1, K - 1):
semiring.one_(init[:, b, : end - (k - 1), k - 1, k].diagonal(0, 2, 3))
semiring.one_(init.data[:, :, :, 0, 0].diagonal(0, -2, -1))

# Length mask
big = torch.zeros(
ssize,
batch,
bin_N,
K,
C,
C,
dtype=log_potentials.dtype,
device=log_potentials.device,
)
big[:, :, : N - 1] = log_potentials
c = init[:, :, :].view(ssize, batch * bin_N, K - 1, K - 1, C, C)
lp = big[:, :, :].view(ssize, batch * bin_N, K, C, C)
mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask.view(batch * bin_N, 1, 1, 1).to(lp.device)
semiring.zero_mask_(lp.data, mask)
semiring.zero_mask_(c.data[:, :, :, 0], (~mask))
c[:, :, : K - 1, 0] = semiring.sum(
torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1)
)
end = torch.min(lengths) - 1
for k in range(1, K - 1):
semiring.one_(
init.data[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1)
)

K_1 = K - 1

# Order n, n-1
Expand All @@ -51,8 +76,8 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
for n in range(1, log_N + 1):
chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2])

final = chart.view(-1, batch, 1, K_1, C, K_1, C)
v = semiring.sum(semiring.sum(final[:, :, 0, 0, :, 0, :]))
final = chart.view(-1, batch, K_1, C, K_1, C)
v = semiring.sum(semiring.sum(final[:, :, 0, :, 0, :].contiguous()))
return v, [log_potentials], None

# def _dp_standard(self, edge, lengths=None, force_grad=False):
Expand Down
27 changes: 23 additions & 4 deletions torch_struct/semirings/semirings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import torch
import genbmm

has_genbmm = False
try:
import genbmm
has_genbmm = True
except ImportError:
pass


def matmul(cls, a, b):
Expand Down Expand Up @@ -63,6 +69,11 @@ def zero_(xs):
"Fill *ssize x ...* tensor with additive identity."
raise NotImplementedError()

@classmethod
def zero_mask_(cls, xs, mask):
"Fill *ssize x ...* tensor with additive identity."
xs.masked_fill_(mask.unsqueeze(0), cls.zero)

@staticmethod
def one_(xs):
"Fill *ssize x ...* tensor with multiplicative identity."
Expand Down Expand Up @@ -143,7 +154,7 @@ def matmul(cls, a, b, dims=1):
(Faster than calling sum and times.)
"""

if isinstance(a, genbmm.BandedMatrix):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
return b.multiply(a.transpose())
else:
return torch.matmul(a, b)
Expand All @@ -158,7 +169,7 @@ class LogSemiring(_BaseLog):

@classmethod
def matmul(cls, a, b):
if isinstance(a, genbmm.BandedMatrix):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
return b.multiply_log(a.transpose())
else:
return _BaseLog.matmul(a, b)
Expand All @@ -173,7 +184,7 @@ class MaxSemiring(_BaseLog):

@classmethod
def matmul(cls, a, b):
if isinstance(a, genbmm.BandedMatrix):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
return b.multiply_max(a.transpose())
else:
return matmul(cls, a, b)
Expand Down Expand Up @@ -269,6 +280,8 @@ class EntropySemiring(Semiring):
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
"""

zero = 0

@staticmethod
def size():
return 2
Expand Down Expand Up @@ -301,6 +314,12 @@ def mul(a, b):
def prod(cls, xs, dim=-1):
return xs.sum(dim)

@classmethod
def zero_mask_(cls, xs, mask):
"Fill *ssize x ...* tensor with additive identity."
xs[0].masked_fill_(mask, -1e5)
xs[1].masked_fill_(mask, 0)

@staticmethod
def zero_(xs):
xs[0].fill_(-1e5)
Expand Down
7 changes: 5 additions & 2 deletions torch_struct/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,17 @@ def test_cky(data):
@settings(max_examples=50, deadline=None)
def test_generic_a(data):
model = data.draw(
sampled_from([Alignment]) # , LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree])
sampled_from(
[SemiMarkov]
) # , Alignment , LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree])
)

semiring = data.draw(sampled_from([LogSemiring, MaxSemiring]))
struct = model(semiring)
vals, (batch, N) = model._rand()
alpha = struct.sum(vals)
count = struct.enumerate(vals)[0]

# assert(False)
assert alpha.shape[0] == batch
assert count.shape[0] == batch
assert alpha.shape == count.shape
Expand Down

0 comments on commit b3ce3b1

Please sign in to comment.