Skip to content

Commit

Permalink
fix semimarkov batching and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
da03 committed Oct 14, 2021
1 parent 5328ec5 commit ada7645
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
20 changes: 20 additions & 0 deletions tests/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,23 @@ def test_hsmm(model_test, semiring):

assert torch.isclose(partition1, partition2).all()
assert torch.isclose(partition2, partition3).all()


@given(data())
@pytest.mark.parametrize("model_test", ["SemiMarkov"])
@pytest.mark.parametrize("semiring", [LogSemiring, MaxSemiring])
def test_batching_lengths(model_test, semiring, data):
"Test batching"
gen = Gen(model_test, data, LogSemiring)
model, vals, N, batch = gen.model, gen.vals, gen.N, gen.batch
lengths = torch.tensor(
[data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N]
)
# first way: batched implementation
partition = model(semiring).logpartition(vals, lengths=lengths)[0][0]
# second way: unbatched implementation
for b in range(batch):
vals_b = vals[b:(b + 1), :(lengths[b] - 1)]
lengths_b = lengths[b:(b + 1)]
partition_b = model(semiring).logpartition(vals_b, lengths=lengths_b)[0][0]
assert torch.isclose(partition[b], partition_b).all()
10 changes: 6 additions & 4 deletions torch_struct/semimarkov.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
)

# Init.
mask = torch.zeros(*init.shape).bool()
mask = torch.zeros(*init.shape, device=log_potentials.device).bool()
mask[:, :, :, 0, 0].diagonal(0, -2, -1).fill_(True)
init = semiring.fill(init, mask, semiring.one)

Expand All @@ -61,10 +61,12 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
c[:, :, : K - 1, 0] = semiring.sum(
torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1)
)
end = torch.min(lengths) - 1
mask = torch.zeros(*init.shape).bool()
end = torch.max(lengths) - 1
mask = torch.zeros(*init.shape, device=log_potentials.device).bool()
for k in range(1, K - 1):
mask[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True)
for b in range(batch):
end = lengths[b] - 1
mask[:, b, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True)
init = semiring.fill(init, mask, semiring.one)

K_1 = K - 1
Expand Down

0 comments on commit ada7645

Please sign in to comment.