Skip to content

Commit

Permalink
Merge 4705eca into 5328ec5
Browse files Browse the repository at this point in the history
  • Loading branch information
da03 committed Oct 14, 2021
2 parents 5328ec5 + 4705eca commit ef6a196
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
20 changes: 20 additions & 0 deletions tests/test_algorithms.py
Expand Up @@ -522,3 +522,23 @@ def test_hsmm(model_test, semiring):

assert torch.isclose(partition1, partition2).all()
assert torch.isclose(partition2, partition3).all()


@given(data())
@pytest.mark.parametrize("model_test", ["SemiMarkov"])
@pytest.mark.parametrize("semiring", [LogSemiring, MaxSemiring])
def test_batching_lengths(model_test, semiring, data):
"Test batching"
gen = Gen(model_test, data, LogSemiring)
model, vals, N, batch = gen.model, gen.vals, gen.N, gen.batch
lengths = torch.tensor(
[data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N]
)
# first way: batched implementation
partition = model(semiring).logpartition(vals, lengths=lengths)[0][0]
# second way: unbatched implementation
for b in range(batch):
vals_b = vals[b:(b + 1), :(lengths[b] - 1)]
lengths_b = lengths[b:(b + 1)]
partition_b = model(semiring).logpartition(vals_b, lengths=lengths_b)[0][0]
assert torch.isclose(partition[b], partition_b).all()
11 changes: 7 additions & 4 deletions torch_struct/semimarkov.py
Expand Up @@ -34,7 +34,7 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
)

# Init.
mask = torch.zeros(*init.shape).bool()
mask = torch.zeros(*init.shape, device=log_potentials.device).bool()
mask[:, :, :, 0, 0].diagonal(0, -2, -1).fill_(True)
init = semiring.fill(init, mask, semiring.one)

Expand All @@ -61,10 +61,13 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False):
c[:, :, : K - 1, 0] = semiring.sum(
torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1)
)
end = torch.min(lengths) - 1
mask = torch.zeros(*init.shape).bool()
mask = torch.zeros(*init.shape, device=log_potentials.device).bool()
mask_length = torch.arange(bin_N).view(1, bin_N, 1).expand(batch, bin_N, C)
mask_length = mask_length.to(log_potentials.device)
for k in range(1, K - 1):
mask[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True)
mask_length_k = mask_length < (lengths - 1 - (k - 1)).view(batch, 1, 1)
mask_length_k = semiring.convert(mask_length_k)
mask[:, :, :, k - 1, k].diagonal(0, -2, -1).masked_fill_(mask_length_k, True)
init = semiring.fill(init, mask, semiring.one)

K_1 = K - 1
Expand Down

0 comments on commit ef6a196

Please sign in to comment.