Skip to content

Commit

Permalink
Merge 2f72847 into 5328ec5
Browse files Browse the repository at this point in the history
  • Loading branch information
da03 committed Oct 14, 2021
2 parents 5328ec5 + 2f72847 commit 4f8a4b3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
20 changes: 20 additions & 0 deletions tests/test_algorithms.py
Expand Up @@ -522,3 +522,23 @@ def test_hsmm(model_test, semiring):

assert torch.isclose(partition1, partition2).all()
assert torch.isclose(partition2, partition3).all()


@given(data())
@pytest.mark.parametrize("model_test", ["SemiMarkov"])
@pytest.mark.parametrize("semiring", [LogSemiring, MaxSemiring])
def test_batching_lengths(model_test, semiring, data):
"Test batching"
gen = Gen(model_test, data, LogSemiring)
model, vals, N, batch = gen.model, gen.vals, gen.N, gen.batch
lengths = torch.tensor(
[data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N]
)
# first way: batched implementation
partition = model(semiring).logpartition(vals, lengths=lengths)[0][0]
# second way: unbatched implementation
for b in range(batch):
vals_b = vals[b:(b + 1), :(lengths[b] - 1)]
lengths_b = lengths[b:(b + 1)]
partition_b = model(semiring).logpartition(vals_b, lengths=lengths_b)[0][0]
assert torch.isclose(partition[b], partition_b).all()
9 changes: 5 additions & 4 deletions torch_struct/semimarkov.py
Expand Up @@ -34,7 +34,7 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
)

# Init.
mask = torch.zeros(*init.shape).bool()
mask = torch.zeros(*init.shape, device=log_potentials.device).bool()
mask[:, :, :, 0, 0].diagonal(0, -2, -1).fill_(True)
init = semiring.fill(init, mask, semiring.one)

Expand All @@ -61,10 +61,11 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
c[:, :, : K - 1, 0] = semiring.sum(
torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1)
)
end = torch.min(lengths) - 1
mask = torch.zeros(*init.shape).bool()
mask = torch.zeros(*init.shape, device=log_potentials.device).bool()
for k in range(1, K - 1):
mask[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True)
for b in range(batch):
end = lengths[b] - 1
mask[:, b, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True)
init = semiring.fill(init, mask, semiring.one)

K_1 = K - 1
Expand Down

0 comments on commit 4f8a4b3

Please sign in to comment.