diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index c199af3b..5fe45e86 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -128,7 +128,7 @@ def deptree(arc_scores, semiring=LogSemiring, lengths=None): return _unconvert(ret) -def deptree_fromseq(sequence): +def deptree_fromseq(sequence, lengths=None): """ Convert a sequence representation to arcs @@ -138,9 +138,13 @@ def deptree_fromseq(sequence): arcs : b x N x N arc indicators """ batch, N = sequence.shape + if lengths is None: + lengths = torch.LongTensor([N] * batch) labels = torch.zeros(batch, N + 1, N + 1).long() for n in range(1, N): labels[torch.arange(batch), sequence[:, n], n] = 1 + for b in range(batch): + labels[b, lengths[b]:, lengths[b]:] = 0 return _convert(labels) diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 07a6cf73..31768735 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -83,20 +83,26 @@ def hmm(transition, emission, init, observations): return scores -def linearchain_fromseq(sequence, C): +def linearchain_fromseq(sequence, C, lengths=None): """ Convert a sequence representation to edges Parameters: sequence : b x (N+1) long tensor in [0, C-1] + C : number of states + lengths: b long tensor of N values Returns: edge : b x N x C x C markov indicators (t x z_t x z_{t-1}) """ batch, N = sequence.shape labels = torch.zeros(batch, N - 1, C, C).long() + if lengths is None: + lengths = torch.LongTensor([N] * batch) for n in range(1, N): labels[torch.arange(batch), n - 1, sequence[:, n], sequence[:, n - 1]] = 1 + for b in range(batch): + labels[b, lengths[b]:, :, :] = 0 return labels