Skip to content

Commit

Permalink
Merge 55296e2 into b770f21
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Oct 21, 2019
2 parents b770f21 + 55296e2 commit 9b0a25b
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 80 deletions.
10 changes: 4 additions & 6 deletions torch_struct/cky_crf.py
Expand Up @@ -16,23 +16,21 @@ def _dp(self, scores, lengths=None, force_grad=False):

# Initialize
reduced_scores = semiring.sum(scores)
rule_use = reduced_scores.diagonal(0, 2, 3)
ns = torch.arange(N)
rule_use = reduced_scores[:, :, ns, ns]
beta[A][:, :, ns, 0] = rule_use
beta[B][:, :, ns, N - 1] = rule_use

# Run
for w in range(1, N):
Y = beta[A][:, :, : N - w, :w]
Z = beta[B][:, :, w:, N - w :]
f = torch.arange(N - w)
X = reduced_scores[:, :, f, f + w]

beta[A][:, :, : N - w, w] = semiring.times(semiring.dot(Y, Z), X)
score = reduced_scores.diagonal(w, 2, 3)
beta[A][:, :, : N - w, w] = semiring.times(semiring.dot(Y, Z), score)
beta[B][:, :, w:N, N - w - 1] = beta[A][:, :, : N - w, w]

final = beta[A][:, :, 0]
log_Z = torch.stack([final[:, b, l - 1] for b, l in enumerate(lengths)], dim=1)
log_Z = final[:, torch.arange(batch), lengths - 1]
return log_Z, [scores], beta

def enumerate(self, scores):
Expand Down
8 changes: 7 additions & 1 deletion torch_struct/distributions.py
Expand Up @@ -6,7 +6,13 @@
from .semimarkov import SemiMarkov
from .deptree import DepTree, deptree_nonproj, deptree_part
from .cky_crf import CKY_CRF
from .semirings import LogSemiring, MaxSemiring, EntropySemiring, MultiSampledSemiring, KMaxSemiring
from .semirings import (
LogSemiring,
MaxSemiring,
EntropySemiring,
MultiSampledSemiring,
KMaxSemiring,
)


class StructDistribution(Distribution):
Expand Down
85 changes: 37 additions & 48 deletions torch_struct/linearchain.py
Expand Up @@ -28,6 +28,7 @@ def _dp(self, log_potentials, lengths=None, force_grad=False):

def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
"Compute forward pass by linear scan"
# Setup
semiring = self.semiring
log_potentials.requires_grad_(True)
ssize = semiring.size()
Expand All @@ -36,23 +37,6 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
)
log_N = int(math.ceil(math.log(N - 1, 2)))
bin_N = int(math.pow(2, log_N))

# setup scan
def left(x, size):
return x[:, :, 0 : size * 2 : 2]

def right(x, size):
return x[:, :, 1 : size * 2 : 2]

def root(x):
return x[:, :, 0]

def merge(x, y, size):
return semiring.dot(
x.transpose(3, 4).view(ssize, batch, size, 1, C, C),
y.view(ssize, batch, size, C, 1, C),
)

chart = self._make_chart(
log_N + 1, (batch, bin_N, C, C), log_potentials, force_grad
)
Expand All @@ -61,54 +45,59 @@ def merge(x, y, size):
for b in range(lengths.shape[0]):
end = lengths[b] - 1
semiring.zero_(chart[0][:, b, end:])
chart[0][:, b, end:, torch.arange(C), torch.arange(C)] = semiring.one_(
chart[0][:, b, end:, torch.arange(C), torch.arange(C)]
cs = torch.arange(C)
chart[0][:, b, end:, cs, cs] = semiring.one_(
chart[0][:, b, end:].diagonal(0, 2, 3)
)

for b in range(lengths.shape[0]):
end = lengths[b] - 1
chart[0][:, b, :end] = log_potentials[:, b, :end]

# Scan
def merge(x, size):
return semiring.dot(
x[:, :, 0 : size * 2 : 2]
.transpose(3, 4)
.view(ssize, batch, size, 1, C, C),
x[:, :, 1 : size * 2 : 2].view(ssize, batch, size, C, 1, C),
)

size = bin_N
for n in range(1, log_N + 1):
size = int(size / 2)
chart[n][:, :, :size] = merge(
left(chart[n - 1], size), right(chart[n - 1], size), size
)
print(root(chart[-1][:]))
v = semiring.sum(semiring.sum(root(chart[-1][:])))

chart[n][:, :, :size] = merge(chart[n - 1], size)
v = semiring.sum(semiring.sum(chart[-1][:, :, 0]))
return v, [log_potentials], None

def _dp_standard(self, edge, lengths=None, force_grad=False):
semiring = self.semiring
ssize = semiring.size()
edge, batch, N, C, lengths = self._check_potentials(edge, lengths)
# def _dp_standard(self, edge, lengths=None, force_grad=False):
# semiring = self.semiring
# ssize = semiring.size()
# edge, batch, N, C, lengths = self._check_potentials(edge, lengths)

alpha = self._make_chart(N, (batch, C), edge, force_grad)
edge_store = self._make_chart(N - 1, (batch, C, C), edge, force_grad)
# alpha = self._make_chart(N, (batch, C), edge, force_grad)
# edge_store = self._make_chart(N - 1, (batch, C, C), edge, force_grad)

semiring.one_(alpha[0].data)
# semiring.one_(alpha[0].data)

for n in range(1, N):
edge_store[n - 1][:] = semiring.times(
alpha[n - 1].view(ssize, batch, 1, C),
edge[:, :, n - 1].view(ssize, batch, C, C),
)
alpha[n][:] = semiring.sum(edge_store[n - 1])
# for n in range(1, N):
# edge_store[n - 1][:] = semiring.times(
# alpha[n - 1].view(ssize, batch, 1, C),
# edge[:, :, n - 1].view(ssize, batch, C, C),
# )
# alpha[n][:] = semiring.sum(edge_store[n - 1])

for n in range(1, N):
edge_store[n - 1][:] = semiring.times(
alpha[n - 1].view(ssize, batch, 1, C),
edge[:, :, n - 1].view(ssize, batch, C, C),
)
alpha[n][:] = semiring.sum(edge_store[n - 1])
# for n in range(1, N):
# edge_store[n - 1][:] = semiring.times(
# alpha[n - 1].view(ssize, batch, 1, C),
# edge[:, :, n - 1].view(ssize, batch, C, C),
# )
# alpha[n][:] = semiring.sum(edge_store[n - 1])

ret = [alpha[lengths[i] - 1][:, i] for i in range(batch)]
ret = torch.stack(ret, dim=1)
v = semiring.sum(ret)
return v, edge_store, alpha
# ret = [alpha[lengths[i] - 1][:, i] for i in range(batch)]
# ret = torch.stack(ret, dim=1)
# v = semiring.sum(ret)
# return v, edge_store, alpha

@staticmethod
def to_parts(sequence, extra, lengths=None):
Expand Down
115 changes: 90 additions & 25 deletions torch_struct/semimarkov.py
@@ -1,4 +1,5 @@
import torch
import math
from .helpers import _Struct


Expand All @@ -18,37 +19,101 @@ def _check_potentials(self, edge, lengths=None):
assert C == C2, "Transition shape doesn't match"
return edge, batch, N, K, C, lengths

def _dp(self, edge, lengths=None, force_grad=False):
def _dp(self, log_potentials, lengths=None, force_grad=False):
"Compute forward pass by linear scan"

# Setup
semiring = self.semiring
log_potentials.requires_grad_(True)
ssize = semiring.size()
edge, batch, N, K, C, lengths = self._check_potentials(edge, lengths)
edge.requires_grad_(True)
log_potentials, batch, N, K, C, lengths = self._check_potentials(
log_potentials, lengths
)
log_N = int(math.ceil(math.log(N - 1, 2)))
bin_N = int(math.pow(2, log_N))
chart = self._make_chart(
log_N + 1, (batch, bin_N, K - 1, K - 1, C, C), log_potentials, force_grad
)

# Init
# All paths starting at N of len K
alpha = self._make_chart(1, (batch, N, K, C), edge, force_grad)[0]

# All paths finishing at N with label C
beta = self._make_chart(N, (batch, C), edge, force_grad)
semiring.one_(beta[0].data)

# Main.
for n in range(1, N):
alpha[:, :, n - 1] = semiring.dot(
beta[n - 1].view(ssize, batch, 1, 1, C),
edge[:, :, n - 1].view(ssize, batch, K, C, C),
for b in range(lengths.shape[0]):
end = lengths[b] - 1
semiring.zero_(chart[0][:, b, end:])
cs = torch.arange(C)
chart[0][:, b, end:, 0, 0, cs, cs] = semiring.one_(
chart[0][:, b, end:, 0, 0].diagonal(0, 2, 3)
)

t = max(n - K, -1)
f1 = torch.arange(n - 1, t, -1)
f2 = torch.arange(1, len(f1) + 1)
beta[n][:] = semiring.sum(
torch.stack([alpha[:, :, a, b] for a, b in zip(f1, f2)], dim=-1)
)
v = semiring.sum(
torch.stack([beta[l - 1][:, i] for i, l in enumerate(lengths)], dim=1)
)
return v, [edge], beta
for b in range(lengths.shape[0]):
end = lengths[b] - 1
chart[0][:, b, :end, 0, 0] = log_potentials[:, b, :end, 1]

for k in range(K - 1):
if k == 0:
chart[0][:, b, : end - k, 1 + k : (K - 1), k] = log_potentials[
:, b, : end - k, 2 + k : K
]

if k >= 1:
cs = torch.arange(C)
chart[0][:, b, : end - (k - 1), k - 1, k, cs, cs] = semiring.one_(
chart[0][:, b, : end - (k - 1), k - 1, k].diagonal(0, 2, 3)
)

# Scan
def merge(x, size):
return semiring.sum(
semiring.sum(
semiring.times(
x[:, :, 0 : size * 2 : 2]
.transpose(-1, -2)
.transpose(-3, -4)
.view(ssize, batch, size, 1, K - 1, K - 1, 1, C, C),
x[:, :, 1 : size * 2 : 2].view(
ssize, batch, size, K - 1, 1, K - 1, C, 1, C
),
)
).transpose(-1, 5)
).transpose(-1, -2)

size = bin_N
for n in range(1, log_N + 1):
size = int(size / 2)
chart[n][:, :, :size] = merge(chart[n - 1], size)
v = semiring.sum(semiring.sum(chart[-1][:, :, 0, 0, 0, :, :]))
return v, [log_potentials], None

# def _dp_standard(self, edge, lengths=None, force_grad=False):
# semiring = self.semiring
# ssize = semiring.size()
# edge, batch, N, K, C, lengths = self._check_potentials(edge, lengths)
# edge.requires_grad_(True)

# # Init
# # All paths starting at N of len K
# alpha = self._make_chart(1, (batch, N, K, C), edge, force_grad)[0]

# # All paths finishing at N with label C
# beta = self._make_chart(N, (batch, C), edge, force_grad)
# semiring.one_(beta[0].data)

# # Main.
# for n in range(1, N):
# alpha[:, :, n - 1] = semiring.dot(
# beta[n - 1].view(ssize, batch, 1, 1, C),
# edge[:, :, n - 1].view(ssize, batch, K, C, C),
# )

# t = max(n - K, -1)
# f1 = torch.arange(n - 1, t, -1)
# f2 = torch.arange(1, len(f1) + 1)
# beta[n][:] = semiring.sum(
# torch.stack([alpha[:, :, a, b] for a, b in zip(f1, f2)], dim=-1)
# )
# v = semiring.sum(
# torch.stack([beta[l - 1][:, i] for i, l in enumerate(lengths)], dim=1)
# )
# return v, [edge], beta

@staticmethod
def _rand():
Expand Down
11 changes: 11 additions & 0 deletions torch_struct/test_algorithms.py
Expand Up @@ -35,6 +35,17 @@ def test_simple_a(batch, N, C):
LinearChain(MultiSampledSemiring).marginals(vals)


@given(smint, smint, smint, smint)
@settings(max_examples=50, deadline=None)
def test_simple_b(batch, N, K, C):
print(N)
N = 14
vals = torch.ones(batch, N, 5, C, C)
semiring = StdSemiring
SemiMarkov(SampledSemiring).marginals(vals)
SemiMarkov(MultiSampledSemiring).marginals(vals)


@given(data())
@settings(max_examples=50, deadline=None)
def test_networkx(data):
Expand Down

0 comments on commit 9b0a25b

Please sign in to comment.