Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 3, 2019
1 parent 48cc96e commit 51e0073
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
6 changes: 5 additions & 1 deletion torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def deptree(arc_scores, semiring=LogSemiring, lengths=None):
return _unconvert(ret)


def deptree_fromseq(sequence):
def deptree_fromseq(sequence, lengths=None):
"""
Convert a sequence representation to arcs
Expand All @@ -138,9 +138,13 @@ def deptree_fromseq(sequence):
arcs : b x N x N arc indicators
"""
batch, N = sequence.shape
if lengths is None:
lengths = torch.LongTensor([N] * batch)
labels = torch.zeros(batch, N + 1, N + 1).long()
for n in range(1, N):
labels[torch.arange(batch), sequence[:, n], n] = 1
for b in range(batch):
labels[b, lengths[b]:, lengths[b]:] = 0
return _convert(labels)


Expand Down
8 changes: 7 additions & 1 deletion torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,20 +83,26 @@ def hmm(transition, emission, init, observations):
return scores


def linearchain_fromseq(sequence, C):
def linearchain_fromseq(sequence, C, lengths=None):
"""
Convert a sequence representation to edges
Parameters:
sequence : b x (N+1) long tensor in [0, C-1]
C : number of states
lengths: b long tensor of N values
Returns:
edge : b x N x C x C markov indicators
(t x z_t x z_{t-1})
"""
batch, N = sequence.shape
labels = torch.zeros(batch, N - 1, C, C).long()
if lengths is None:
lengths = torch.LongTensor([N] * batch)
for n in range(1, N):
labels[torch.arange(batch), n - 1, sequence[:, n], sequence[:, n - 1]] = 1
for b in range(batch):
labels[b, lengths[b]:, :, :] = 0
return labels


Expand Down

0 comments on commit 51e0073

Please sign in to comment.