Skip to content

Commit

Permalink
Remove imperative filling functions _ (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 12, 2021
1 parent 84ee7cd commit e51fecc
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 141 deletions.
20 changes: 18 additions & 2 deletions tests/extensions.py
Expand Up @@ -26,7 +26,17 @@ def enumerate(semiring, edge, lengths=None):
semiring = semiring
ssize = semiring.size()
edge, batch, N, C, lengths = model._check_potentials(edge, lengths)
chains = [[([c], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)]]
chains = [
[
(
[c],
semiring.fill(
torch.zeros(ssize, batch), torch.tensor(True), semiring.one
),
)
for c in range(C)
]
]

enum_lengths = torch.LongTensor(lengths.shape)
for n in range(1, N):
Expand Down Expand Up @@ -128,7 +138,13 @@ def enumerate(semiring, edge):
edge = semiring.convert(edge)
chains = {}
chains[0] = [
([(c, 0)], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)
(
[(c, 0)],
semiring.fill(
torch.zeros(ssize, batch), torch.tensor(True), semiring.one
),
)
for c in range(C)
]

for n in range(1, N + 1):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_algorithms.py
Expand Up @@ -263,7 +263,7 @@ def test_generic_lengths(model_test, data):
part = model().sum(vals, lengths=lengths)

# Check that max is correct
assert (maxes <= part).all()
assert (maxes <= part + 1e-3).all()
m_part = model(MaxSemiring).sum(vals, lengths=lengths)
assert (torch.isclose(maxes, m_part)).all(), maxes - m_part

Expand Down
6 changes: 4 additions & 2 deletions torch_struct/autoregressive.py
Expand Up @@ -118,8 +118,10 @@ def log_prob(self, value, sparse=False):
return wrap(scores, sample)

def _beam_search(self, semiring, gumbel=False):
beam = semiring.one_(
torch.zeros((semiring.size(),) + self.batch_shape, device=self.device)
beam = semiring.fill(
torch.zeros((semiring.size(),) + self.batch_shape, device=self.device),
torch.tensor(True),
semiring.one,
)
ssize = semiring.size()

Expand Down
29 changes: 22 additions & 7 deletions torch_struct/deptree.py
Expand Up @@ -66,10 +66,22 @@ def logpartition(self, arc_scores_in, lengths=None, force_grad=False):
]
for _ in range(2)
]
semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)
mask = torch.zeros(alpha[A][C][L].data.shape).bool()
mask[:, :, :, 0].fill_(True)
alpha[A][C][L].data[:] = semiring.fill(
alpha[A][C][L].data[:], mask, semiring.one
)
alpha[A][C][R].data[:] = semiring.fill(
alpha[A][C][R].data[:], mask, semiring.one
)
mask = torch.zeros(alpha[B][C][L].data[:].shape).bool()
mask[:, :, :, -1].fill_(True)
alpha[B][C][L].data[:] = semiring.fill(
alpha[B][C][L].data[:], mask, semiring.one
)
alpha[B][C][R].data[:] = semiring.fill(
alpha[B][C][R].data[:], mask, semiring.one
)

if multiroot:
start_idx = 0
Expand Down Expand Up @@ -119,10 +131,13 @@ def _check_potentials(self, arc_scores, lengths=None):
lengths = torch.LongTensor([N - 1] * batch).to(arc_scores.device)
assert max(lengths) <= N, "Length longer than N"
arc_scores = semiring.convert(arc_scores)
for b in range(batch):
semiring.zero_(arc_scores[:, b, lengths[b] + 1 :, :])
semiring.zero_(arc_scores[:, b, :, lengths[b] + 1 :])

# Set the extra elements of the log-potentials to zero.
keep = torch.ones_like(arc_scores).bool()
for b in range(batch):
keep[:, b, lengths[b] + 1 :, :].fill_(0.0)
keep[:, b, :, lengths[b] + 1 :].fill_(0.0)
arc_scores = semiring.fill(arc_scores, ~keep, semiring.zero)
return arc_scores, batch, N, lengths

def _arrange_marginals(self, grads):
Expand Down
1 change: 1 addition & 0 deletions torch_struct/distributions.py
Expand Up @@ -36,6 +36,7 @@ class StructDistribution(Distribution):
log_potentials (tensor, batch_shape x event_shape) : log-potentials :math:`\phi`
lengths (long tensor, batch_shape) : integers for length masking
"""
validate_args = False

def __init__(self, log_potentials, lengths=None, args={}):
batch_shape = log_potentials.shape[:1]
Expand Down
34 changes: 17 additions & 17 deletions torch_struct/helpers.py
Expand Up @@ -5,13 +5,14 @@

class Chart:
def __init__(self, size, potentials, semiring):
self.data = semiring.zero_(
torch.zeros(
*((semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
c = torch.zeros(
*((semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
c[:] = semiring.zero.view((semiring.size(),) + len(size) * (1,))

self.data = c
self.grad = self.data.detach().clone().fill_(0.0)

def __getitem__(self, ind):
Expand Down Expand Up @@ -50,18 +51,17 @@ def _chart(self, size, potentials, force_grad):
return self._make_chart(1, size, potentials, force_grad)[0]

def _make_chart(self, N, size, potentials, force_grad=False):
return [
(
self.semiring.zero_(
torch.zeros(
*((self.semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
).requires_grad_(force_grad and not potentials.requires_grad)
chart = []
for _ in range(N):
c = torch.zeros(
*((self.semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
for _ in range(N)
]
c[:] = self.semiring.zero.view((self.semiring.size(),) + len(size) * (1,))
c.requires_grad_(force_grad and not potentials.requires_grad)
chart.append(c)
return chart

def sum(self, logpotentials, lengths=None, _raw=False):
"""
Expand Down
8 changes: 5 additions & 3 deletions torch_struct/linearchain.py
Expand Up @@ -53,7 +53,9 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
chart = self._chart((batch, bin_N, C, C), log_potentials, force_grad)

# Init
semiring.one_(chart[:, :, :].diagonal(0, 3, 4))
init = torch.zeros(*chart.shape).bool()
init.diagonal(0, 3, 4).fill_(True)
chart = semiring.fill(chart, init, semiring.one)

# Length mask
big = torch.zeros(
Expand All @@ -71,8 +73,8 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N).type_as(c)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask.view(batch * bin_N, 1, 1).to(lp.device)
semiring.zero_mask_(lp.data, mask)
semiring.zero_mask_(c.data, (~mask))
lp.data[:] = semiring.fill(lp.data, mask, semiring.zero)
c.data[:] = semiring.fill(c.data, ~mask, semiring.zero)

c[:] = semiring.sum(torch.stack([c.data, lp], dim=-1))

Expand Down
14 changes: 8 additions & 6 deletions torch_struct/semimarkov.py
Expand Up @@ -34,7 +34,9 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
)

# Init.
semiring.one_(init.data[:, :, :, 0, 0].diagonal(0, -2, -1))
mask = torch.zeros(*init.shape).bool()
mask[:, :, :, 0, 0].diagonal(0, -2, -1).fill_(True)
init = semiring.fill(init, mask, semiring.one)

# Length mask
big = torch.zeros(
Expand All @@ -54,16 +56,16 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
mask = mask.to(log_potentials.device)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask.view(batch * bin_N, 1, 1, 1).to(lp.device)
semiring.zero_mask_(lp.data, mask)
semiring.zero_mask_(c.data[:, :, :, 0], (~mask))
lp.data[:] = semiring.fill(lp.data, mask, semiring.zero)
c.data[:, :, :, 0] = semiring.fill(c.data[:, :, :, 0], (~mask), semiring.zero)
c[:, :, : K - 1, 0] = semiring.sum(
torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1)
)
end = torch.min(lengths) - 1
mask = torch.zeros(*init.shape).bool()
for k in range(1, K - 1):
semiring.one_(
init.data[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1)
)
mask[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True)
init = semiring.fill(init, mask, semiring.one)

K_1 = K - 1

Expand Down
1 change: 1 addition & 0 deletions torch_struct/semirings/checkpoint.py
Expand Up @@ -4,6 +4,7 @@
try:
import genbmm
from genbmm import BandedMatrix

has_genbmm = True
except ImportError:
pass
Expand Down

0 comments on commit e51fecc

Please sign in to comment.