Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Oct 30, 2019
1 parent 595bd5d commit 7c18ada
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions torch_struct/semimarkov.py
Expand Up @@ -31,9 +31,14 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
)
log_N = int(math.ceil(math.log(N - 1, 2)))
bin_N = int(math.pow(2, log_N))
chart = self._make_chart(
log_N + 1, (batch, bin_N, K - 1, K - 1, C, C), log_potentials, force_grad
)
chart = [
self._make_chart(
1, (batch, bin_N, K - 1, K - 1, C, C), log_potentials, force_grad
)[0]
if i == 0
else None
for i in range(log_N + 1)
]

# Init
for b in range(lengths.shape[0]):
Expand Down Expand Up @@ -67,7 +72,7 @@ def merge(x, size):
size = bin_N
for n in range(1, log_N + 1):
size = int(size / 2)
chart[n][:, :, :size] = merge(chart[n - 1], size)
chart[n] = merge(chart[n - 1], size)
v = semiring.sum(semiring.sum(chart[-1][:, :, 0, 0, 0, :, :]))
return v, [log_potentials], None

Expand Down

0 comments on commit 7c18ada

Please sign in to comment.