Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 26, 2019
1 parent fd6935f commit 008e0dc
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions torch_struct/semimarkov.py
Expand Up @@ -37,32 +37,32 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
semiring.one_(init.data[:, :, :, 0, 0].diagonal(0, -2, -1))

# Length mask
# big = torch.zeros(
# ssize,
# batch,
# bin_N,
# K,
# C,
# C,
# dtype=log_potentials.dtype,
# device=log_potentials.device,
# )
# big[:, :, : N - 1] = log_potentials
# c = init[:, :, :].view(ssize, batch * bin_N, K - 1, K - 1, C, C)
# lp = big[:, :, :].view(ssize, batch * bin_N, K, C, C)
# mask = torch.arange(bin_N) \
# .view(1, bin_N).expand(batch, bin_N)
# mask = mask >= (lengths - 1).view(batch, 1)
# mask = mask.view(batch * bin_N, 1, 1, 1).to(lp.device)
# semiring.zero_mask_(lp.data, mask)
# semiring.zero_mask_(c.data[:, :, :, 0], (~mask))
# c[:, :, : K - 1, 0] = semiring.sum(
# torch.stack([c.data[:, :, : K - 1, 0],
# lp[:, :, 1:K]], dim=-1)
# )
# end = torch.min(lengths) - 1
# for k in range(1, K - 1):
# semiring.one_(init.data[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1))
big = torch.zeros(
ssize,
batch,
bin_N,
K,
C,
C,
dtype=log_potentials.dtype,
device=log_potentials.device,
)
big[:, :, : N - 1] = log_potentials
c = init[:, :, :].view(ssize, batch * bin_N, K - 1, K - 1, C, C)
lp = big[:, :, :].view(ssize, batch * bin_N, K, C, C)
mask = torch.arange(bin_N) \
.view(1, bin_N).expand(batch, bin_N)
mask = mask >= (lengths - 1).view(batch, 1)
mask = mask.view(batch * bin_N, 1, 1, 1).to(lp.device)
semiring.zero_mask_(lp.data, mask)
semiring.zero_mask_(c.data[:, :, :, 0], (~mask))
c[:, :, : K - 1, 0] = semiring.sum(
torch.stack([c.data[:, :, : K - 1, 0],
lp[:, :, 1:K]], dim=-1)
)
end = torch.min(lengths) - 1
for k in range(1, K - 1):
semiring.one_(init.data[:, :, : end - (k - 1), k - 1, k].diagonal(0, -2, -1))

K_1 = K - 1

Expand Down

0 comments on commit 008e0dc

Please sign in to comment.