From a60df6bb0201bfafd1b0bd8d057db5643d876492 Mon Sep 17 00:00:00 2001 From: Sasha Date: Mon, 25 Nov 2019 16:52:20 -0500 Subject: [PATCH] . --- torch_struct/linearchain.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 4a0af35b..94b64607 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -71,13 +71,13 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False): device=log_potentials.device, ) big[:, :, : N - 1] = log_potentials - c = chart[:, :, :].view(chart.shape[0], (batch * bin_N), C, C) + c = chart[:, :, :].view(chart.shape[0], batch * bin_N, C, C) lp = big[:, :, :].view(chart.shape[0], batch * bin_N, C, C) mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N) mask = mask < (lengths - 1).view(batch, 1) - # c[:, mask.view(-1)] = lp[:, mask.view(-1)] - c[:, :] = lp[:, :] + c[:, mask.view(-1)] = lp[:, mask.view(-1)] + # c[:, :] = lp[:, :] # Scan for n in range(1, log_N + 1):