Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 26, 2019
1 parent 6cae984 commit 2d87ed0
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions torch_struct/semimarkov.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
# semiring.one_(init.data[:, :, :, 0, 0].diagonal(0, -2, -1))

# Length mask
big = torch.zeros(
ssize,
batch,
bin_N,
K,
C,
C,
dtype=log_potentials.dtype,
device=log_potentials.device,
)
# big = torch.zeros(
# ssize,
# batch,
# bin_N,
# K,
# C,
# C,
# dtype=log_potentials.dtype,
# device=log_potentials.device,
# )
# big[:, :, : N - 1] = log_potentials
# c = init[:, :, :].view(ssize, batch * bin_N, K - 1, K - 1, C, C)
# lp = big[:, :, :].view(ssize, batch * bin_N, K, C, C)
Expand All @@ -64,7 +64,7 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
# for k in range(1, K - 1):
# semiring.one_(init.data[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1))

K_1 = K - 1
# K_1 = K - 1

# Order n, n-1
chart = (
Expand All @@ -73,8 +73,8 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
.view(-1, batch, bin_N, K_1 * C, K_1 * C)
)

for n in range(1, log_N + 1):
chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2])
# for n in range(1, log_N + 1):
# chart = semiring.matmul(chart[:, :, 1::2], chart[:, :, 0::2])

final = chart.view(-1, batch, 1, K_1, C, K_1, C)
v = semiring.sum(semiring.sum(final[:, :, 0, 0, :, 0, :]))
Expand Down

0 comments on commit 2d87ed0

Please sign in to comment.