Skip to content

Commit

Permalink
fix if lenghts bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chijames committed Jan 15, 2021
1 parent 10ac289 commit 6f55e83
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions torch_struct/deptree.py
Expand Up @@ -220,7 +220,8 @@ def deptree_part(arc_scores, multi_root, lengths, eps=1e-5):
laplacian = input.exp() + eps
lap = laplacian.masked_fill(eye != 0, 0)
lap = -lap + torch.diag_embed(lap.sum(1), offset=0, dim1=-2, dim2=-1)
lap += det_offset
if lengths is not None:
lap += det_offset

if multi_root:
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
Expand Down Expand Up @@ -266,7 +267,8 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
laplacian = input.exp() + eps
lap = laplacian.masked_fill(eye != 0, 0)
lap = -lap + torch.diag_embed(lap.sum(1), offset=0, dim1=-2, dim2=-1)
lap += det_offset
if lengths is not None:
lap += det_offset

if multi_root:
rss = torch.diagonal(input, 0, -2, -1).exp() # root selection scores
Expand Down

0 comments on commit 6f55e83

Please sign in to comment.