Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sasha committed Nov 20, 2019
1 parent a64688e commit 809a71d
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 50 deletions.
2 changes: 1 addition & 1 deletion torch_struct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@
Alignment,
CheckpointSemiring,
CheckpointShardSemiring,
TempMax
TempMax,
]
35 changes: 10 additions & 25 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def _dp(self, scores, lengths=None, force_grad=False):
lengths = torch.LongTensor([N] * batch)

# Charts
beta = [Chart((batch, N, N, NT), rules, semiring)
for _ in range(2)]
beta = [Chart((batch, N, N, NT), rules, semiring) for _ in range(2)]
span = [None for _ in range(N)]
v = (ssize, batch)
term_use = terms + 0.0
Expand All @@ -36,11 +35,9 @@ def _dp(self, scores, lengths=None, force_grad=False):
NTs = slice(0, NT)
Ts = slice(NT, S)
rules = rules.view(ssize, batch, 1, NT, S, S)

def arr(a, b):
return rules[..., a, b] \
.contiguous() \
.view(*v + (NT, -1)) \
.transpose(-2, -1)
return rules[..., a, b].contiguous().view(*v + (NT, -1)).transpose(-2, -1)

matmul = semiring.matmul
times = semiring.times
Expand All @@ -51,45 +48,33 @@ def arr(a, b):

for w in range(1, N):
all_span = []
v2 = v +(N - w, -1)
v2 = v + (N - w, -1)

Y = beta[A][: N - w, :w, :]
Z = beta[B][w:, N - w :, :]
X1 = matmul(
matmul(Y.transpose(-2, -1), Z).view(*v2), X_Y_Z
)
X1 = matmul(matmul(Y.transpose(-2, -1), Z).view(*v2), X_Y_Z)
all_span.append(X1)

Y_term = term_use[..., : N - w, :, None]
Z_term = term_use[..., w:, None, :]

Y = Y[...,-1, :].unsqueeze(-1)
X2 = matmul(
times(Y, Z_term).view(*v2), X_Y_Z1
)
Y = Y[..., -1, :].unsqueeze(-1)
X2 = matmul(times(Y, Z_term).view(*v2), X_Y_Z1)

Z = Z[..., 0, :].unsqueeze(-2)
X3 = matmul(
times(Y_term, Z).view(*v2), X_Y1_Z
)
X3 = matmul(times(Y_term, Z).view(*v2), X_Y1_Z)
all_span += [X2, X3]

if w == 1:
X4 = matmul(
times(Y_term, Z_term).view(*v2),
X_Y1_Z1,
)
X4 = matmul(times(Y_term, Z_term).view(*v2), X_Y1_Z1)
all_span.append(X4)

span[w] = semiring.sum(torch.stack(all_span, dim=-1))
beta[A][: N - w, w, :] = span[w]
beta[B][w:N, N - w - 1, :] = span[w]

final = beta[A][0, :, NTs]
top = torch.stack(
[final[:, i, l-1]
for i, l in enumerate(lengths)], dim=1
)
top = torch.stack([final[:, i, l - 1] for i, l in enumerate(lengths)], dim=1)
log_Z = semiring.dot(top, roots)
return semiring.unconvert(log_Z), (term_use, rules, top, span[1:]), beta

Expand Down
13 changes: 4 additions & 9 deletions torch_struct/cky_crf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
# return None, None, z




class CKY_CRF(_Struct):

def _check_potentials(self, edge, lengths=None):
batch, N, _, NT = edge.shape
edge.requires_grad_(True)
Expand All @@ -48,24 +45,22 @@ def _dp(self, scores, lengths=None, force_grad=False):
semiring = self.semiring
scores, batch, N, NT, lengths = self._check_potentials(scores, lengths)


beta = [Chart((batch, N, N), scores, semiring)
for _ in range(2)]
beta = [Chart((batch, N, N), scores, semiring) for _ in range(2)]
L_DIM, R_DIM = 2, 3

# Initialize
reduced_scores = semiring.sum(scores)
term = reduced_scores.diagonal(0, L_DIM, R_DIM)
ns = torch.arange(N)
beta[A][ns, 0] = term
beta[B][ns, N-1] = term
beta[B][ns, N - 1] = term

# Run
for w in range(1, N):
left = slice(None, N-w)
left = slice(None, N - w)
right = slice(w, None)
Y = beta[A][left, :w]
Z = beta[B][right, N-w:]
Z = beta[B][right, N - w :]
score = reduced_scores.diagonal(w, L_DIM, R_DIM)
new = semiring.times(semiring.dot(Y, Z), score)
beta[A][left, w] = new
Expand Down
22 changes: 10 additions & 12 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,19 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False):
arc_scores, batch, N, lengths = self._check_potentials(arc_scores, lengths)
arc_scores.requires_grad_(True)
DIRS = 2
alpha = [[Chart((batch, DIRS, N, N), arc_scores, semiring)
for _ in range(2)] for _ in range(2)]
alpha = [
[Chart((batch, DIRS, N, N), arc_scores, semiring) for _ in range(2)]
for _ in range(2)
]
semiring.one_(alpha[A][C].data[:, :, :, :, 0].data)
semiring.one_(alpha[B][C].data[:, :, :, :, -1].data)

def stack(a, b=None):
if b is None:
return torch.stack([a, a], dim=2)
else:
return torch.stack([a, b], dim=2)


for k in range(1, N):
f = torch.arange(N - k), torch.arange(k, N)
AC = alpha[A][C][:, : N - k, :k]
Expand All @@ -71,28 +73,24 @@ def stack(a, b=None):
BC = alpha[B][C][:, k:, N - k :]
BCL, BCR = BC.unbind(2)
arcs = semiring.dot(
semiring.times(
stack(ACR), stack(BCL)),
stack(arc_scores[:, :, f[1], f[0]],
arc_scores[:, :, f[0], f[1]]).unsqueeze(-1),
semiring.times(stack(ACR), stack(BCL)),
stack(
arc_scores[:, :, f[1], f[0]], arc_scores[:, :, f[0], f[1]]
).unsqueeze(-1),
)
alpha[A][I][:, : N - k, k] = arcs
alpha[B][I][:, k:N, N - k - 1] = arcs

AIR = alpha[A][I][R, : N - k, 1 : k + 1]
BIL = alpha[B][I][L, k:, N - k - 1 : N - 1]
new = semiring.dot(
stack(ACL, AIR),
stack(BIL, BCR),
)
new = semiring.dot(stack(ACL, AIR), stack(BIL, BCR))
alpha[A][C][:, : N - k, k] = new
alpha[B][C][:, k:N, N - k - 1] = new

final = alpha[A][C][R, 0]
v = torch.stack([final[:, i, l] for i, l in enumerate(lengths)], dim=1)
return v, [arc_scores], alpha


def _check_potentials(self, arc_scores, lengths=None):
semiring = self.semiring
batch, N, N2 = arc_scores.shape
Expand Down
6 changes: 3 additions & 3 deletions torch_struct/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def backward(ctx, grad_output):
grad_chart[ctx.indices] += grad_output
return grad_chart, None, None


class Set(torch.autograd.Function):
@staticmethod
def forward(ctx, chart, indices, vals):
Expand All @@ -32,8 +33,7 @@ def backward(ctx, grad_output):


class Chart:
def __init__(self, size, potentials, semiring,
cache=True):
def __init__(self, size, potentials, semiring, cache=True):
self.data = semiring.zero_(
torch.zeros(
*((semiring.size(),) + size),
Expand All @@ -58,13 +58,13 @@ def __setitem__(self, ind, new):
else:
self.data[(I, I) + ind] = new


def get(self, ind):
return Get.apply(self.data, self.grad, ind)

def set(self, ind, new):
self.data = Set.apply(self.data, ind, new)


class _Struct:
def __init__(self, semiring=LogSemiring):
self.semiring = semiring
Expand Down

0 comments on commit 809a71d

Please sign in to comment.