diff --git a/tests/test_algorithms.py b/tests/test_algorithms.py index 40354a6..79cfb5c 100644 --- a/tests/test_algorithms.py +++ b/tests/test_algorithms.py @@ -522,3 +522,23 @@ def test_hsmm(model_test, semiring): assert torch.isclose(partition1, partition2).all() assert torch.isclose(partition2, partition3).all() + + +@given(data()) +@pytest.mark.parametrize("model_test", ["SemiMarkov"]) +@pytest.mark.parametrize("semiring", [LogSemiring, MaxSemiring]) +def test_batching_lengths(model_test, semiring, data): + "Test batching" + gen = Gen(model_test, data, LogSemiring) + model, vals, N, batch = gen.model, gen.vals, gen.N, gen.batch + lengths = torch.tensor( + [data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N] + ) + # first way: batched implementation + partition = model(semiring).logpartition(vals, lengths=lengths)[0][0] + # second way: unbatched implementation + for b in range(batch): + vals_b = vals[b:(b + 1), :(lengths[b] - 1)] + lengths_b = lengths[b:(b + 1)] + partition_b = model(semiring).logpartition(vals_b, lengths=lengths_b)[0][0] + assert torch.isclose(partition[b], partition_b).all() diff --git a/torch_struct/semimarkov.py b/torch_struct/semimarkov.py index f614351..498f88d 100644 --- a/torch_struct/semimarkov.py +++ b/torch_struct/semimarkov.py @@ -34,7 +34,7 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False): ) # Init. - mask = torch.zeros(*init.shape).bool() + mask = torch.zeros(*init.shape, device=log_potentials.device).bool() mask[:, :, :, 0, 0].diagonal(0, -2, -1).fill_(True) init = semiring.fill(init, mask, semiring.one) @@ -61,10 +61,13 @@ def logpartition(self, log_potentials, lengths=None, force_grad=False): c[:, :, : K - 1, 0] = semiring.sum( torch.stack([c.data[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1) ) - end = torch.min(lengths) - 1 - mask = torch.zeros(*init.shape).bool() + mask = torch.zeros(*init.shape, device=log_potentials.device).bool() + mask_length = torch.arange(bin_N).view(1, bin_N, 1).expand(batch, bin_N, C) + mask_length = mask_length.to(log_potentials.device) for k in range(1, K - 1): - mask[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1).fill_(True) + mask_length_k = mask_length < (lengths - 1 - (k - 1)).view(batch, 1, 1) + mask_length_k = semiring.convert(mask_length_k) + mask[:, :, :, k - 1, k].diagonal(0, -2, -1).masked_fill_(mask_length_k, True) init = semiring.fill(init, mask, semiring.one) K_1 = K - 1