Skip to content

Commit

Permalink
Merge 95d0f53 into 4c4c01c
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 16, 2019
2 parents 4c4c01c + 95d0f53 commit 2a10f83
Show file tree
Hide file tree
Showing 10 changed files with 147 additions and 125 deletions.
7 changes: 1 addition & 6 deletions torch_struct/cky_crf.py
Expand Up @@ -28,18 +28,13 @@ def _dp(self, scores, lengths=None, force_grad=False):
f = torch.arange(N - w)
X = reduced_scores[:, :, f, f + w]

beta[A][:, :, : N - w, w] = semiring.times(
semiring.sum(semiring.times(Y, Z)), X
)
beta[A][:, :, : N - w, w] = semiring.times(semiring.dot(Y, Z), X)
beta[B][:, :, w:N, N - w - 1] = beta[A][:, :, : N - w, w]

final = beta[A][:, :, 0]
log_Z = torch.stack([final[:, b, l - 1] for b, l in enumerate(lengths)], dim=1)
return log_Z, [scores], beta

def _arrange_marginals(self, grads):
return self.semiring.unconvert(grads[0])

def enumerate(self, scores):
semiring = self.semiring
batch, N, _, NT = scores.shape
Expand Down
2 changes: 1 addition & 1 deletion torch_struct/deptree.py
Expand Up @@ -129,7 +129,7 @@ def _check_potentials(self, arc_scores, lengths=None):
return arc_scores, batch, N, lengths

def _arrange_marginals(self, grads):
return _unconvert(self.semiring.unconvert(grads[0]))
return self.semiring.convert(_unconvert(self.semiring.unconvert(grads[0])))

@staticmethod
def to_parts(sequence, extra=None, lengths=None):
Expand Down
11 changes: 11 additions & 0 deletions torch_struct/distributions.py
Expand Up @@ -78,6 +78,17 @@ def argmax(self):
"""
return self.struct(MaxSemiring).marginals(self.log_potentials, self.lengths)

def kmax(self, k):
r"""
Compute the k-max for distribution :math:`k\max p(z)`.
Returns:
kmax (*k x batch_shape x event_shape*)
"""
return self.struct(KMaxSemiring(k)).marginals(
self.log_potentials, self.lengths, _raw=True
)

@lazy_property
def mode(self):
return self.argmax
Expand Down
32 changes: 23 additions & 9 deletions torch_struct/helpers.py
Expand Up @@ -74,7 +74,7 @@ def backward(ctx, grad_v):

return DPManual.apply(edge)

def marginals(self, edge, lengths=None, _autograd=True):
def marginals(self, edge, lengths=None, _autograd=True, _raw=False):
"""
Compute the marginals of a structured model.
Expand All @@ -91,14 +91,28 @@ def marginals(self, edge, lengths=None, _autograd=True):
or not hasattr(self, "_dp_backward")
):
v, edges, _ = self._dp(edge, lengths=lengths, force_grad=True)
marg = torch.autograd.grad(
self.semiring.unconvert(v).sum(dim=0),
edges,
create_graph=True,
only_inputs=True,
allow_unused=False,
)
return self.semiring.unconvert(self._arrange_marginals(marg))
if _raw:
all_m = []
print(v)
for k in range(v.shape[0]):
obj = v[k].sum(dim=0)

marg = torch.autograd.grad(
obj,
edges,
create_graph=True,
only_inputs=True,
allow_unused=False,
)
all_m.append(self.semiring.unconvert(self._arrange_marginals(marg)))
return torch.stack(all_m, dim=0)
else:
obj = self.semiring.unconvert(v).sum(dim=0)
marg = torch.autograd.grad(
obj, edges, create_graph=True, only_inputs=True, allow_unused=False
)
a_m = self._arrange_marginals(marg)
return self.semiring.unconvert(a_m)
else:
v, _, alpha = self._dp(edge, lengths=lengths, force_grad=True)
return self._dp_backward(edge, lengths, alpha)
Expand Down
13 changes: 6 additions & 7 deletions torch_struct/linearchain.py
Expand Up @@ -48,11 +48,9 @@ def root(x):
return x[:, :, 0]

def merge(x, y, size):
return semiring.sum(
semiring.times(
x.transpose(3, 4).view(ssize, batch, size, 1, C, C),
y.view(ssize, batch, size, C, 1, C),
)
return semiring.dot(
x.transpose(3, 4).view(ssize, batch, size, 1, C, C),
y.view(ssize, batch, size, C, 1, C),
)

chart = self._make_chart(
Expand All @@ -78,6 +76,7 @@ def merge(x, y, size):
chart[n][:, :, :size] = merge(
left(chart[n - 1], size), right(chart[n - 1], size), size
)
print(root(chart[-1][:]))
v = semiring.sum(semiring.sum(root(chart[-1][:])))

return v, [log_potentials], None
Expand Down Expand Up @@ -182,9 +181,9 @@ def hmm(transition, emission, init, observations):
return scores

@staticmethod
def _rand():
def _rand(min_n=2):
b = torch.randint(2, 4, (1,))
N = torch.randint(2, 4, (1,))
N = torch.randint(min_n, 4, (1,))
C = torch.randint(2, 4, (1,))
return torch.rand(b, N, C, C), (b.item(), (N + 1).item())

Expand Down
1 change: 1 addition & 0 deletions torch_struct/networks/NeuralCFG.py
Expand Up @@ -22,6 +22,7 @@ class NeuralCFG(torch.nn.Module):
"""
NeuralCFG From Kim et al
"""

def __init__(self, V, T, NT, H):
super().__init__()
self.NT = NT
Expand Down
1 change: 1 addition & 0 deletions torch_struct/networks/SpanLSTM.py
Expand Up @@ -22,6 +22,7 @@ class SpanLSTM(torch.nn.Module):
"""
SpanLSTM model.
"""

def __init__(self, NT, V, H):
super().__init__()
self.H = H
Expand Down
15 changes: 5 additions & 10 deletions torch_struct/semimarkov.py
Expand Up @@ -34,21 +34,19 @@ def _dp(self, edge, lengths=None, force_grad=False):

# Main.
for n in range(1, N):
alpha[:, :, n - 1] = semiring.sum(
semiring.times(
beta[n - 1].view(ssize, batch, 1, 1, C),
edge[:, :, n - 1].view(ssize, batch, K, C, C),
)
alpha[:, :, n - 1] = semiring.dot(
beta[n - 1].view(ssize, batch, 1, 1, C),
edge[:, :, n - 1].view(ssize, batch, K, C, C),
)

t = max(n - K, -1)
f1 = torch.arange(n - 1, t, -1)
f2 = torch.arange(1, len(f1) + 1)
beta[n][:] = semiring.sum(
torch.stack([alpha[:, :, a, b] for a, b in zip(f1, f2)], dim=-1)
)
v = semiring.sum(
torch.stack([beta[l - 1][:, i] for i, l in enumerate(lengths)], dim=1),
dim=2,
torch.stack([beta[l - 1][:, i] for i, l in enumerate(lengths)], dim=1)
)
return v, [edge], beta

Expand All @@ -60,9 +58,6 @@ def _rand():
C = torch.randint(2, 4, (1,))
return torch.rand(b, N, K, C, C), (b.item(), (N + 1).item())

def _arrange_marginals(self, marg):
return self.semiring.unconvert(marg[0])

@staticmethod
def to_parts(sequence, extra, lengths=None):
"""
Expand Down

0 comments on commit 2a10f83

Please sign in to comment.