Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Oct 21, 2019
1 parent 55296e2 commit 46a43c3
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions torch_struct/semimarkov.py
Expand Up @@ -60,21 +60,15 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
chart[0][:, b, : end - (k - 1), k - 1, k].diagonal(0, 2, 3)
)

K_1 = K-1
# Scan
def merge(x, size):
return semiring.sum(
semiring.sum(
semiring.times(
x[:, :, 0 : size * 2 : 2]
.transpose(-1, -2)
.transpose(-3, -4)
.view(ssize, batch, size, 1, K - 1, K - 1, 1, C, C),
x[:, :, 1 : size * 2 : 2].view(
ssize, batch, size, K - 1, 1, K - 1, C, 1, C
),
)
).transpose(-1, 5)
).transpose(-1, -2)
left = x[:, :, 0 : size * 2 : 2].permute(0, 1, 2, 4, 6, 3, 5)
right = x[:, :, 1 : size * 2 : 2].permute(0, 1, 2, 3, 5, 4, 6)
return semiring.dot(
left.view(ssize, batch, size, 1, K_1, 1, C, K_1* C),
right.view(ssize, batch, size, K_1, 1, C, 1, K_1* C),
)

size = bin_N
for n in range(1, log_N + 1):
Expand Down

0 comments on commit 46a43c3

Please sign in to comment.