diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 65220313..086659d5 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -79,7 +79,6 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False): c.data[:, (~mask).view(-1)] = semiring.zero c[:] = semiring.sum(torch.stack([c, lp], dim=-1)) - # Scan for n in range(1, log_N + 1): chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2]) diff --git a/torch_struct/semimarkov.py b/torch_struct/semimarkov.py index e6e18b4e..1fef2dce 100644 --- a/torch_struct/semimarkov.py +++ b/torch_struct/semimarkov.py @@ -35,8 +35,6 @@ def _dp(self, log_potentials, lengths=None, force_grad=False): # Init. semiring.one_(init.data[:, :, :, 0, 0].diagonal(0, -2, -1)) - # for k in range(1, K - 1): - # semiring.one_(init[:, :, : , k - 1, k].diagonal(0, -2, -1)) # Length mask big = torch.zeros( @@ -50,33 +48,19 @@ def _dp(self, log_potentials, lengths=None, force_grad=False): device=log_potentials.device, ) big[:, :, : N - 1] = log_potentials - c = init[:, :, :].view(ssize, batch * bin_N, K -1, K-1, C, C) + c = init[:, :, :].view(ssize, batch * bin_N, K - 1, K - 1, C, C) lp = big[:, :, :].view(ssize, batch * bin_N, K, C, C) mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N) mask = mask >= (lengths - 1).view(batch, 1) lp.data[:, mask.view(-1)] = semiring.zero c.data[:, (~mask).view(-1)] = semiring.zero - c[:, :, :K-1, 0] = semiring.sum(torch.stack([c[:, :, :K-1, 0], - lp[:, :, 1:K]], dim=-1)) - - # for b in range(lengths.shape[0]): - # end = lengths[b] - 1 - # for k in range(1, K - 1): - # semiring.one_(init[:, b, : end - (k - 1), k - 1, k].diagonal(0, -2, -1)) + c[:, :, : K - 1, 0] = semiring.sum( + torch.stack([c[:, :, : K - 1, 0], lp[:, :, 1:K]], dim=-1) + ) - # ks = torch.arange(1, K - 1) for k in range(1, K - 1): semiring.one_(init[:, :, : (k - 1), k - 1, k].diagonal(0, -2, -1)) - # init[:, :, :end, : (K - 1), 0] = log_potentials[:, :, :end, 1:K] - - # init[:, :, :N-1, : (K - 1), 0] = log_potentials[:, :, :N-1, 1:K] - # for b in range(lengths.shape[0]): - # end = lengths[b] - 1 - # semiring.one_(init[:, b, end:, 0, 0].diagonal(0, 2, 3)) - # init[:, b, :end, : (K - 1), 0] = log_potentials[:, b, :end, 1:K] - # for k in range(1, K - 1): - # semiring.one_(init[:, b, : end - (k - 1), k - 1, k].diagonal(0, 2, 3)) K_1 = K - 1 # Order n, n-1 diff --git a/torch_struct/semirings/semirings.py b/torch_struct/semirings/semirings.py index c8419ed5..baf3ec82 100644 --- a/torch_struct/semirings/semirings.py +++ b/torch_struct/semirings/semirings.py @@ -275,6 +275,7 @@ class EntropySemiring(Semiring): * Parameter estimation for probabilistic finite-state transducers :cite:`eisner2002parameter` * First-and second-order expectation semirings with applications to minimum-risk training on translation forests :cite:`li2009first` """ + zero = 0 @staticmethod def size(): diff --git a/torch_struct/test_algorithms.py b/torch_struct/test_algorithms.py index 90580ade..82fe10cf 100644 --- a/torch_struct/test_algorithms.py +++ b/torch_struct/test_algorithms.py @@ -142,7 +142,8 @@ def test_cky(data): @settings(max_examples=50, deadline=None) def test_generic_a(data): model = data.draw( - sampled_from([SemiMarkov]) #Alignment , LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree]) + sampled_from( + [SemiMarkov, Alignment , LinearChain, SemiMarkov, CKY, CKY_CRF, DepTree]) ) semiring = data.draw(sampled_from([LogSemiring, MaxSemiring])) struct = model(semiring)