Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 25, 2019
1 parent 04ac887 commit 3ea092f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 35 deletions.
59 changes: 30 additions & 29 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,44 +53,45 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False):
arc_scores.requires_grad_(True)
DIRS = 2
alpha = [
[[Chart((batch, N, N), arc_scores, semiring) for _ in range(2)]
for _ in range(2)] for _ in range(2)
[
[Chart((batch, N, N), arc_scores, semiring) for _ in range(2)]
for _ in range(2)
]
for _ in range(2)
]
semiring.one_(alpha[A][C][L].data[ :, :, :, 0].data)
semiring.one_(alpha[A][C][R].data[ :, :, :, 0].data)
semiring.one_(alpha[B][C][L].data[ :, :, :, -1].data)
semiring.one_(alpha[B][C][R].data[ :, :, :, -1].data)
semiring.one_(alpha[A][C][L].data[:, :, :, 0].data)
semiring.one_(alpha[A][C][R].data[:, :, :, 0].data)
semiring.one_(alpha[B][C][L].data[:, :, :, -1].data)
semiring.one_(alpha[B][C][R].data[:, :, :, -1].data)

# for k in range(1, N):
# f = torch.arange(N - k), torch.arange(k, N)
# ACL = alpha[A][C][L][: N - k, :k]
# ACR = alpha[A][C][R][: N - k, :k]
for k in range(1, N):
f = torch.arange(N - k), torch.arange(k, N)
ACL = alpha[A][C][L][: N - k, :k]
ACR = alpha[A][C][R][: N - k, :k]

# BCL = alpha[B][C][L][k:, N - k :]
# BCR = alpha[B][C][R][k:, N - k :]
# x = semiring.dot(ACR, BCL)
BCL = alpha[B][C][L][k:, N - k :]
BCR = alpha[B][C][R][k:, N - k :]
x = semiring.dot(ACR, BCL)

# arcs_l = semiring.times(
# x, arc_scores[:, :, f[1], f[0]])
arcs_l = semiring.times(x, arc_scores[:, :, f[1], f[0]])

# # alpha[A][I][L][: N - k, k] = arcs_l
# # alpha[B][I][L][k:N, N - k - 1] = arcs_l
alpha[A][I][L][: N - k, k] = arcs_l
alpha[B][I][L][k:N, N - k - 1] = arcs_l

# arcs_r = semiring.times(
# x , arc_scores[:, :, f[0], f[1]])
# # alpha[A][I][R][: N - k, k] = arcs_r
# # alpha[B][I][R][k:N, N - k - 1] = arcs_r
arcs_r = semiring.times(x, arc_scores[:, :, f[0], f[1]])
alpha[A][I][R][: N - k, k] = arcs_r
alpha[B][I][R][k:N, N - k - 1] = arcs_r

# AIR = alpha[A][I][R][: N - k, 1 : k + 1]
# BIL = alpha[B][I][L][ k:, N - k - 1 : N - 1]
AIR = alpha[A][I][R][: N - k, 1 : k + 1]
BIL = alpha[B][I][L][k:, N - k - 1 : N - 1]

# new = semiring.dot(ACL, BIL)
# # alpha[A][C][L][ : N - k, k] = new
# # alpha[B][C][L][ k:N, N - k - 1] = new
new = semiring.dot(ACL, BIL)
alpha[A][C][L][: N - k, k] = new
alpha[B][C][L][k:N, N - k - 1] = new

# new = semiring.dot(AIR, BCR)
# # alpha[A][C][R][ : N - k, k] = new
# # alpha[B][C][R][ k:N, N - k - 1] = new
new = semiring.dot(AIR, BCR)
alpha[A][C][R][: N - k, k] = new
alpha[B][C][R][k:N, N - k - 1] = new

final = alpha[A][C][R][(0,)]
v = torch.stack([final[:, i, l] for i, l in enumerate(lengths)], dim=1)
Expand Down
18 changes: 12 additions & 6 deletions torch_struct/linearchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,24 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):

# Init
m = torch.min(lengths)
semiring.one_(chart[:, :, m-1:].diagonal(0, 3, 4))
semiring.one_(chart[:, :, m - 1 :].diagonal(0, 3, 4))

# Length mask
big = torch.zeros(log_potentials.shape[0], batch, bin_N, C, C,
dtype=log_potentials.dtype,
device=log_potentials.device)
big[:, :, :N-1] = log_potentials
big = torch.zeros(
log_potentials.shape[0],
batch,
bin_N,
C,
C,
dtype=log_potentials.dtype,
device=log_potentials.device,
)
big[:, :, : N - 1] = log_potentials
c = chart[:, :, :].view(chart.shape[0], (batch * bin_N), C, C)
lp = big[:, :, :].view(chart.shape[0], batch * bin_N, C, C)

mask = torch.arange(bin_N).view(1, bin_N).expand(batch, bin_N)
mask = mask < (lengths -1).view(batch, 1)
mask = mask < (lengths - 1).view(batch, 1)
c[:, mask.view(-1)] = lp[:, mask.view(-1)]

# Scan
Expand Down

0 comments on commit 3ea092f

Please sign in to comment.