diff --git a/torch_struct/semimarkov.py b/torch_struct/semimarkov.py index 3f4a4942..eda04753 100644 --- a/torch_struct/semimarkov.py +++ b/torch_struct/semimarkov.py @@ -60,21 +60,15 @@ def _dp(self, log_potentials, lengths=None, force_grad=False): chart[0][:, b, : end - (k - 1), k - 1, k].diagonal(0, 2, 3) ) + K_1 = K-1 # Scan def merge(x, size): - return semiring.sum( - semiring.sum( - semiring.times( - x[:, :, 0 : size * 2 : 2] - .transpose(-1, -2) - .transpose(-3, -4) - .view(ssize, batch, size, 1, K - 1, K - 1, 1, C, C), - x[:, :, 1 : size * 2 : 2].view( - ssize, batch, size, K - 1, 1, K - 1, C, 1, C - ), - ) - ).transpose(-1, 5) - ).transpose(-1, -2) + left = x[:, :, 0 : size * 2 : 2].permute(0, 1, 2, 4, 6, 3, 5) + right = x[:, :, 1 : size * 2 : 2].permute(0, 1, 2, 3, 5, 4, 6) + return semiring.dot( + left.view(ssize, batch, size, 1, K_1, 1, C, K_1* C), + right.view(ssize, batch, size, K_1, 1, C, 1, K_1* C), + ) size = bin_N for n in range(1, log_N + 1):