Skip to content

Commit

Permalink
fix variable names and add tensor check
Browse files Browse the repository at this point in the history
  • Loading branch information
chijames committed Jan 10, 2021
1 parent 120842d commit cb42332
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions torch_struct/deptree.py
Expand Up @@ -204,8 +204,10 @@ def deptree_part(arc_scores, multi_root, lengths, eps=1e-5):
if lengths:
batch, N, N = arc_scores.shape
x = torch.arange(N).expand(batch, N)
length = torch.tensor(lengths).unsqueeze(1)
x = x < length
if not torch.is_tensor(lengths):
lengths = torch.tensor(lengths)
lengths = lengths.unsqueeze(1)
x = x < lengths
x = x.unsqueeze(2).expand(-1, -1, N)
mask = torch.transpose(x, 1, 2) * x
mask = mask.float()
Expand Down Expand Up @@ -243,11 +245,13 @@ def deptree_nonproj(arc_scores, multi_root, lengths, eps=1e-5):
Returns:
arc_marginals : b x N x N.
"""
if lengths:
if lengths is not None:
batch, N, N = arc_scores.shape
x = torch.arange(N).expand(batch, N)
length = torch.tensor(lengths).unsqueeze(1)
x = x < length
if not torch.is_tensor(lengths):
lengths = torch.tensor(lengths)
lengths = lengths.unsqueeze(1)
x = x < lengths
x = x.unsqueeze(2).expand(-1, -1, N)
mask = torch.transpose(x, 1, 2) * x
mask = mask.float()
Expand Down

0 comments on commit cb42332

Please sign in to comment.