Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Sep 3, 2019
1 parent 9606323 commit f2c9b6a
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
16 changes: 16 additions & 0 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,22 @@ def deptree(arc_scores, semiring=LogSemiring, lengths=None):
return _unconvert(ret)


def deptree_fromseq(sequence):
"""
Convert a sequence representation to arcs
Parameters:
sequence : b x N long tensor in [0, N-1]
Returns:
arcs : b x N x N arc indicators
"""
batch, N = sequence.shape
labels = torch.zeros(batch, N + 1, N + 1).long()
for n in range(1, N):
labels[torch.arange(batch), sequence[:, n], n] = 1
return _convert(labels)


def deptree_nonproj(arc_scores, eps=1e-5):
"""
Compute the marginals of a non-projective dependency tree using the
Expand Down
45 changes: 45 additions & 0 deletions torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,41 @@ def hmm(transition, emission, init, observations):
return scores


def linearchain_fromseq(sequence, C):
"""
Convert a sequence representation to edges
Parameters:
sequence : b x (N+1) long tensor in [0, C-1]
Returns:
edge : b x N x C x C markov indicators
(t x z_t x z_{t-1})
"""
batch, N = sequence.shape
labels = torch.zeros(batch, N - 1, C, C).long()
for n in range(1, N):
labels[torch.arange(batch), n - 1, sequence[:, n], sequence[:, n - 1]] = 1
return labels


def linearchain_toseq(edge):
"""
Convert edges to sequence representation.
Parameters:
edge : b x N x C x C markov indicators
(t x z_t x z_{t-1})
Returns:
sequence : b x (N+1) long tensor in [0, C-1]
"""
batch, N, C, _ = edge.shape
labels = torch.zeros(batch, N + 1).long()
on = edge.nonzero()
for i in range(on.shape[0]):
labels[on[i][0], on[i][1]] = on[i][3]
return labels


### Tests
def linearchain_check(edge, semiring=LogSemiring):
batch, N, C, _ = edge.shape
Expand All @@ -96,4 +131,14 @@ def linearchain_check(edge, semiring=LogSemiring):
)
chains = new_chains

edges = linearchain_fromseq(torch.stack([torch.tensor(c) for (c, _) in chains]), C)
a = (
torch.einsum("ancd,bncd->bancd", edges.float(), edge)
.sum(dim=2)
.sum(dim=2)
.sum(dim=2)
)
a = semiring.sum(a, dim=1)
b = semiring.sum(torch.stack([s for (_, s) in chains]), dim=0)
assert torch.isclose(a, b).all()
return semiring.sum(torch.stack([s for (_, s) in chains]), dim=0)

0 comments on commit f2c9b6a

Please sign in to comment.