Skip to content

Commit

Permalink
fix length mask bugs. remove incorrect +1 and fix batch determinant c…
Browse files Browse the repository at this point in the history
…alculation. can be derived using minors and cofactors
  • Loading branch information
chijames committed Jan 13, 2021
1 parent 486c467 commit 10ac289
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions torch_struct/deptree.py
Expand Up @@ -201,31 +201,32 @@ def enumerate(self, arc_scores, non_proj=False, multi_root=True):


def deptree_part(arc_scores, multi_root, lengths, eps=1e-5):
if lengths:
if lengths is not None:
batch, N, N = arc_scores.shape
x = torch.arange(N, device=arc_scores.device).expand(batch, N)
if not torch.is_tensor(lengths):
lengths = torch.tensor(lengths, device=arc_scores.device)
lengths = lengths.unsqueeze(1)
x = x < lengths
det_offset = torch.diag_embed((~x).float())
x = x.unsqueeze(2).expand(-1, -1, N)
mask = torch.transpose(x, 1, 2) * x
mask = mask.float()
mask[mask==0] = float('-inf')
mask[mask==1] = 0
arc_scores = arc_scores + mask

input = arc_scores
eye = torch.eye(input.shape[1], device=input.device)
laplacian = input.exp() + eps
lap = laplacian.masked_fill(eye != 0, 0)
lap = -lap + torch.diag_embed(lap.sum(1), offset=0, dim1=-2, dim2=-1)
lap += det_offset

if multi_root:
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
lap = lap + torch.diag_embed(rss, offset=0, dim1=-2, dim2=-1)
else:
lap[:, 0] = torch.diagonal(input, 0, -2, -1).exp()

return lap.logdet()


Expand All @@ -252,17 +253,21 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
lengths = torch.tensor(lengths, device=arc_scores.device)
lengths = lengths.unsqueeze(1)
x = x < lengths
det_offset = torch.diag_embed((~x).float())
x = x.unsqueeze(2).expand(-1, -1, N)
mask = torch.transpose(x, 1, 2) * x
mask = mask.float()
mask[mask==0] = float('-inf')
mask[mask==1] = 0
arc_scores = arc_scores + mask

input = arc_scores
eye = torch.eye(input.shape[1], device=input.device)
laplacian = input.exp() + eps
lap = laplacian.masked_fill(eye != 0, 0)
lap = -lap + torch.diag_embed(lap.sum(1), offset=0, dim1=-2, dim2=-1)
lap += det_offset

if multi_root:
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
lap = lap + torch.diag_embed(rss, offset=0, dim1=-2, dim2=-1)
Expand Down

0 comments on commit 10ac289

Please sign in to comment.