From 51e0073dbcf01fbdfa931add413f9cedbccb14d6 Mon Sep 17 00:00:00 2001 From: srush Date: Tue, 3 Sep 2019 14:09:09 +0000 Subject: [PATCH] . --- torch_struct/deptree.py | 6 +++++- torch_struct/linearchain.py | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index c199af3b..5fe45e86 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -128,7 +128,7 @@ def deptree(arc_scores, semiring=LogSemiring, lengths=None): return _unconvert(ret) -def deptree_fromseq(sequence): +def deptree_fromseq(sequence, lengths=None): """ Convert a sequence representation to arcs @@ -138,9 +138,13 @@ def deptree_fromseq(sequence): arcs : b x N x N arc indicators """ batch, N = sequence.shape + if lengths is None: + lengths = torch.LongTensor([N] * batch) labels = torch.zeros(batch, N + 1, N + 1).long() for n in range(1, N): labels[torch.arange(batch), sequence[:, n], n] = 1 + for b in range(batch): + labels[b, lengths[b]:, lengths[b]:] = 0 return _convert(labels) diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 07a6cf73..31768735 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -83,20 +83,26 @@ def hmm(transition, emission, init, observations): return scores -def linearchain_fromseq(sequence, C): +def linearchain_fromseq(sequence, C, lengths=None): """ Convert a sequence representation to edges Parameters: sequence : b x (N+1) long tensor in [0, C-1] + C : number of states + lengths: b long tensor of N values Returns: edge : b x N x C x C markov indicators (t x z_t x z_{t-1}) """ batch, N = sequence.shape labels = torch.zeros(batch, N - 1, C, C).long() + if lengths is None: + lengths = torch.LongTensor([N] * batch) for n in range(1, N): labels[torch.arange(batch), n - 1, sequence[:, n], sequence[:, n - 1]] = 1 + for b in range(batch): + labels[b, lengths[b]:, :, :] = 0 return labels