diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 5c74f8ea..e0541d96 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -104,7 +104,7 @@ def to_parts(sequence, extra, lengths=None): batch, N = sequence.shape labels = torch.zeros(batch, N - 1, C, C).long() if lengths is None: - lengths = torch.LongTensor([N] * batch) + lengths = torch.LongTensor([N] * batch).to(edge.device) for n in range(1, N): labels[torch.arange(batch), n - 1, sequence[:, n], sequence[:, n - 1]] = 1 for b in range(batch):