From f93ac54a3e5a5195bf5fedb08383afec6222068f Mon Sep 17 00:00:00 2001 From: srush Date: Sat, 11 Apr 2020 14:20:19 -0400 Subject: [PATCH] Update linearchain.py --- torch_struct/linearchain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 5c74f8ea..e0541d96 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -104,7 +104,7 @@ def to_parts(sequence, extra, lengths=None): batch, N = sequence.shape labels = torch.zeros(batch, N - 1, C, C).long() if lengths is None: - lengths = torch.LongTensor([N] * batch) + lengths = torch.LongTensor([N] * batch).to(edge.device) for n in range(1, N): labels[torch.arange(batch), n - 1, sequence[:, n], sequence[:, n - 1]] = 1 for b in range(batch):