Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Nov 26, 2019
1 parent a60df6b commit e92deaa
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,19 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
"Compute forward pass by linear scan"
# Setup
semiring = self.semiring
ssize = semiring.size()
log_potentials, batch, N, C, lengths = self._check_potentials(
log_potentials, lengths
)
log_N, bin_N = self._bin_length(N - 1)
chart = self._chart((batch, bin_N, C, C), log_potentials, force_grad)

# Init
m = torch.min(lengths)
semiring.one_(chart[:, :, m - 1 :].diagonal(0, 3, 4))
semiring.one_(chart[:, :, :].diagonal(0, 3, 4))

# Length mask
big = torch.zeros(
log_potentials.shape[0],
ssize,
batch,
bin_N,
C,
Expand All @@ -71,13 +71,14 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
device=log_potentials.device,
)
big[:, :, : N - 1] = log_potentials
c = chart[:, :, :].view(chart.shape[0], batch * bin_N, C, C)
lp = big[:, :, :].view(chart.shape[0], batch * bin_N, C, C)

c = chart[:, :, :].view(ssize, batch * bin_N, C, C)
lp = big[:, :, :].view(ssize, batch * bin_N, C, C)
mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N)
mask = mask < (lengths - 1).view(batch, 1)
c[:, mask.view(-1)] = lp[:, mask.view(-1)]
# c[:, :] = lp[:, :]
mask = mask >= (lengths - 1).view(batch, 1)
lp.data[:, mask.view(-1)] = semiring.zero
c.data[:, (1- mask).view(-1)] = semiring.zero
c[:] = semiring.sum(torch.stack([c, lp], dim=-1))


# Scan
for n in range(1, log_N + 1):
Expand Down

0 comments on commit e92deaa

Please sign in to comment.