Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 9, 2019
1 parent 7340e0e commit 0bbf279
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ def _convert(logits):
def _unconvert(logits):
"Move root arcs to diagonal"
new_logits = torch.zeros(
logits.size(0), logits.size(1) - 1, logits.size(2) - 1
).type_as(logits.data)
logits.size(0), logits.size(1) - 1, logits.size(2) - 1,
dtype=logits.dtype
)

new_logits.fill_(-1e9)
new_logits[:, :, :] = logits[:, 1:, 1:]
N = new_logits.size(1)
Expand Down Expand Up @@ -206,7 +208,7 @@ def sstack(a):
v = alpha[A][C][R, :, 0, 0]
left = semiring.times(alpha[A][I][L, :, :, :], alpha_in[A][I][L, :, :, :])
right = semiring.times(alpha[A][I][R, :, :, :], alpha_in[A][I][R, :, :, :])
ret = torch.zeros(batch, N, N).type_as(left)
ret = torch.zeros(batch, N, N, dtype=left.dtype)
for k in torch.arange(N):
f = torch.arange(N - k), torch.arange(k, N)
ret[:, f[1], k] = left[:, k, f[0]]
Expand All @@ -219,7 +221,7 @@ def _arrange_marginals(self, grads):
batch, N = grads[0][0].shape
N = N + 1

ret = torch.zeros(batch, N, N).type_as(grads[0][0])
ret = torch.zeros(batch, N, N, dtype=grads[0][0].dtype)
# for k in torch.arange(N):
# f = torch.arange(N - k), torch.arange(k, N)
# ret[:, f[1], k] = grad[L][:, k, f[0]]
Expand Down

0 comments on commit 0bbf279

Please sign in to comment.