diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index 37bcdec4..c199af3b 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -144,6 +144,23 @@ def deptree_fromseq(sequence): return _convert(labels) +def deptree_toseq(arcs): + """ + Convert a arc representation to sequence + + Parameters: + arcs : b x N x N arc indicators + Returns: + sequence : b x N long tensor in [0, N-1] + """ + batch, N, _ = arcs.shape + labels = torch.zeros(batch, N).long() + on = arcs.nonzero() + for i in range(on.shape[0]): + labels[on[i][0], on[i][2]] = on[i][1] + return labels + + def deptree_nonproj(arc_scores, eps=1e-5): """ Compute the marginals of a non-projective dependency tree using the