Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed ups for linear chain and semi-markov #42

Merged
merged 41 commits into from Nov 26, 2019
Merged
63 changes: 34 additions & 29 deletions torch_struct/deptree.py
Expand Up @@ -51,43 +51,48 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False):
arc_scores = _convert(arc_scores_in)
arc_scores, batch, N, lengths = self._check_potentials(arc_scores, lengths)
arc_scores.requires_grad_(True)
DIRS = 2
alpha = [
[Chart((batch, DIRS, N, N), arc_scores, semiring) for _ in range(2)]
[
[Chart((batch, N, N), arc_scores, semiring) for _ in range(2)]
for _ in range(2)
]
for _ in range(2)
]
semiring.one_(alpha[A][C].data[:, :, :, :, 0].data)
semiring.one_(alpha[B][C].data[:, :, :, :, -1].data)

def stack(a, b=None):
if b is None:
return torch.stack([a, a], dim=2)
else:
return torch.stack([a, b], dim=2)
semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)

for k in range(1, N):
f = torch.arange(N - k), torch.arange(k, N)
AC = alpha[A][C][:, : N - k, :k]
ACL, ACR = AC.unbind(2)

BC = alpha[B][C][:, k:, N - k :]
BCL, BCR = BC.unbind(2)
arcs = semiring.dot(
semiring.times(stack(ACR), stack(BCL)),
stack(
arc_scores[:, :, f[1], f[0]], arc_scores[:, :, f[0], f[1]]
).unsqueeze(-1),
)
alpha[A][I][:, : N - k, k] = arcs
alpha[B][I][:, k:N, N - k - 1] = arcs
ACL = alpha[A][C][L][: N - k, :k]
ACR = alpha[A][C][R][: N - k, :k]

BCL = alpha[B][C][L][k:, N - k :]
BCR = alpha[B][C][R][k:, N - k :]
x = semiring.dot(ACR, BCL)

arcs_l = semiring.times(x, arc_scores[:, :, f[1], f[0]])

alpha[A][I][L][: N - k, k] = arcs_l
alpha[B][I][L][k:N, N - k - 1] = arcs_l

arcs_r = semiring.times(x, arc_scores[:, :, f[0], f[1]])
alpha[A][I][R][: N - k, k] = arcs_r
alpha[B][I][R][k:N, N - k - 1] = arcs_r

AIR = alpha[A][I][R][: N - k, 1 : k + 1]
BIL = alpha[B][I][L][k:, N - k - 1 : N - 1]

new = semiring.dot(ACL, BIL)
alpha[A][C][L][: N - k, k] = new
alpha[B][C][L][k:N, N - k - 1] = new

AIR = alpha[A][I][R, : N - k, 1 : k + 1]
BIL = alpha[B][I][L, k:, N - k - 1 : N - 1]
new = semiring.dot(stack(ACL, AIR), stack(BIL, BCR))
alpha[A][C][:, : N - k, k] = new
alpha[B][C][:, k:N, N - k - 1] = new
new = semiring.dot(AIR, BCR)
alpha[A][C][R][: N - k, k] = new
alpha[B][C][R][k:N, N - k - 1] = new

final = alpha[A][C][R, 0]
final = alpha[A][C][R][(0,)]
v = torch.stack([final[:, i, l] for i, l in enumerate(lengths)], dim=1)
return v, [arc_scores], alpha

Expand Down
36 changes: 28 additions & 8 deletions torch_struct/linearchain.py
Expand Up @@ -36,9 +36,10 @@ def _check_potentials(self, edge, lengths=None):

if lengths is None:
lengths = torch.LongTensor([N] * batch)

assert max(lengths) <= N, "Length longer than edge scores"
assert max(lengths) == N, "One length must be at least N"
# pass
else:
assert max(lengths) <= N, "Length longer than edge scores"
assert max(lengths) == N, "One length must be at least N"
assert C == C2, "Transition shape doesn't match"
return edge, batch, N, C, lengths

Expand All @@ -49,22 +50,41 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
"Compute forward pass by linear scan"
# Setup
semiring = self.semiring
ssize = semiring.size()
log_potentials, batch, N, C, lengths = self._check_potentials(
log_potentials, lengths
)
log_N, bin_N = self._bin_length(N - 1)
chart = self._chart((batch, bin_N, C, C), log_potentials, force_grad)

# Init
for b in range(lengths.shape[0]):
end = lengths[b] - 1
semiring.one_(chart[:, b, end:].diagonal(0, 2, 3))
chart[:, b, :end] = log_potentials[:, b, :end]
semiring.one_(chart[:, :, :].diagonal(0, 3, 4))

# Length mask
big = torch.zeros(
ssize,
batch,
bin_N,
C,
C,
dtype=log_potentials.dtype,
device=log_potentials.device,
)
big[:, :, : N - 1] = log_potentials
c = chart[:, :, :].view(ssize, batch * bin_N, C, C)
lp = big[:, :, :].view(ssize, batch * bin_N, C, C)
mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask.view(batch * bin_N, 1, 1).to(lp.device)
semiring.zero_mask_(lp.data, mask)
semiring.zero_mask_(c.data, (~mask))

c[:] = semiring.sum(torch.stack([c.data, lp], dim=-1))

# Scan
for n in range(1, log_N + 1):
chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2])
v = semiring.sum(semiring.sum(chart[:, :, 0]))
v = semiring.sum(semiring.sum(chart[:, :, 0].contiguous()))
return v, [log_potentials], None

# def _dp_standard(self, edge, lengths=None, force_grad=False):
Expand Down
41 changes: 33 additions & 8 deletions torch_struct/semimarkov.py
Expand Up @@ -23,6 +23,7 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):

# Setup
semiring = self.semiring
ssize = semiring.size()
log_potentials.requires_grad_(True)
log_potentials, batch, N, K, C, lengths = self._check_potentials(
log_potentials, lengths
Expand All @@ -33,12 +34,36 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
)

# Init.
for b in range(lengths.shape[0]):
end = lengths[b] - 1
semiring.one_(init[:, b, end:, 0, 0].diagonal(0, 2, 3))
init[:, b, :end, : (K - 1), 0] = log_potentials[:, b, :end, 1:K]
for k in range(1, K - 1):
semiring.one_(init[:, b, : end - (k - 1), k - 1, k].diagonal(0, 2, 3))
semiring.one_(init.data[:, :, :, 0, 0].diagonal(0, -2, -1))

# Length mask
big = torch.zeros(
ssize,
batch,
bin_N,
K,
C,
C,
dtype=log_potentials.dtype,
device=log_potentials.device,
)
big[:, :, : N - 1] = log_potentials
c = init[:, :, :].view(ssize, batch * bin_N, K - 1, K - 1, C, C)
lp = big[:, :, :].view(ssize, batch * bin_N, K, C, C)
mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask.view(batch * bin_N, 1, 1, 1).to(lp.device)
semiring.zero_mask_(lp.data, mask)
semiring.zero_mask_(c.data[:, :, :, 0], (~mask))
c[:, :, : K - 1, 0] = semiring.sum(
torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1)
)
end = torch.min(lengths) - 1
for k in range(1, K - 1):
semiring.one_(
init.data[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1)
)

K_1 = K - 1

# Order n, n-1
Expand All @@ -51,8 +76,8 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
for n in range(1, log_N + 1):
chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2])

final = chart.view(-1, batch, 1, K_1, C, K_1, C)
v = semiring.sum(semiring.sum(final[:, :, 0, 0, :, 0, :]))
final = chart.view(-1, batch, K_1, C, K_1, C)
v = semiring.sum(semiring.sum(final[:, :, 0, :, 0, :].contiguous()))
return v, [log_potentials], None

# def _dp_standard(self, edge, lengths=None, force_grad=False):
Expand Down
27 changes: 23 additions & 4 deletions torch_struct/semirings/semirings.py
@@ -1,5 +1,11 @@
import torch
import genbmm

has_genbmm = False
try:
import genbmm
has_genbmm = True
except ImportError:
pass


def matmul(cls, a, b):
Expand Down Expand Up @@ -63,6 +69,11 @@ def zero_(xs):
"Fill *ssize x ...* tensor with additive identity."
raise NotImplementedError()

@classmethod
def zero_mask_(cls, xs, mask):
"Fill *ssize x ...* tensor with additive identity."
xs.masked_fill_(mask.unsqueeze(0), cls.zero)

@staticmethod
def one_(xs):
"Fill *ssize x ...* tensor with multiplicative identity."
Expand Down Expand Up @@ -143,7 +154,7 @@ def matmul(cls, a, b, dims=1):
(Faster than calling sum and times.)
"""

if isinstance(a, genbmm.BandedMatrix):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
return b.multiply(a.transpose())
else:
return torch.matmul(a, b)
Expand All @@ -158,7 +169,7 @@ class LogSemiring(_BaseLog):

@classmethod
def matmul(cls, a, b):
if isinstance(a, genbmm.BandedMatrix):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
return b.multiply_log(a.transpose())
else:
return _BaseLog.matmul(a, b)
Expand All @@ -173,7 +184,7 @@ class MaxSemiring(_BaseLog):

@classmethod
def matmul(cls, a, b):
if isinstance(a, genbmm.BandedMatrix):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
return b.multiply_max(a.transpose())
else:
return matmul(cls, a, b)
Expand Down Expand Up @@ -269,6 +280,8 @@ class EntropySemiring(Semiring):
* First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first`
"""

zero = 0

@staticmethod
def size():
return 2
Expand Down Expand Up @@ -301,6 +314,12 @@ def mul(a, b):
def prod(cls, xs, dim=-1):
return xs.sum(dim)

@classmethod
def zero_mask_(cls, xs, mask):
"Fill *ssize x ...* tensor with additive identity."
xs[0].masked_fill_(mask, -1e5)
xs[1].masked_fill_(mask, 0)

@staticmethod
def zero_(xs):
xs[0].fill_(-1e5)
Expand Down
7 changes: 5 additions & 2 deletions torch_struct/test_algorithms.py
Expand Up @@ -142,14 +142,17 @@ def test_cky(data):
@settings(max_examples=50, deadline=None)
def test_generic_a(data):
model = data.draw(
sampled_from([Alignment]) # , LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree])
sampled_from(
[SemiMarkov]
) # , Alignment , LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree])
)

semiring = data.draw(sampled_from([LogSemiring, MaxSemiring]))
struct = model(semiring)
vals, (batch, N) = model._rand()
alpha = struct.sum(vals)
count = struct.enumerate(vals)[0]

# assert(False)
assert alpha.shape[0] == batch
assert count.shape[0] == batch
assert alpha.shape == count.shape
Expand Down