Skip to content

Commit

Permalink
Refactor the chart data structure to allow for faster slicing. (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Nov 20, 2019
1 parent 905dd78 commit f34c2e9
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 135 deletions.
2 changes: 1 addition & 1 deletion torch_struct/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@
Alignment,
CheckpointSemiring,
CheckpointShardSemiring,
TempMax
TempMax,
]
101 changes: 35 additions & 66 deletions torch_struct/cky.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from .helpers import _Struct
from .helpers import _Struct, Chart

A, B = 0, 1

Expand All @@ -18,94 +18,63 @@ def _dp(self, scores, lengths=None, force_grad=False):
S = NT + T

terms, rules, roots = (
semiring.convert(terms),
semiring.convert(rules),
semiring.convert(roots),
semiring.convert(terms).requires_grad_(True),
semiring.convert(rules).requires_grad_(True),
semiring.convert(roots).requires_grad_(True),
)
if lengths is None:
lengths = torch.LongTensor([N] * batch)

# Charts
beta = self._make_chart(2, (batch, N, N, NT), rules, force_grad)
span = self._make_chart(N, (batch, N, NT), rules, force_grad)

# Terminals and Tops
top = self._chart((batch, NT), rules, force_grad)
term_use = self._chart((batch, N, T), terms, force_grad)
term_use[:] = terms + 0.0
beta = [Chart((batch, N, N, NT), rules, semiring) for _ in range(2)]
span = [None for _ in range(N)]
v = (ssize, batch)
term_use = terms + 0.0

# Split into NT/T groups
NTs = slice(0, NT)
Ts = slice(NT, S)
rules = rules.view(ssize, batch, 1, NT, S, S)
X_Y_Z = (
rules[..., NTs, NTs]
.contiguous()
.view(ssize, batch, NT, NT * NT)
.transpose(-2, -1)
)
X_Y1_Z = (
rules[..., Ts, NTs]
.contiguous()
.view(ssize, batch, NT, T * NT)
.transpose(-2, -1)
)
X_Y_Z1 = (
rules[..., NTs, Ts]
.contiguous()
.view(ssize, batch, NT, NT * T)
.transpose(-2, -1)
)
X_Y1_Z1 = (
rules[..., Ts, Ts]
.contiguous()
.view(ssize, batch, NT, T * T)
.transpose(-2, -1)
)

# Here
def arr(a, b):
return rules[..., a, b].contiguous().view(*v + (NT, -1)).transpose(-2, -1)

matmul = semiring.matmul
times = semiring.times
X_Y_Z = arr(NTs, NTs)
X_Y1_Z = arr(Ts, NTs)
X_Y_Z1 = arr(NTs, Ts)
X_Y1_Z1 = arr(Ts, Ts)

for w in range(1, N):
all_span = []
v2 = v + (N - w, -1)

Y = beta[A][: N - w, :w, :]
Z = beta[B][w:, N - w :, :]
X1 = matmul(matmul(Y.transpose(-2, -1), Z).view(*v2), X_Y_Z)
all_span.append(X1)

Y = beta[A][..., : N - w, :w, :].transpose(-2, -1)
Z = beta[B][..., w:, N - w :, :]
all_span.append(
semiring.matmul(
semiring.matmul(Y, Z).view(ssize, batch, N - w, NT * NT), X_Y_Z
)
)
Y_term = term_use[..., : N - w, :, None]
Z_term = term_use[..., w:, None, :]

Y = beta[A][..., : N - w, w - 1, :, None]
all_span.append(
semiring.matmul(
semiring.times(Y, Z_term).view(ssize, batch, N - w, T * NT), X_Y_Z1
)
)
Y = Y[..., -1, :].unsqueeze(-1)
X2 = matmul(times(Y, Z_term).view(*v2), X_Y_Z1)

Z = beta[B][..., w:, N - w, None, :]
all_span.append(
semiring.matmul(
semiring.times(Y_term, Z).view(ssize, batch, N - w, NT * T), X_Y1_Z
)
)
Z = Z[..., 0, :].unsqueeze(-2)
X3 = matmul(times(Y_term, Z).view(*v2), X_Y1_Z)
all_span += [X2, X3]

if w == 1:
all_span.append(
semiring.matmul(
semiring.times(Y_term, Z_term).view(ssize, batch, N - w, T * T),
X_Y1_Z1,
)
)
X4 = matmul(times(Y_term, Z_term).view(*v2), X_Y1_Z1)
all_span.append(X4)

span[w] = semiring.sum(torch.stack(all_span, dim=-1))
beta[A][..., : N - w, w, :] = span[w]
beta[B][..., w:N, N - w - 1, :] = beta[A][..., : N - w, w, :]
beta[A][: N - w, w, :] = span[w]
beta[B][w:N, N - w - 1, :] = span[w]

top[:] = torch.stack(
[beta[A][:, i, 0, l - 1, NTs] for i, l in enumerate(lengths)], dim=1
)
final = beta[A][0, :, NTs]
top = torch.stack([final[:, i, l - 1] for i, l in enumerate(lengths)], dim=1)
log_Z = semiring.dot(top, roots)
return semiring.unconvert(log_Z), (term_use, rules, top, span[1:]), beta

Expand Down
49 changes: 40 additions & 9 deletions torch_struct/cky_crf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,36 @@
import torch
from .helpers import _Struct
from .helpers import _Struct, Chart

A, B = 0, 1


# class Get(torch.autograd.Function):
# @staticmethod
# def forward(ctx, chart, grad_chart, indices):
# out = chart[indices]
# ctx.save_for_backward(grad_chart)
# ctx.indices = indices
# return out

# @staticmethod
# def backward(ctx, grad_output):
# grad_chart, = ctx.saved_tensors
# grad_chart[ctx.indices] += grad_output
# return grad_chart, None, None

# class Set(torch.autograd.Function):
# @staticmethod
# def forward(ctx, chart, indices, vals):
# chart[indices] = vals
# ctx.indices = indices
# return chart

# @staticmethod
# def backward(ctx, grad_output):
# z = grad_output[ctx.indices]
# return None, None, z


class CKY_CRF(_Struct):
def _check_potentials(self, edge, lengths=None):
batch, N, _, NT = edge.shape
Expand All @@ -17,25 +44,29 @@ def _check_potentials(self, edge, lengths=None):
def _dp(self, scores, lengths=None, force_grad=False):
semiring = self.semiring
scores, batch, N, NT, lengths = self._check_potentials(scores, lengths)
beta = self._make_chart(2, (batch, N, N), scores, force_grad)

beta = [Chart((batch, N, N), scores, semiring) for _ in range(2)]
L_DIM, R_DIM = 2, 3

# Initialize
reduced_scores = semiring.sum(scores)
term = reduced_scores.diagonal(0, L_DIM, R_DIM)
ns = torch.arange(N)
beta[A][:, :, ns, 0] = term
beta[B][:, :, ns, N - 1] = term
beta[A][ns, 0] = term
beta[B][ns, N - 1] = term

# Run
for w in range(1, N):
Y = beta[A][:, :, : N - w, :w]
Z = beta[B][:, :, w:, N - w :]
left = slice(None, N - w)
right = slice(w, None)
Y = beta[A][left, :w]
Z = beta[B][right, N - w :]
score = reduced_scores.diagonal(w, L_DIM, R_DIM)
beta[A][:, :, : N - w, w] = semiring.times(semiring.dot(Y, Z), score)
beta[B][:, :, w:N, N - w - 1] = beta[A][:, :, : N - w, w]
new = semiring.times(semiring.dot(Y, Z), score)
beta[A][left, w] = new
beta[B][right, N - w - 1] = new

final = beta[A][:, :, 0]
final = beta[A][0, :]
log_Z = final[:, torch.arange(batch), lengths - 1]
return log_Z, [scores], beta

Expand Down
85 changes: 31 additions & 54 deletions torch_struct/deptree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import itertools
from .helpers import _Struct
from .helpers import _Struct, Chart


def _convert(logits):
Expand Down Expand Up @@ -50,69 +50,46 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False):
semiring = self.semiring
arc_scores = _convert(arc_scores_in)
arc_scores, batch, N, lengths = self._check_potentials(arc_scores, lengths)

arc_scores.requires_grad_(True)
DIRS = 2
alpha = [
self._make_chart(2, (DIRS, batch, N, N), arc_scores, force_grad)
[Chart((batch, DIRS, N, N), arc_scores, semiring) for _ in range(2)]
for _ in range(2)
]
semiring.one_(alpha[A][C].data[:, :, :, :, 0].data)
semiring.one_(alpha[B][C].data[:, :, :, :, -1].data)

def stack(a, b):
return torch.stack([a, b], dim=1)

def sstack(a):
return torch.stack([a, a], dim=1)

# Inside step. assumes first token is root symbol
semiring.one_(alpha[A][C][:, :, :, :, 0].data)
semiring.one_(alpha[B][C][:, :, :, :, -1].data)
k = 0

AIR = alpha[A][I][:, R, :, : N - k, 1:k]
BIL = alpha[B][I][:, L, :, k:N, N - k : N - 1]
k = 1
AC2 = alpha[A][C][:, :, :, : N - k, :k]
BC2 = alpha[B][C][:, :, :, k:, N - k :]
AC, BC, AC_next = None, None, None
def stack(a, b=None):
if b is None:
return torch.stack([a, a], dim=2)
else:
return torch.stack([a, b], dim=2)

ends = [None]
for k in range(1, N):

def tf(a):
return torch.narrow(a, 3, 0, N - k)

def tb(a):
return torch.narrow(a, 3, 1, N - k)

f = torch.arange(N - k), torch.arange(k, N)
if k > 1:
AC2 = torch.cat([tf(AC), tf(AC_next).unsqueeze(-1)], dim=4)
if k > 1:
BC2 = torch.cat([tb(AC_next).unsqueeze(-1), tb(BC)], dim=4)

ACL, ACR = AC2.unbind(dim=1)
BCL, BCR = BC2.unbind(dim=1)
start = semiring.dot(BCL, ACR)

arcsL = semiring.times(start, arc_scores[:, :, f[1], f[0]])
arcsR = semiring.times(start, arc_scores[:, :, f[0], f[1]])

AIR2 = torch.cat(
[torch.narrow(AIR, 2, 0, N - k), arcsR.unsqueeze(-1)], dim=3
)
BIL2 = torch.cat(
[arcsL.unsqueeze(-1), torch.narrow(BIL, 2, 1, N - k)], dim=3
AC = alpha[A][C][:, : N - k, :k]
ACL, ACR = AC.unbind(2)

BC = alpha[B][C][:, k:, N - k :]
BCL, BCR = BC.unbind(2)
arcs = semiring.dot(
semiring.times(stack(ACR), stack(BCL)),
stack(
arc_scores[:, :, f[1], f[0]], arc_scores[:, :, f[0], f[1]]
).unsqueeze(-1),
)
AC_next = stack(semiring.dot(ACL, BIL2), semiring.dot(AIR2, BCR))

ends.append(AC_next[:, R, :, 0])
AC = AC2
BC = BC2
AIR = AIR2
BIL = BIL2
v = torch.stack([ends[l][:, i] for i, l in enumerate(lengths)], dim=1)
return (v, [arc_scores], alpha)
alpha[A][I][:, : N - k, k] = arcs
alpha[B][I][:, k:N, N - k - 1] = arcs

AIR = alpha[A][I][R, : N - k, 1 : k + 1]
BIL = alpha[B][I][L, k:, N - k - 1 : N - 1]
new = semiring.dot(stack(ACL, AIR), stack(BIL, BCR))
alpha[A][C][:, : N - k, k] = new
alpha[B][C][:, k:N, N - k - 1] = new

final = alpha[A][C][R, 0]
v = torch.stack([final[:, i, l] for i, l in enumerate(lengths)], dim=1)
return v, [arc_scores], alpha

def _check_potentials(self, arc_scores, lengths=None):
semiring = self.semiring
Expand Down
61 changes: 57 additions & 4 deletions torch_struct/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,65 @@
from torch.autograd import Function


# def roll(a, b, N, k, gap=0):
# return (a[:, : N - (k + gap), (k + gap) :], b[:, k + gap :, : N - (k + gap)])
class Get(torch.autograd.Function):
@staticmethod
def forward(ctx, chart, grad_chart, indices):
ctx.save_for_backward(grad_chart)
out = chart[indices]
ctx.indices = indices
return out

@staticmethod
def backward(ctx, grad_output):
grad_chart, = ctx.saved_tensors
grad_chart[ctx.indices] += grad_output
return grad_chart, None, None


class Set(torch.autograd.Function):
@staticmethod
def forward(ctx, chart, indices, vals):
chart[indices] = vals
ctx.indices = indices
return chart

@staticmethod
def backward(ctx, grad_output):
z = grad_output[ctx.indices]
return None, None, z


class Chart:
def __init__(self, size, potentials, semiring, cache=True):
self.data = semiring.zero_(
torch.zeros(
*((semiring.size(),) + size),
dtype=potentials.dtype,
device=potentials.device
)
)
self.grad = self.data.detach().clone().fill_(0.0)
self.cache = cache

def __getitem__(self, ind):
I = slice(None)
if self.cache:
return Get.apply(self.data, self.grad, (I, I) + ind)
else:
return self.data[(I, I) + ind]

def __setitem__(self, ind, new):
I = slice(None)
if self.cache:
self.data = Set.apply(self.data, (I, I) + ind, new)
else:
self.data[(I, I) + ind] = new

def get(self, ind):
return Get.apply(self.data, self.grad, ind)

# def roll2(a, b, N, k, gap=0):
# return (a[:, :, : N - (k + gap), (k + gap) :], b[:, :, k + gap :, : N - (k + gap)])
def set(self, ind, new):
self.data = Set.apply(self.data, ind, new)


class _Struct:
Expand Down

0 comments on commit f34c2e9

Please sign in to comment.