Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Sep 22, 2019
1 parent 4cd39be commit aa95f45
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 200 deletions.
127 changes: 5 additions & 122 deletions torch_struct/cky.py
@@ -1,52 +1,11 @@
import torch
from .helpers import _Struct
from .semirings import LogSemiring
from torch.autograd import Function

A, B = 0, 1


# class DPManual2(Function):
# @staticmethod
# def forward(ctx, obj, terms, rules, roots, lengths):
# with torch.no_grad():
# v, _, alpha = obj._dp((terms, rules, roots), lengths, False)
# ctx.obj = obj
# ctx.lengths = lengths
# ctx.alpha = alpha
# ctx.v = v
# ctx.save_for_backward(terms, rules, roots)
# return v

# @staticmethod
# def backward(ctx, grad_v):
# terms, rules, roots = ctx.saved_tensors
# with torch.no_grad():
# marginals = ctx.obj._dp_backward(
# (terms, rules, roots), ctx.lengths, ctx.alpha, ctx.v
# )
# return None, marginals[0], marginals[1].sum(1).sum(1), marginals[2], None


class CKY(_Struct):
# def sum(self, scores, lengths=None, force_grad=False, _autograd=True):
# """
# Compute the inside pass of a CFG using CKY.

# Parameters:
# terms : b x n x T
# rules : b x NT x (NT+T) x (NT+T)
# root: b x NT

# Returns:
# v: b tensor of total sum
# spans: list of N, b x N x (NT+t)
# """
# if _autograd or self.semiring is not LogSemiring:
# return self._dp(scores, lengths)[0]
# else:
# return DPManual2.apply(self, *scores, lengths=lengths)

def _dp(self, scores, lengths=None, force_grad=False):
terms, rules, roots = scores
semiring = self.semiring
Expand Down Expand Up @@ -111,84 +70,6 @@ def _dp(self, scores, lengths=None, force_grad=False):
log_Z = semiring.dot(top, roots)
return semiring.unconvert(log_Z), (term_use, rule_use, top), beta

# def _dp_backward(self, scores, lengths, alpha_in, v, force_grad=False):
# terms, rules, roots = scores
# semiring = self.semiring
# batch, N, T = terms.shape
# _, NT, _, _ = rules.shape
# S = NT + T
# if lengths is None:
# lengths = torch.LongTensor([N] * batch)

# beta = self._make_chart(2, (batch, N, N, NT + T), rules, force_grad)
# span_l = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
# span_r = self._make_chart(N, (batch, N, NT + T), rules, force_grad)
# term_use = self._make_chart(1, (batch, N, T), terms, force_grad)[0]

# ssum = semiring.sum
# st = semiring.times
# X_Y_Z = rules.view(batch, 1, NT, S, S)

# for w in range(N - 1, -1, -1):
# for b, l in enumerate(lengths):
# beta[A][b, 0, l - 1, :NT] = roots[b]
# beta[B][b, l - 1, N - (l), :NT] = roots[b]

# # LEFT
# # all bigger on the left.
# X = beta[A][:, : N - w - 1, w + 1 :, :NT].view(
# batch, N - w - 1, N - w - 1, NT, 1, 1
# )
# Z = alpha_in[A][:, w + 1 : N, 0 : N - w - 1].view(
# batch, N - w - 1, N - w - 1, 1, 1, S
# )
# t = st(ssum(st(X, Z), dim=2), X_Y_Z)
# # sum out x and y
# span_l[w] = ssum(ssum(t, dim=-3), dim=-1)

# # RIGHT
# X = beta[B][:, w + 1 :, : N - 1 - w, :NT].view(
# batch, N - w - 1, N - w - 1, NT, 1, 1
# )
# Y = alpha_in[B][:, : N - w - 1, w + 1 :, :].view(
# batch, N - w - 1, N - w - 1, 1, S, 1
# )
# t = st(ssum(st(X, Y), dim=2), X_Y_Z)

# span_r[w] = ssum(ssum(t, dim=-3), dim=-2)

# beta[A][:, : N - w - 1, w, :] = span_l[w]
# beta[A][:, 1 : N - w, w, :] = ssum(
# torch.stack([span_r[w], beta[A][:, 1 : N - w, w, :]]), dim=0
# )
# beta[B][:, w:, N - w - 1, :] = beta[A][:, : N - w, w, :]

# term_use[:, :, :] = st(beta[A][:, :, 0, NT:], terms)
# term_marginals = self._make_chart(1, (batch, N, T), terms, force_grad=False)[0]
# for n in range(N):
# term_marginals[:, n] = semiring.div_exp(term_use[:, n], v.view(batch, 1))

# root_marginals = self._make_chart(1, (batch, NT), terms, force_grad=False)[0]
# for b in range(batch):
# root_marginals[b] = semiring.div_exp(
# st(alpha_in[A][b, 0, lengths[b] - 1, :NT], roots[b]), v[b].view(1)
# )
# edge_marginals = self._make_chart(
# 1, (batch, N, N, NT, S, S), terms, force_grad=False
# )[0]
# edge_marginals.fill_(0)
# for w in range(1, N):
# Y = alpha_in[A][:, : N - w, :w, :].view(batch, N - w, w, 1, S, 1)
# Z = alpha_in[B][:, w:, N - w :, :].view(batch, N - w, w, 1, 1, S)
# score = semiring.times(semiring.sum(semiring.times(Y, Z), dim=2), X_Y_Z)
# score = st(score, beta[A][:, : N - w, w, :NT].view(batch, N - w, NT, 1, 1))
# edge_marginals[:, : N - w, w - 1] = semiring.div_exp(
# score, v.view(batch, 1, 1, 1, 1)
# )
# edge_marginals = edge_marginals.transpose(1, 2)

# return (term_marginals, edge_marginals, root_marginals)

def marginals(self, scores, lengths=None, _autograd=True):
"""
Compute the marginals of a CFG using CKY.
Expand Down Expand Up @@ -219,7 +100,9 @@ def marginals(self, scores, lengths=None, _autograd=True):
allow_unused=False,
)
rule_use = marg[:-2]
rules = torch.zeros(batch, N, N, NT, S, S, dtype=scores[1].dtype, device=scores[1].device)
rules = torch.zeros(
batch, N, N, NT, S, S, dtype=scores[1].dtype, device=scores[1].device
)
for w in range(len(rule_use)):
rules[:, w, : N - w - 1] = self.semiring.unconvert(rule_use[w])

Expand Down Expand Up @@ -332,8 +215,8 @@ def to_networkx(cls, spans):
topo = [[] for _ in range(N)]
for n in ordered:
batch, i, j, _ = n.tolist()
#G.add_node(cur, label=A)
if i-j != 0:
# G.add_node(cur, label=A)
if i - j != 0:
a.append(left[(batch, i)][0])
a.append(right[(batch, j)][0])
b.append(cur)
Expand Down
27 changes: 13 additions & 14 deletions torch_struct/cky_crf.py
@@ -1,10 +1,9 @@
import torch
from .helpers import _Struct
from .semirings import LogSemiring
from torch.autograd import Function

A, B = 0, 1


class CKY_CRF(_Struct):
def _dp(self, scores, lengths=None, force_grad=False):
semiring = self.semiring
Expand All @@ -27,28 +26,27 @@ def _dp(self, scores, lengths=None, force_grad=False):
beta[A][:, :, ns, 0] = rule_use[0]
beta[B][:, :, ns, N - 1] = rule_use[0]
for w in range(1, N):
Y = beta[A][:, :, : N - w, :w].view(ssize, batch, N - w, 1, w, NT, 1)
Z = beta[B][:, :, w:, N - w :].view(ssize, batch, N - w, 1, w, 1, NT)
f = torch.arange(N-w), torch.arange(w, N)
X = scores[:, :, f[0], f[1]].view(ssize, batch, N-w, NT)
Y = beta[A][:, :, : N - w, :w].view(ssize, batch, N - w, 1, w, NT, 1)
Z = beta[B][:, :, w:, N - w :].view(ssize, batch, N - w, 1, w, 1, NT)
f = torch.arange(N - w), torch.arange(w, N)
X = scores[:, :, f[0], f[1]].view(ssize, batch, N - w, NT)
merge = semiring.times(Y, Z).view(ssize, batch, N - w, 1, -1)
rule_use[w ][:] = semiring.times(
semiring.sum(merge), X)
rule_use[w][:] = semiring.times(semiring.sum(merge), X)

span[w] = rule_use[w].view(ssize, batch, N - w, NT)
beta[A][:, :, : N - w, w] = span[w]
beta[B][:, :, w:N, N - w - 1] = beta[A][:, :, : N - w, w]

final = semiring.sum(beta[A][:, :, 0, :])
log_Z = torch.stack(
[final[:, b, l - 1] for b, l in enumerate(lengths)], dim=1
)
log_Z = torch.stack([final[:, b, l - 1] for b, l in enumerate(lengths)], dim=1)
return log_Z, rule_use, beta

def _arrange_marginals(self, grads):
semiring = self.semiring
_, batch, N, NT = grads[0].shape
rules = torch.zeros(batch, N, N, NT, dtype=grads[0].dtype, device=grads[0].device)
rules = torch.zeros(
batch, N, N, NT, dtype=grads[0].dtype, device=grads[0].device
)

for w, grad in enumerate(grads):
grad = semiring.unconvert(grad)
Expand All @@ -58,7 +56,6 @@ def _arrange_marginals(self, grads):

def enumerate(self, scores):
semiring = self.semiring
ssize = semiring.size()
batch, N, _, NT = scores.shape

def enumerate(x, start, end):
Expand All @@ -71,7 +68,9 @@ def enumerate(x, start, end):
for m1, y1 in enumerate(y, start, w):
for m2, z1 in enumerate(z, w, end):
yield (
semiring.times(m1, m2, scores[:, start, end-1, x]),
semiring.times(
m1, m2, scores[:, start, end - 1, x]
),
[(x, start, w, end)] + y1 + z1,
)

Expand Down
3 changes: 1 addition & 2 deletions torch_struct/deptree.py
Expand Up @@ -238,12 +238,11 @@ def _check_potentials(self, arc_scores, lengths=None):
# return _unconvert(ret)

def _arrange_marginals(self, grads):
ssize = self.semiring.size()
_, batch, N = grads[0][0].shape
N = N + 1

ret = torch.zeros(
batch, N, N, dtype=grads[0][0].dtype, device=grads[0][0].device
batch, N, N, dtype=grads[0][0].dtype, device=grads[0][0].device
)
# for k in torch.arange(N):
# f = torch.arange(N - k), torch.arange(k, N)
Expand Down
60 changes: 36 additions & 24 deletions torch_struct/semirings.py
Expand Up @@ -80,17 +80,11 @@ def unconvert(xs):
@staticmethod
def sum(xs, dim=-1):
assert dim != 0
eps= 1e-6
d = dim - 1 if dim > 0 else dim
part = torch.logsumexp(xs[0], dim=d)
log_sm = xs[0] - part.unsqueeze(d)
sm = log_sm.exp()
return torch.stack(
(
part,
torch.sum(xs[1].mul(sm) - log_sm.mul(sm), dim=d),
)
)
return torch.stack((part, torch.sum(xs[1].mul(sm) - log_sm.mul(sm), dim=d)))

@staticmethod
def mul(a, b):
Expand Down Expand Up @@ -171,19 +165,26 @@ def backward(ctx, grad_output):
logits, dim = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:

def sample(ls):
pre_shape = ls.shape
draws = torch.multinomial(ls.softmax(-1).view(-1, pre_shape[-1]), 1, True)
draws = torch.multinomial(
ls.softmax(-1).view(-1, pre_shape[-1]), 1, True
)
draws.squeeze(1)
return torch.nn.functional.one_hot(draws, pre_shape[-1]).view(*pre_shape).type_as(ls)
return (
torch.nn.functional.one_hot(draws, pre_shape[-1])
.view(*pre_shape)
.type_as(ls)
)

if dim == -1:
s=sample(logits)
s = sample(logits)
else:
dim = dim if dim >= 0 else logits.dim() + dim
perm = [i for i in range(logits.dim()) if i != dim] + [dim]
rev_perm = [a for a,b in sorted(enumerate(perm), key=lambda a:a[1])]
s= sample(logits.permute(perm)).permute(rev_perm)
rev_perm = [a for a, b in sorted(enumerate(perm), key=lambda a: a[1])]
s = sample(logits.permute(perm)).permute(rev_perm)

grad_input = grad_output.unsqueeze(dim).mul(s)
return grad_input, None
Expand All @@ -197,6 +198,7 @@ def sum(xs, dim=-1):

bits = torch.tensor([pow(2, i) for i in range(1, 18)])


class _MultiSampledLogSumExp(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim):
Expand All @@ -206,35 +208,46 @@ def forward(ctx, input, dim):

@staticmethod
def backward(ctx, grad_output):
#assert ((grad_output == 64) + (grad_output == 0) + (grad_output ==1)).all()
# assert ((grad_output == 64) + (grad_output == 0) + (grad_output ==1)).all()

logits, part, dim = ctx.saved_tensors
grad_input = None
if ctx.needs_input_grad[0]:

def sample(ls):
pre_shape = ls.shape
draws = torch.multinomial(ls.softmax(-1).view(-1, pre_shape[-1]), 16, True)
draws = torch.multinomial(
ls.softmax(-1).view(-1, pre_shape[-1]), 16, True
)
draws = draws.transpose(0, 1)
return torch.nn.functional.one_hot(draws, pre_shape[-1]).view(16, *pre_shape).type_as(ls)
return (
torch.nn.functional.one_hot(draws, pre_shape[-1])
.view(16, *pre_shape)
.type_as(ls)
)

if dim == -1:
s = sample(logits)
else:
dim = dim if dim >= 0 else logits.dim() + dim
perm = [i for i in range(logits.dim()) if i != dim] + [dim]
rev_perm =[0] + [a+1 for a,b in sorted(enumerate(perm), key=lambda a:a[1])]
s= sample(logits.permute(perm)).permute(rev_perm)

rev_perm = [0] + [
a + 1 for a, b in sorted(enumerate(perm), key=lambda a: a[1])
]
s = sample(logits.permute(perm)).permute(rev_perm)

dim = dim if dim >= 0 else logits.dim() + dim
final = (grad_output % 2).unsqueeze(0)
mbits = bits[:].type_as(grad_output)
on = grad_output.unsqueeze(0) % mbits.view(17, * [1]*grad_output.dim())
on = grad_output.unsqueeze(0) % mbits.view(17, *[1] * grad_output.dim())
on = on[1:] - on[:-1]
old_bits = (on + final == 0).unsqueeze(dim+1)
old_bits = (on + final == 0).unsqueeze(dim + 1)

grad_input = mbits[:-1].view(16, *[1]*(s.dim()-1)).mul(
s.masked_fill_(old_bits,0))
grad_input = (
mbits[:-1]
.view(16, *[1] * (s.dim() - 1))
.mul(s.masked_fill_(old_bits, 0))
)

return torch.sum(grad_input, dim=0), None

Expand All @@ -249,5 +262,4 @@ def to_discrete(xs, j):
i = j
final = xs % 2
mbits = bits.type_as(xs)
return (((xs % mbits[i + 1]) - (xs % mbits[i]) + final)!= 0).type_as(xs)

return (((xs % mbits[i + 1]) - (xs % mbits[i]) + final) != 0).type_as(xs)

0 comments on commit aa95f45

Please sign in to comment.