Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 26, 2019
1 parent b0dc178 commit 5ab2078
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torch_struct/semimarkov.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def _check_potentials(self, edge, lengths=None):
edge = self.semiring.convert(edge)
N = N_1 + 1
if lengths is None:
lengths = torch.LongTensor([N] * batch)
lengths = torch.LongTensor([N] * batch, device=edge.device)
assert max(lengths) <= N, "Length longer than edge scores"
assert max(lengths) == N, "At least one in batch must be length N"
assert C == C2, "Transition shape doesn't match"
Expand Down Expand Up @@ -50,7 +50,8 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):
big[:, :, : N - 1] = log_potentials
c = init[:, :, :].view(ssize, batch * bin_N, K - 1, K - 1, C, C)
lp = big[:, :, :].view(ssize, batch * bin_N, K, C, C)
mask = torch.arange(bin_N, device=lp.device).view(1, bin_N).expand(batch, bin_N)
mask = torch.arange(bin_N, device=lp.device) \
.view(1, bin_N).expand(batch, bin_N)
mask = mask >= (lengths - 1).view(batch, 1)

semiring.zero_mask_(lp.data, mask.view(-1))
Expand Down

0 comments on commit 5ab2078

Please sign in to comment.