Skip to content

Commit

Permalink
Merge 120842d into 9f93432
Browse files Browse the repository at this point in the history
  • Loading branch information
chijames committed Jan 6, 2021
2 parents 9f93432 + 120842d commit 5eceb92
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 25 deletions.
86 changes: 65 additions & 21 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,17 +200,34 @@ def enumerate(self, arc_scores, non_proj=False, multi_root=True):
return semiring.sum(torch.stack(parses, dim=-1)), None


def deptree_part(arc_scores, eps=1e-5):
def deptree_part(arc_scores, multi_root, lengths, eps=1e-5):
if lengths:
batch, N, N = arc_scores.shape
x = torch.arange(N).expand(batch, N)
length = torch.tensor(lengths).unsqueeze(1)
x = x < length
x = x.unsqueeze(2).expand(-1, -1, N)
mask = torch.transpose(x, 1, 2) * x
mask = mask.float()
mask[mask==0] = float('-inf')
arc_scores = arc_scores + mask

input = arc_scores
eye = torch.eye(input.shape[1], device=input.device)
laplacian = input.exp() + eps
lap = laplacian.masked_fill(eye != 0, 0)
lap = -lap + torch.diag_embed(lap.sum(1), offset=0, dim1=-2, dim2=-1)
lap[:, 0] = torch.diagonal(input, 0, -2, -1).exp()
return lap.logdet()


def deptree_nonproj(arc_scores, eps=1e-5):
if multi_root:
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
lap = lap + torch.diag_embed(rss, offset=0, dim1=-2, dim2=-1)
else:
lap[:, 0] = torch.diagonal(input, 0, -2, -1).exp()

return lap.logdet()


def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
"""
Compute the marginals of a non-projective dependency tree using the
matrix-tree theorem.
Expand All @@ -226,28 +243,55 @@ def deptree_nonproj(arc_scores, eps=1e-5):
Returns:
arc_marginals : b x N x N.
"""
if lengths:
batch, N, N = arc_scores.shape
x = torch.arange(N).expand(batch, N)
length = torch.tensor(lengths).unsqueeze(1)
x = x < length
x = x.unsqueeze(2).expand(-1, -1, N)
mask = torch.transpose(x, 1, 2) * x
mask = mask.float()
mask[mask==0] = float('-inf')
arc_scores = arc_scores + mask

input = arc_scores
eye = torch.eye(input.shape[1], device=input.device)
laplacian = input.exp() + eps
lap = laplacian.masked_fill(eye != 0, 0)
lap = -lap + torch.diag_embed(lap.sum(1), offset=0, dim1=-2, dim2=-1)
lap[:, 0] = torch.diagonal(input, 0, -2, -1).exp()
inv_laplacian = lap.inverse()
factor = (
torch.diagonal(inv_laplacian, 0, -2, -1)
.unsqueeze(2)
.expand_as(input)
.transpose(1, 2)
)
term1 = input.exp().mul(factor).clone()
term2 = input.exp().mul(inv_laplacian.transpose(1, 2)).clone()
term1[:, :, 0] = 0
term2[:, 0] = 0
output = term1 - term2
roots_output = (
torch.diagonal(input, 0, -2, -1).exp().mul(inv_laplacian.transpose(1, 2)[:, 0])
)
if multi_root:
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
lap = lap + torch.diag_embed(rss, offset=0, dim1=-2, dim2=-1)
inv_laplacian = lap.inverse()
factor = (
torch.diagonal(inv_laplacian, 0, -2, -1)
.unsqueeze(2)
.expand_as(input)
.transpose(1, 2)
)
term1 = input.exp().mul(factor).clone()
term2 = input.exp().mul(inv_laplacian.transpose(1, 2)).clone()
output = term1 - term2
roots_output = (
torch.diagonal(input, 0, -2, -1).exp().mul(torch.diagonal(inv_laplacian.transpose(1, 2), 0, -2, -1))
)
else:
lap[:, 0] = torch.diagonal(input, 0, -2, -1).exp()
inv_laplacian = lap.inverse()
factor = (
torch.diagonal(inv_laplacian, 0, -2, -1)
.unsqueeze(2)
.expand_as(input)
.transpose(1, 2)
)
term1 = input.exp().mul(factor).clone()
term2 = input.exp().mul(inv_laplacian.transpose(1, 2)).clone()
term1[:, :, 0] = 0
term2[:, 0] = 0
output = term1 - term2
roots_output = (
torch.diagonal(input, 0, -2, -1).exp().mul(inv_laplacian.transpose(1, 2)[:, 0])
)
output = output + torch.diag_embed(roots_output, 0, -2, -1)
return output

Expand Down
10 changes: 6 additions & 4 deletions torch_struct/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,10 @@ class NonProjectiveDependencyCRF(StructDistribution):
Note: Does not currently implement argmax (Chiu-Liu) or sampling.
"""

struct = DepTree
def __init__(self, log_potentials, lengths=None, args={}, multiroot=False):
super(NonProjectiveDependencyCRF, self).__init__(log_potentials, lengths, args)
self.multiroot = multiroot


@lazy_property
def marginals(self):
Expand All @@ -474,7 +476,7 @@ def marginals(self):
Returns:
marginals (*batch_shape x event_shape*)
"""
return deptree_nonproj(self.log_potentials)
return deptree_nonproj(self.log_potentials, self.multiroot, self.lengths)

def sample(self, sample_shape=torch.Size()):
raise NotImplementedError()
Expand All @@ -484,7 +486,7 @@ def partition(self):
"""
Compute the partition function.
"""
return deptree_part(self.log_potentials)
return deptree_part(self.log_potentials, self.multiroot, self.lengths)

@lazy_property
def argmax(self):
Expand Down

0 comments on commit 5eceb92

Please sign in to comment.