Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Sep 3, 2019
1 parent f2c9b6a commit 60abbfb
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,23 @@ def deptree_fromseq(sequence):
return _convert(labels)


def deptree_toseq(arcs):
"""
Convert a arc representation to sequence
Parameters:
arcs : b x N x N arc indicators
Returns:
sequence : b x N long tensor in [0, N-1]
"""
batch, N, _ = arcs.shape
labels = torch.zeros(batch, N).long()
on = arcs.nonzero()
for i in range(on.shape[0]):
labels[on[i][0], on[i][2]] = on[i][1]
return labels


def deptree_nonproj(arc_scores, eps=1e-5):
"""
Compute the marginals of a non-projective dependency tree using the
Expand Down

0 comments on commit 60abbfb

Please sign in to comment.