Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 25, 2019
1 parent 4c90826 commit a60df6b
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
device=log_potentials.device,
)
big[:, :, : N - 1] = log_potentials
c = chart[:, :, :].view(chart.shape[0], (batch * bin_N), C, C)
c = chart[:, :, :].view(chart.shape[0], batch * bin_N, C, C)
lp = big[:, :, :].view(chart.shape[0], batch * bin_N, C, C)

mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N)
mask = mask < (lengths - 1).view(batch, 1)
# c[:, mask.view(-1)] = lp[:, mask.view(-1)]
c[:, :] = lp[:, :]
c[:, mask.view(-1)] = lp[:, mask.view(-1)]
# c[:, :] = lp[:, :]

# Scan
for n in range(1, log_N + 1):
Expand Down

0 comments on commit a60df6b

Please sign in to comment.