Skip to content

Commit

Permalink
Update linearchain.py
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Apr 11, 2020
1 parent fba3d5b commit 5e8d662
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torch_struct/linearchain.py
Expand Up @@ -104,7 +104,7 @@ def to_parts(sequence, extra, lengths=None):
batch, N = sequence.shape
labels = torch.zeros(batch, N - 1, C, C).long()
if lengths is None:
lengths = torch.LongTensor([N] * batch).to(edge.device)
lengths = torch.LongTensor([N] * batch)
for n in range(1, N):
labels[torch.arange(batch), n - 1, sequence[:, n], sequence[:, n - 1]] = 1
for b in range(batch):
Expand Down

0 comments on commit 5e8d662

Please sign in to comment.