diff --git a/.github/workflows/torch-struct.yml b/.github/workflows/torch-struct.yml index ab02d06d..312228d8 100644 --- a/.github/workflows/torch-struct.yml +++ b/.github/workflows/torch-struct.yml @@ -26,10 +26,10 @@ jobs: - name: Lint with flake8 run: | # stop the build if there are Python syntax errors or undefined names - flake8 --ignore "N801, E203, E266, E501, W503, F812, E741, N803, N802, N806" torch_struct/ + flake8 --ignore "N801, E203, E266, E501, W503, F812, E741, N803, N802, N806" torch_struct/ tests/ - name: Test with pytest run: | - pytest --cov=torch_struct --cov-report annotate:annotate --cov-report term-missing torch_struct/ + pytest --cov=torch_struct --cov-report annotate:annotate --cov-report term-missing tests/ - name: Coveralls env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/torch_struct/data/__init__.py b/examples/data/__init__.py similarity index 100% rename from torch_struct/data/__init__.py rename to examples/data/__init__.py diff --git a/torch_struct/data/data.py b/examples/data/data.py similarity index 100% rename from torch_struct/data/data.py rename to examples/data/data.py diff --git a/torch_struct/data/trees.py b/examples/data/trees.py similarity index 100% rename from torch_struct/data/trees.py rename to examples/data/trees.py diff --git a/torch_struct/networks/NeuralCFG.py b/examples/networks/NeuralCFG.py similarity index 100% rename from torch_struct/networks/NeuralCFG.py rename to examples/networks/NeuralCFG.py diff --git a/torch_struct/networks/SpanLSTM.py b/examples/networks/SpanLSTM.py similarity index 100% rename from torch_struct/networks/SpanLSTM.py rename to examples/networks/SpanLSTM.py diff --git a/torch_struct/networks/TreeLSTM.py b/examples/networks/TreeLSTM.py similarity index 100% rename from torch_struct/networks/TreeLSTM.py rename to examples/networks/TreeLSTM.py diff --git a/torch_struct/networks/__init__.py b/examples/networks/__init__.py similarity index 100% rename from torch_struct/networks/__init__.py rename to examples/networks/__init__.py diff --git a/torch_struct/rl.py b/examples/rl.py similarity index 100% rename from torch_struct/rl.py rename to examples/rl.py diff --git a/setup.cfg b/setup.cfg index 31ad82b6..60361141 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,7 @@ [aliases] -test = pytest +test = pytest tests +style = flake8 --ignore "N801, E203, E266, E501, W503, F812, E741, N803, N802, N806" torch_struct tests +[darglint] +ignore_regex=((^_(.*))|(.*map)|(.*zip)|(.*reduce)|(test.*)|(tensor_.*)) +docstring_style=google +strictness=short diff --git a/setup.py b/setup.py index 3db23560..f9e514cd 100644 --- a/setup.py +++ b/setup.py @@ -2,13 +2,11 @@ setup( name="torch_struct", - version="0.4", + version="0.5", author="Alexander Rush", author_email="arush@cornell.edu", packages=[ "torch_struct", - "torch_struct.data", - "torch_struct.networks", "torch_struct.semirings", ], package_data={"torch_struct": []}, diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/extensions.py b/tests/extensions.py new file mode 100644 index 00000000..de5b09a0 --- /dev/null +++ b/tests/extensions.py @@ -0,0 +1,355 @@ +import torch_struct +import torch +from torch_struct import LogSemiring +import itertools + + +class LinearChainTest: + def __init__(self, semiring=LogSemiring): + self.semiring = semiring + + @staticmethod + def _rand(min_n=2): + b = torch.randint(2, 4, (1,)) + N = torch.randint(min_n, 4, (1,)) + C = torch.randint(2, 4, (1,)) + return torch.rand(b, N, C, C), (b.item(), (N + 1).item()) + + ### Tests + + def enumerate(self, edge, lengths=None): + model = torch_struct.LinearChain(self.semiring) + semiring = self.semiring + ssize = semiring.size() + edge, batch, N, C, lengths = model._check_potentials(edge, lengths) + chains = [[([c], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)]] + + enum_lengths = torch.LongTensor(lengths.shape) + for n in range(1, N): + new_chains = [] + for chain, score in chains[-1]: + for c in range(C): + new_chains.append( + ( + chain + [c], + semiring.mul(score, edge[:, :, n - 1, c, chain[-1]]), + ) + ) + chains.append(new_chains) + + for b in range(lengths.shape[0]): + if lengths[b] == n + 1: + enum_lengths[b] = len(new_chains) + + edges = model.to_parts( + torch.stack([torch.tensor(c) for (c, _) in chains[-1]]), C + ) + # Sum out non-batch + a = torch.einsum("ancd,sbncd->sbancd", edges.float(), edge) + a = semiring.prod(a.view(*a.shape[:3] + (-1,)), dim=3) + a = semiring.sum(a, dim=2) + ret = semiring.sum(torch.stack([s for (_, s) in chains[-1]], dim=1), dim=1) + assert torch.isclose(a, ret).all(), "%s %s" % (a, ret) + + edges = torch.zeros(len(chains[-1]), batch, N - 1, C, C) + for b in range(lengths.shape[0]): + edges[: enum_lengths[b], b, : lengths[b] - 1] = model.to_parts( + torch.stack([torch.tensor(c) for (c, _) in chains[lengths[b] - 1]]), C + ) + + return ( + semiring.unconvert(ret), + [s for (_, s) in chains[-1]], + edges, + enum_lengths, + ) + + +class DepTreeTest: + def __init__(self, semiring=LogSemiring): + self.semiring = semiring + + @staticmethod + def _rand(): + b = torch.randint(2, 4, (1,)) + N = torch.randint(2, 4, (1,)) + return torch.rand(b, N, N), (b.item(), N.item()) + + def enumerate(self, arc_scores, non_proj=False, multi_root=True): + semiring = self.semiring + parses = [] + q = [] + arc_scores = torch_struct.convert(arc_scores) + batch, N, _ = arc_scores.shape + + # arc_scores = arc_scores.sum(-1) + for mid in itertools.product(range(N + 1), repeat=N - 1): + parse = [-1] + list(mid) + if not _is_spanning(parse): + continue + if not non_proj and not _is_projective(parse): + continue + + if not multi_root and _is_multi_root(parse): + continue + + q.append(parse) + parses.append( + semiring.times(*[arc_scores[:, parse[i], i] for i in range(1, N, 1)]) + ) + return semiring.sum(torch.stack(parses, dim=-1)), None + + +class SemiMarkovTest: + def __init__(self, semiring=LogSemiring): + self.semiring = semiring + + # Tests + + @staticmethod + def _rand(): + b = torch.randint(2, 4, (1,)) + N = torch.randint(2, 4, (1,)) + K = torch.randint(2, 4, (1,)) + C = torch.randint(2, 4, (1,)) + return torch.rand(b, N, K, C, C), (b.item(), (N + 1).item()) + + def enumerate(self, edge): + semiring = self.semiring + ssize = semiring.size() + batch, N, K, C, _ = edge.shape + edge = semiring.convert(edge) + chains = {} + chains[0] = [ + ([(c, 0)], semiring.one_(torch.zeros(ssize, batch))) for c in range(C) + ] + + for n in range(1, N + 1): + chains[n] = [] + for k in range(1, K): + if n - k not in chains: + continue + for chain, score in chains[n - k]: + for c in range(C): + chains[n].append( + ( + chain + [(c, k)], + semiring.mul( + score, edge[:, :, n - k, k, c, chain[-1][0]] + ), + ) + ) + ls = [s for (_, s) in chains[N]] + return semiring.unconvert(semiring.sum(torch.stack(ls, dim=1), dim=1)), ls + + +### Tests + + +def _is_spanning(parse): + """ + Is the parse tree a valid spanning tree? + Returns + -------- + spanning : bool + True if a valid spanning tree. + """ + d = {} + for m, h in enumerate(parse): + if m == h: + return False + d.setdefault(h, []) + d[h].append(m) + stack = [0] + seen = set() + while stack: + cur = stack[0] + if cur in seen: + return False + seen.add(cur) + stack = d.get(cur, []) + stack[1:] + if len(seen) != len(parse) - len([1 for p in parse if p is None]): + return False + return True + + +def _is_multi_root(parse): + root_count = 0 + for m, h in enumerate(parse): + if h == 0: + root_count += 1 + return root_count > 1 + + +def _is_projective(parse): + """ + Is the parse tree projective? + Returns + -------- + projective : bool + True if a projective tree. + """ + for m, h in enumerate(parse): + for m2, h2 in enumerate(parse): + if m2 == m: + continue + if m < h: + if ( + m < m2 < h < h2 + or m < h2 < h < m2 + or m2 < m < h2 < h + or h2 < m < m2 < h + ): + return False + if h < m: + if ( + h < m2 < m < h2 + or h < h2 < m < m2 + or m2 < h < h2 < m + or h2 < h < m2 < m + ): + return False + return True + + +class CKY_CRFTest: + def __init__(self, semiring=LogSemiring): + self.semiring = semiring + + # For testing + def enumerate(self, scores): + semiring = self.semiring + batch, N, _, NT = scores.shape + + def enumerate(x, start, end): + if start + 1 == end: + yield (scores[:, start, start, x], [(start, x)]) + else: + for w in range(start + 1, end): + for y in range(NT): + for z in range(NT): + for m1, y1 in enumerate(y, start, w): + for m2, z1 in enumerate(z, w, end): + yield ( + semiring.times( + m1, m2, scores[:, start, end - 1, x] + ), + [(x, start, w, end)] + y1 + z1, + ) + + ls = [] + for nt in range(NT): + ls += [s for s, _ in enumerate(nt, 0, N)] + + return semiring.sum(torch.stack(ls, dim=-1)), None + + @staticmethod + def _rand(): + batch = torch.randint(2, 5, (1,)) + N = torch.randint(2, 5, (1,)) + NT = torch.randint(2, 5, (1,)) + scores = torch.rand(batch, N, N, NT) + return scores, (batch.item(), N.item()) + + +class CKYTest: + def __init__(self, semiring=LogSemiring): + self.semiring = semiring + + def enumerate(self, scores): + terms, rules, roots = scores + semiring = self.semiring + batch, N, T = terms.shape + _, NT, _, _ = rules.shape + + def enumerate(x, start, end): + if start + 1 == end: + yield (terms[:, start, x - NT], [(start, x - NT)]) + else: + for w in range(start + 1, end): + for y in range(NT) if w != start + 1 else range(NT, NT + T): + for z in range(NT) if w != end - 1 else range(NT, NT + T): + for m1, y1 in enumerate(y, start, w): + for m2, z1 in enumerate(z, w, end): + yield ( + semiring.times( + semiring.times(m1, m2), rules[:, x, y, z] + ), + [(x, start, w, end)] + y1 + z1, + ) + + ls = [] + for nt in range(NT): + ls += [semiring.times(s, roots[:, nt]) for s, _ in enumerate(nt, 0, N)] + return semiring.sum(torch.stack(ls, dim=-1)), None + + @staticmethod + def _rand(): + batch = torch.randint(2, 5, (1,)) + N = torch.randint(2, 5, (1,)) + NT = torch.randint(2, 5, (1,)) + T = torch.randint(2, 5, (1,)) + terms = torch.rand(batch, N, T) + rules = torch.rand(batch, NT, (NT + T), (NT + T)) + roots = torch.rand(batch, NT) + return (terms, rules, roots), (batch.item(), N.item()) + + +class AlignmentTest: + def __init__(self, semiring=LogSemiring): + self.semiring = semiring + + @staticmethod + def _rand(min_n=2): + b = torch.randint(2, 4, (1,)) + N = torch.randint(min_n, 4, (1,)) + M = torch.randint(min_n, 4, (1,)) + N = torch.min(M, N) + return torch.rand(b, N, M, 3), (b.item(), (N).item()) + + def enumerate(self, edge, lengths=None): + semiring = self.semiring + edge, batch, N, M, lengths = self._check_potentials(edge, lengths) + d = {} + d[0, 0] = [([(0, 0)], edge[:, :, 0, 0, 1])] + # enum_lengths = torch.LongTensor(lengths.shape) + for i in range(N): + for j in range(M): + d.setdefault((i + 1, j + 1), []) + d.setdefault((i, j + 1), []) + d.setdefault((i + 1, j), []) + for chain, score in d[i, j]: + if i + 1 < N and j + 1 < M: + d[i + 1, j + 1].append( + ( + chain + [(i + 1, j + 1)], + semiring.mul(score, edge[:, :, i + 1, j + 1, 1]), + ) + ) + if i + 1 < N: + + d[i + 1, j].append( + ( + chain + [(i + 1, j)], + semiring.mul(score, edge[:, :, i + 1, j, 2]), + ) + ) + if j + 1 < M: + d[i, j + 1].append( + ( + chain + [(i, j + 1)], + semiring.mul(score, edge[:, :, i, j + 1, 0]), + ) + ) + all_val = torch.stack([x[1] for x in d[N - 1, M - 1]], dim=-1) + return semiring.unconvert(semiring.sum(all_val)), None + + +test_lookup = { + torch_struct.LinearChain: LinearChainTest, + torch_struct.SemiMarkov: SemiMarkovTest, + torch_struct.DepTree: DepTreeTest, + torch_struct.CKY_CRF: CKY_CRFTest, + torch_struct.CKY: CKYTest, + torch_struct.Alignment: AlignmentTest, +} diff --git a/torch_struct/test_algorithms.py b/tests/test_algorithms.py similarity index 85% rename from torch_struct/test_algorithms.py rename to tests/test_algorithms.py index fc8a7715..9573c323 100644 --- a/torch_struct/test_algorithms.py +++ b/tests/test_algorithms.py @@ -1,10 +1,5 @@ -from .cky import CKY -from .cky_crf import CKY_CRF -from .deptree import DepTree, deptree_nonproj, deptree_part -from .linearchain import LinearChain -from .semimarkov import SemiMarkov -from .alignment import Alignment -from .semirings import ( +from torch_struct import CKY, CKY_CRF, DepTree, LinearChain, SemiMarkov, Alignment +from torch_struct import ( LogSemiring, CheckpointSemiring, CheckpointShardSemiring, @@ -16,6 +11,7 @@ EntropySemiring, MultiSampledSemiring, ) +from .extensions import test_lookup import torch from hypothesis import given, settings from hypothesis.strategies import integers, data, sampled_from @@ -84,11 +80,13 @@ def test_entropy(data): model = data.draw(sampled_from([LinearChain, SemiMarkov])) semiring = EntropySemiring struct = model(semiring) - vals, (batch, N) = model._rand() + test = test_lookup[model](LogSemiring) + vals, (batch, N) = test._rand() alpha = struct.sum(vals) log_z = model(LogSemiring).sum(vals) - log_probs = model(LogSemiring).enumerate(vals)[1] + + log_probs = test.enumerate(vals)[1] log_probs = torch.stack(log_probs, dim=1) - log_z print(log_probs.shape, log_z.shape, log_probs.exp().sum(1)) entropy = -log_probs.mul(log_probs.exp()).sum(1).squeeze(0) @@ -102,7 +100,8 @@ def test_kmax(data): K = 2 semiring = KMaxSemiring(K) struct = model(semiring) - vals, (batch, N) = model._rand() + test = test_lookup[model](LogSemiring) + vals, (batch, N) = test._rand() max1 = model(MaxSemiring).sum(vals) alpha = struct.sum(vals, _raw=True) assert (alpha[0] == max1).all() @@ -115,7 +114,7 @@ def test_kmax(data): assert (topk[1] != topk[0]).any() if model != DepTree: - log_probs = model(MaxSemiring).enumerate(vals)[1] + log_probs = test_lookup[model](MaxSemiring).enumerate(vals)[1] tops = torch.topk(torch.cat(log_probs, dim=0), 5, 0)[0] assert torch.isclose(struct.score(topk[1], vals), alpha[1]).all() for k in range(K): @@ -128,9 +127,10 @@ def test_cky(data): model = data.draw(sampled_from([CKY])) semiring = data.draw(sampled_from([LogSemiring, MaxSemiring])) struct = model(semiring) - vals, (batch, N) = model._rand() + test = test_lookup[model](semiring) + vals, (batch, N) = test._rand() alpha = struct.sum(vals) - count = struct.enumerate(vals)[0] + count = test.enumerate(vals)[0] assert alpha.shape[0] == batch assert count.shape[0] == batch @@ -149,16 +149,17 @@ def test_generic_a(data): semiring = data.draw(sampled_from([LogSemiring, MaxSemiring])) struct = model(semiring) - vals, (batch, N) = model._rand() + test = test_lookup[model](semiring) + vals, (batch, N) = test._rand() alpha = struct.sum(vals) - count = struct.enumerate(vals)[0] + count = test.enumerate(vals)[0] # assert(False) assert alpha.shape[0] == batch assert count.shape[0] == batch assert alpha.shape == count.shape assert torch.isclose(count[0], alpha[0]) - vals, _ = model._rand() + vals, _ = test._rand() struct = model(MaxSemiring) score = struct.sum(vals) marginals = struct.marginals(vals) @@ -173,7 +174,7 @@ def test_labeled_proj_deptree(data): semiring = data.draw(sampled_from([LogSemiring, MaxSemiring])) struct = DepTree(semiring) arc_scores = torch.rand(3, 5, 5, 7) - count = struct.enumerate(semiring.sum(arc_scores))[0] + count = test_lookup[DepTree](semiring).enumerate(semiring.sum(arc_scores))[0] alpha = struct.sum(arc_scores) assert torch.isclose(count, alpha).all() @@ -184,38 +185,38 @@ def test_labeled_proj_deptree(data): assert torch.isclose(max_score, struct.score(arc_scores, argmax)).all() -@given(data()) -@settings(max_examples=50, deadline=None) -def test_non_proj(data): - model = data.draw(sampled_from([DepTree])) - semiring = data.draw(sampled_from([LogSemiring])) - struct = model(semiring) - vals, (batch, N) = model._rand() - alpha = deptree_part(vals) - count = struct.enumerate(vals, non_proj=True, multi_root=False)[0] - - assert alpha.shape[0] == batch - assert count.shape[0] == batch - assert alpha.shape == count.shape - assert torch.isclose(count[0], alpha[0]) - - marginals = deptree_nonproj(vals) - print(marginals.sum(1)) - # assert(False) - # vals, _ = model._rand() - # struct = model(MaxSemiring) - # score = struct.sum(vals) - # marginals = struct.marginals(vals) - # assert torch.isclose(score, struct.score(vals, marginals)).all() +# @given(data()) +# @settings(max_examples=50, deadline=None) +# def test_non_proj(data): +# model = data.draw(sampled_from([DepTree])) +# semiring = data.draw(sampled_from([LogSemiring])) +# struct = model(semiring) +# vals, (batch, N) = model._rand() +# alpha = deptree_part(vals) +# count = struct.enumerate(vals, non_proj=True, multi_root=False)[0] + +# assert alpha.shape[0] == batch +# assert count.shape[0] == batch +# assert alpha.shape == count.shape +# assert torch.isclose(count[0], alpha[0]) + +# marginals = deptree_nonproj(vals) +# print(marginals.sum(1)) +# # assert(False) +# # vals, _ = model._rand() +# # struct = model(MaxSemiring) +# # score = struct.sum(vals) +# # marginals = struct.marginals(vals) +# # assert torch.isclose(score, struct.score(vals, marginals)).all() @given(data(), integers(min_value=1, max_value=20)) def test_parts_from_marginals(data, seed): # todo: add CKY, DepTree too? model = data.draw(sampled_from([LinearChain, SemiMarkov])) - struct = model() + test = test_lookup[model]() torch.manual_seed(seed) - vals, (batch, N) = struct._rand() + vals, (batch, N) = test._rand() edge = model(MaxSemiring).marginals(vals).long() @@ -234,8 +235,9 @@ def test_parts_from_marginals(data, seed): def test_parts_from_sequence(data, seed): model = data.draw(sampled_from([LinearChain, SemiMarkov])) struct = model() + test = test_lookup[model]() torch.manual_seed(seed) - vals, (batch, N) = struct._rand() + vals, (batch, N) = test._rand() C = vals.size(-1) if isinstance(struct, LinearChain): K = 2 @@ -268,12 +270,11 @@ def test_parts_from_sequence(data, seed): @given(data(), integers(min_value=1, max_value=10)) @settings(max_examples=50, deadline=None) def test_generic_lengths(data, seed): - model = data.draw( - sampled_from([CKY, LinearChain, SemiMarkov, CKY_CRF, DepTree]) - ) + model = data.draw(sampled_from([CKY, LinearChain, SemiMarkov, CKY_CRF, DepTree])) struct = model() torch.manual_seed(seed) - vals, (batch, N) = struct._rand() + test = test_lookup[model]() + vals, (batch, N) = test._rand() lengths = torch.tensor( [data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N] ) @@ -319,12 +320,10 @@ def test_generic_lengths(data, seed): @settings(max_examples=50, deadline=None) @given(data(), integers(min_value=1, max_value=10)) def test_params(data, seed): - model = data.draw( - sampled_from([DepTree, SemiMarkov, DepTree, CKY, CKY_CRF]) - ) - struct = model() + model = data.draw(sampled_from([DepTree, SemiMarkov, DepTree, CKY, CKY_CRF])) torch.manual_seed(seed) - vals, (batch, N) = struct._rand() + test = test_lookup[model]() + vals, (batch, N) = test._rand() if isinstance(vals, tuple): vals = tuple((v.requires_grad_(True) for v in vals)) else: @@ -379,10 +378,11 @@ def ignore_alignment(data): model = data.draw(sampled_from([Alignment])) semiring = data.draw(sampled_from([StdSemiring])) + test = test_lookup[model](semiring) struct = model(semiring, sparse_rounds=10) - vals, (batch, N) = model._rand() + vals, (batch, N) = test._rand() alpha = struct.sum(vals) - count = struct.enumerate(vals)[0] + count = test.enumerate(vals)[0] assert torch.isclose(count, alpha).all() model = data.draw(sampled_from([Alignment])) @@ -390,7 +390,7 @@ def ignore_alignment(data): struct = model(semiring, sparse_rounds=10) vals, (batch, N) = model._rand() alpha = struct.sum(vals) - count = struct.enumerate(vals)[0] + count = test_lookup[model](semiring).enumerate(vals)[0] assert torch.isclose(count, alpha).all() # model = data.draw(sampled_from([Alignment])) @@ -409,12 +409,13 @@ def ignore_alignment(data): semiring = data.draw(sampled_from([MaxSemiring])) struct = model(semiring, local=True) - vals, (batch, N) = model._rand() + test = test_lookup[model](semiring) + vals, (batch, N) = test._rand() vals[..., 0] = -2 * vals[..., 0].abs() vals[..., 1] = vals[..., 1].abs() vals[..., 2] = -2 * vals[..., 2].abs() alpha = struct.sum(vals) - count = struct.enumerate(vals)[0] + count = test.enumerate(vals)[0] mx = struct.marginals(vals) print(alpha, count) print(mx[0].nonzero()) @@ -437,7 +438,8 @@ def test_hmm(): def test_sparse_max(data): model = data.draw(sampled_from([LinearChain])) semiring = SparseMaxSemiring - vals, (batch, N) = model._rand() + test = test_lookup[model]() + vals, (batch, N) = test._rand() vals.requires_grad_(True) model(semiring).sum(vals) sparsemax = model(semiring).marginals(vals) diff --git a/torch_struct/test_cky.py b/tests/test_cky.py similarity index 100% rename from torch_struct/test_cky.py rename to tests/test_cky.py diff --git a/torch_struct/test_distributions.py b/tests/test_distributions.py similarity index 89% rename from torch_struct/test_distributions.py rename to tests/test_distributions.py index f98cfcb4..ad5c7113 100644 --- a/torch_struct/test_distributions.py +++ b/tests/test_distributions.py @@ -1,15 +1,29 @@ -from .distributions import LinearChainCRF -from .autoregressive import Autoregressive -from .semirings import KMaxSemiring +from torch_struct import LinearChainCRF, Autoregressive, KMaxSemiring import torch from hypothesis import given, settings from hypothesis.strategies import integers, data, sampled_from +from .extensions import test_lookup smint = integers(min_value=2, max_value=4) tint = integers(min_value=1, max_value=2) lint = integers(min_value=2, max_value=10) +def enumerate_support(self, expand=True): + """ + Compute the full exponential enumeration set. + + Returns: + (enum, enum_lengths) - (*tuple cardinality x batch_shape x event_shape*) + """ + _, _, edges, enum_lengths = test_lookup[self.struct]().enumerate( + self.log_potentials, self.lengths + ) + # if expand: + # edges = edges.unsqueeze(1).expand(edges.shape[:1] + self.batch_shape[:1] + edges.shape[1:]) + return edges, enum_lengths + + @given(data(), integers(min_value=1, max_value=20)) @settings(max_examples=50, deadline=None) def test_simple(data, seed): @@ -22,7 +36,7 @@ def test_simple(data, seed): [data.draw(integers(min_value=2, max_value=N)) for b in range(batch - 1)] + [N] ) dist = model(vals, lengths) - edges, enum_lengths = dist.enumerate_support() + edges, enum_lengths = enumerate_support(dist) log_probs = dist.log_prob(edges) for b in range(lengths.shape[0]): log_probs[enum_lengths[b] :, b] = -1e9 @@ -36,7 +50,7 @@ def test_simple(data, seed): cross_entropy = dist.cross_entropy(other=dist2) kl = dist.kl(other=dist2) - edges2, enum_lengths2 = dist2.enumerate_support() + edges2, enum_lengths2 = enumerate_support(dist2) log_probs2 = dist2.log_prob(edges2) for b in range(lengths.shape[0]): log_probs2[enum_lengths2[b] :, b] = -1e9 @@ -112,7 +126,9 @@ def forward(self, inputs, state): print(auto.log_prob(v.unsqueeze(0))) print(crf.struct().score(crf.argmax, values2)) assert ( - torch.isclose(auto.log_prob(v.unsqueeze(0)), crf.struct().score(crf.argmax, values2)) + torch.isclose( + auto.log_prob(v.unsqueeze(0)), crf.struct().score(crf.argmax, values2) + ) ).all() assert auto.sample((7,)).shape == (7, batch, n_length, n_classes) diff --git a/torch_struct/semirings/test_semirings.py b/tests/test_semirings.py similarity index 98% rename from torch_struct/semirings/test_semirings.py rename to tests/test_semirings.py index d42e5013..ab83a1cc 100644 --- a/torch_struct/semirings/test_semirings.py +++ b/tests/test_semirings.py @@ -3,7 +3,7 @@ from hypothesis.strategies import integers -from . import ( +from torch_struct import ( LogSemiring, CheckpointSemiring, CheckpointShardSemiring, diff --git a/torch_struct/__init__.py b/torch_struct/__init__.py index 2e58fbbc..57071ead 100644 --- a/torch_struct/__init__.py +++ b/torch_struct/__init__.py @@ -1,74 +1,16 @@ -from .cky import CKY -from .distributions import ( - StructDistribution, - LinearChainCRF, - SemiMarkovCRF, - DependencyCRF, - NonProjectiveDependencyCRF, - TreeCRF, - SentCFG, - AlignmentCRF, - HMM, -) -from .autoregressive import Autoregressive, AutoregressiveModel -from .cky_crf import CKY_CRF -from .deptree import DepTree -from .linearchain import LinearChain -from .semimarkov import SemiMarkov -from .alignment import Alignment -from .rl import SelfCritical -from .semirings import ( - LogSemiring, - FastLogSemiring, - TempMax, - FastMaxSemiring, - FastSampleSemiring, - StdSemiring, - KMaxSemiring, - SparseMaxSemiring, - SampledSemiring, - MaxSemiring, - EntropySemiring, - MultiSampledSemiring, - CheckpointSemiring, - CheckpointShardSemiring, -) +# Models +from .cky import * # noqa: F401,F403 +from .cky_crf import * # noqa: F401,F403 +from .deptree import * # noqa: F401,F403 +from .linearchain import * # noqa: F401,F403 +from .semimarkov import * # noqa: F401,F403 +from .alignment import * # noqa: F401,F403 +# Semirings +from .semirings import * # noqa: F401,F403 -version = "0.4" +# Distributions +from .distributions import * # noqa: F401,F403 +from .autoregressive import * # noqa: F401,F403 -# For flake8 compatibility. -__all__ = [ - CKY, - CKY_CRF, - DepTree, - LinearChain, - SemiMarkov, - LogSemiring, - StdSemiring, - SampledSemiring, - MaxSemiring, - SparseMaxSemiring, - KMaxSemiring, - FastLogSemiring, - FastMaxSemiring, - FastSampleSemiring, - EntropySemiring, - MultiSampledSemiring, - SelfCritical, - StructDistribution, - Autoregressive, - AutoregressiveModel, - LinearChainCRF, - SemiMarkovCRF, - DependencyCRF, - NonProjectiveDependencyCRF, - TreeCRF, - SentCFG, - HMM, - AlignmentCRF, - Alignment, - CheckpointSemiring, - CheckpointShardSemiring, - TempMax, -] +version = "0.5" diff --git a/torch_struct/alignment.py b/torch_struct/alignment.py index 15d2fade..075033a8 100644 --- a/torch_struct/alignment.py +++ b/torch_struct/alignment.py @@ -193,48 +193,3 @@ def pad(v): ..., 0, Open, Open, Mid, N - 1, M - N + ((chart.shape[-1] - 1) // 2) ] return v, [log_potentials], None - - @staticmethod - def _rand(min_n=2): - b = torch.randint(2, 4, (1,)) - N = torch.randint(min_n, 4, (1,)) - M = torch.randint(min_n, 4, (1,)) - N = torch.min(M, N) - return torch.rand(b, N, M, 3), (b.item(), (N).item()) - - def enumerate(self, edge, lengths=None): - semiring = self.semiring - edge, batch, N, M, lengths = self._check_potentials(edge, lengths) - d = {} - d[0, 0] = [([(0, 0)], edge[:, :, 0, 0, 1])] - # enum_lengths = torch.LongTensor(lengths.shape) - for i in range(N): - for j in range(M): - d.setdefault((i + 1, j + 1), []) - d.setdefault((i, j + 1), []) - d.setdefault((i + 1, j), []) - for chain, score in d[i, j]: - if i + 1 < N and j + 1 < M: - d[i + 1, j + 1].append( - ( - chain + [(i + 1, j + 1)], - semiring.mul(score, edge[:, :, i + 1, j + 1, 1]), - ) - ) - if i + 1 < N: - - d[i + 1, j].append( - ( - chain + [(i + 1, j)], - semiring.mul(score, edge[:, :, i + 1, j, 2]), - ) - ) - if j + 1 < M: - d[i, j + 1].append( - ( - chain + [(i, j + 1)], - semiring.mul(score, edge[:, :, i, j + 1, 0]), - ) - ) - all_val = torch.stack([x[1] for x in d[N - 1, M - 1]], dim=-1) - return semiring.unconvert(semiring.sum(all_val)), None diff --git a/torch_struct/cky.py b/torch_struct/cky.py index acb8fba1..87356e8c 100644 --- a/torch_struct/cky.py +++ b/torch_struct/cky.py @@ -5,7 +5,7 @@ class CKY(_Struct): - def _dp(self, scores, lengths=None, force_grad=False, cache=True): + def _dp(self, scores, lengths=None, force_grad=False): semiring = self.semiring @@ -26,9 +26,7 @@ def _dp(self, scores, lengths=None, force_grad=False, cache=True): lengths = torch.LongTensor([N] * batch).to(terms.device) # Charts - beta = [ - Chart((batch, N, N, NT), rules, semiring, cache=cache) for _ in range(2) - ] + beta = [Chart((batch, N, N, NT), rules, semiring) for _ in range(2)] span = [None for _ in range(N)] v = (ssize, batch) term_use = terms + 0.0 @@ -85,9 +83,10 @@ def marginals(self, scores, lengths=None, _autograd=True, _raw=False): Compute the marginals of a CFG using CKY. Parameters: - terms : b x n x T - rules : b x NT x (NT+T) x (NT+T) - root: b x NT + scores : terms : b x n x T + rules : b x NT x (NT+T) x (NT+T) + root: b x NT + lengths : Returns: v: b tensor of total sum @@ -99,7 +98,7 @@ def marginals(self, scores, lengths=None, _autograd=True, _raw=False): _, NT, _, _ = rules.shape v, (term_use, rule_use, root_use, spans), alpha = self._dp( - scores, lengths=lengths, force_grad=True, cache=not _raw + scores, lengths=lengths, force_grad=True ) def marginal(obj, inputs): @@ -268,43 +267,3 @@ def to_networkx(cls, spans): cur += 1 indices = left return (n_nodes, a, b, label), indices, topo - - ###### Test - - def enumerate(self, scores): - terms, rules, roots = scores - semiring = self.semiring - batch, N, T = terms.shape - _, NT, _, _ = rules.shape - - def enumerate(x, start, end): - if start + 1 == end: - yield (terms[:, start, x - NT], [(start, x - NT)]) - else: - for w in range(start + 1, end): - for y in range(NT) if w != start + 1 else range(NT, NT + T): - for z in range(NT) if w != end - 1 else range(NT, NT + T): - for m1, y1 in enumerate(y, start, w): - for m2, z1 in enumerate(z, w, end): - yield ( - semiring.times( - semiring.times(m1, m2), rules[:, x, y, z] - ), - [(x, start, w, end)] + y1 + z1, - ) - - ls = [] - for nt in range(NT): - ls += [semiring.times(s, roots[:, nt]) for s, _ in enumerate(nt, 0, N)] - return semiring.sum(torch.stack(ls, dim=-1)), None - - @staticmethod - def _rand(): - batch = torch.randint(2, 5, (1,)) - N = torch.randint(2, 5, (1,)) - NT = torch.randint(2, 5, (1,)) - T = torch.randint(2, 5, (1,)) - terms = torch.rand(batch, N, T) - rules = torch.rand(batch, NT, (NT + T), (NT + T)) - roots = torch.rand(batch, NT) - return (terms, rules, roots), (batch.item(), N.item()) diff --git a/torch_struct/cky_crf.py b/torch_struct/cky_crf.py index 8817edcf..c06badbc 100644 --- a/torch_struct/cky_crf.py +++ b/torch_struct/cky_crf.py @@ -13,11 +13,11 @@ def _check_potentials(self, edge, lengths=None): return edge, batch, N, NT, lengths - def _dp(self, scores, lengths=None, force_grad=False, cache=True): + def _dp(self, scores, lengths=None, force_grad=False): semiring = self.semiring scores, batch, N, NT, lengths = self._check_potentials(scores, lengths) - beta = [Chart((batch, N, N), scores, semiring, cache=cache) for _ in range(2)] + beta = [Chart((batch, N, N), scores, semiring) for _ in range(2)] L_DIM, R_DIM = 2, 3 # Initialize @@ -41,39 +41,3 @@ def _dp(self, scores, lengths=None, force_grad=False, cache=True): final = beta[A][0, :] log_Z = final[:, torch.arange(batch), lengths - 1] return log_Z, [scores], beta - - # For testing - - def enumerate(self, scores): - semiring = self.semiring - batch, N, _, NT = scores.shape - - def enumerate(x, start, end): - if start + 1 == end: - yield (scores[:, start, start, x], [(start, x)]) - else: - for w in range(start + 1, end): - for y in range(NT): - for z in range(NT): - for m1, y1 in enumerate(y, start, w): - for m2, z1 in enumerate(z, w, end): - yield ( - semiring.times( - m1, m2, scores[:, start, end - 1, x] - ), - [(x, start, w, end)] + y1 + z1, - ) - - ls = [] - for nt in range(NT): - ls += [s for s, _ in enumerate(nt, 0, N)] - - return semiring.sum(torch.stack(ls, dim=-1)), None - - @staticmethod - def _rand(): - batch = torch.randint(2, 5, (1,)) - N = torch.randint(2, 5, (1,)) - NT = torch.randint(2, 5, (1,)) - scores = torch.rand(batch, N, N, NT) - return scores, (batch.item(), N.item()) diff --git a/torch_struct/deptree.py b/torch_struct/deptree.py index 63949f69..54a8a23c 100644 --- a/torch_struct/deptree.py +++ b/torch_struct/deptree.py @@ -1,9 +1,8 @@ import torch -import itertools from .helpers import _Struct, Chart -def _convert(logits): +def convert(logits): "move root arcs from diagonal" new_shape = list(logits.shape) new_shape[1] += 1 @@ -18,7 +17,7 @@ def _convert(logits): return new_logits -def _unconvert(logits): +def unconvert(logits): "Move root arcs to diagonal" new_shape = list(logits.shape) new_shape[1] -= 1 @@ -47,14 +46,14 @@ class DepTree(_Struct): Note: For single-root case, do not set cache=True for now. """ - def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True): + def _dp(self, arc_scores_in, lengths=None, force_grad=False): multiroot = getattr(self, "multiroot", True) if arc_scores_in.dim() not in (3, 4): raise ValueError("potentials must have dim of 3 (unlabeled) or 4 (labeled)") labeled = arc_scores_in.dim() == 4 semiring = self.semiring - arc_scores_in = _convert(arc_scores_in) + arc_scores_in = convert(arc_scores_in) arc_scores_in, batch, N, lengths = self._check_potentials( arc_scores_in, lengths ) @@ -62,10 +61,7 @@ def _dp(self, arc_scores_in, lengths=None, force_grad=False, cache=True): arc_scores = semiring.sum(arc_scores_in) if labeled else arc_scores_in alpha = [ [ - [ - Chart((batch, N, N), arc_scores, semiring, cache=multiroot) - for _ in range(2) - ] + [Chart((batch, N, N), arc_scores, semiring) for _ in range(2)] for _ in range(2) ] for _ in range(2) @@ -130,7 +126,7 @@ def _check_potentials(self, arc_scores, lengths=None): return arc_scores, batch, N, lengths def _arrange_marginals(self, grads): - return self.semiring.convert(_unconvert(self.semiring.unconvert(grads[0]))) + return self.semiring.convert(unconvert(self.semiring.unconvert(grads[0]))) @staticmethod def to_parts(sequence, extra=None, lengths=None): @@ -151,7 +147,7 @@ def to_parts(sequence, extra=None, lengths=None): for b in range(batch): labels[b, lengths[b] + 1 :, :] = 0 labels[b, :, lengths[b] + 1 :] = 0 - return _unconvert(labels) + return unconvert(labels) @staticmethod def from_parts(arcs): @@ -173,34 +169,6 @@ def from_parts(arcs): labels[on[i][0], on[i][2]] = on[i][1] + 1 return labels, None - @staticmethod - def _rand(): - b = torch.randint(2, 4, (1,)) - N = torch.randint(2, 4, (1,)) - return torch.rand(b, N, N), (b.item(), N.item()) - - def enumerate(self, arc_scores, non_proj=False, multi_root=True): - semiring = self.semiring - parses = [] - q = [] - arc_scores = _convert(arc_scores) - batch, N, _ = arc_scores.shape - for mid in itertools.product(range(N + 1), repeat=N - 1): - parse = [-1] + list(mid) - if not _is_spanning(parse): - continue - if not non_proj and not _is_projective(parse): - continue - - if not multi_root and _is_multi_root(parse): - continue - - q.append(parse) - parses.append( - semiring.times(*[arc_scores[:, parse[i], i] for i in range(1, N, 1)]) - ) - return semiring.sum(torch.stack(parses, dim=-1)), None - def deptree_part(arc_scores, eps=1e-5): input = arc_scores @@ -252,72 +220,3 @@ def deptree_nonproj(arc_scores, eps=1e-5): ) output = output + torch.diag_embed(roots_output, 0, -2, -1) return output - - -### Tests - - -def _is_spanning(parse): - """ - Is the parse tree a valid spanning tree? - Returns - -------- - spanning : bool - True if a valid spanning tree. - """ - d = {} - for m, h in enumerate(parse): - if m == h: - return False - d.setdefault(h, []) - d[h].append(m) - stack = [0] - seen = set() - while stack: - cur = stack[0] - if cur in seen: - return False - seen.add(cur) - stack = d.get(cur, []) + stack[1:] - if len(seen) != len(parse) - len([1 for p in parse if p is None]): - return False - return True - - -def _is_multi_root(parse): - root_count = 0 - for m, h in enumerate(parse): - if h == 0: - root_count += 1 - return root_count > 1 - - -def _is_projective(parse): - """ - Is the parse tree projective? - Returns - -------- - projective : bool - True if a projective tree. - """ - for m, h in enumerate(parse): - for m2, h2 in enumerate(parse): - if m2 == m: - continue - if m < h: - if ( - m < m2 < h < h2 - or m < h2 < h < m2 - or m2 < m < h2 < h - or h2 < m < m2 < h - ): - return False - if h < m: - if ( - h < m2 < m < h2 - or h < h2 < m < m2 - or m2 < h < h2 < m - or h2 < h < m2 < m - ): - return False - return True diff --git a/torch_struct/distributions.py b/torch_struct/distributions.py index 711c90a7..836e16e2 100644 --- a/torch_struct/distributions.py +++ b/torch_struct/distributions.py @@ -36,8 +36,6 @@ class StructDistribution(Distribution): lengths (long tensor, batch_shape) : integers for length masking """ - has_enumerate_support = True - def __init__(self, log_potentials, lengths=None, args={}): batch_shape = log_potentials.shape[:1] event_shape = log_potentials.shape[1:] @@ -85,6 +83,9 @@ def cross_entropy(self, other): """ Compute cross-entropy for distribution p(self) and q(other) :math:`H[p, q]`. + Parameters: + other : Comparison distribution + Returns: cross entropy (*batch_shape*) """ @@ -97,6 +98,9 @@ def kl(self, other): """ Compute KL-divergence for distribution p(self) and q(other) :math:`KL[p || q] = H[p, q] - H[p]`. + Parameters: + other : Comparison distribution + Returns: cross entropy (*batch_shape*) """ @@ -108,6 +112,7 @@ def kl(self, other): def max(self): r""" Compute an max for distribution :math:`\max p(z)`. + Returns: max (*batch_shape*) """ @@ -126,6 +131,10 @@ def argmax(self): def kmax(self, k): r""" Compute the k-max for distribution :math:`k\max p(z)`. + + Parameters : + k : Number of solutions to return + Returns: kmax (*k x batch_shape*) """ @@ -138,6 +147,9 @@ def topk(self, k): r""" Compute the k-argmax for distribution :math:`k\max p(z)`. + Parameters : + k : Number of solutions to return + Returns: kmax (*k x batch_shape x event_shape*) """ @@ -215,20 +227,6 @@ def from_event(self, event): "Convert event to simple representation." return self.struct.from_parts(event) - def enumerate_support(self, expand=True): - """ - Compute the full exponential enumeration set. - - Returns: - (enum, enum_lengths) - (*tuple cardinality x batch_shape x event_shape*) - """ - _, _, edges, enum_lengths = self._struct().enumerate( - self.log_potentials, self.lengths - ) - # if expand: - # edges = edges.unsqueeze(1).expand(edges.shape[:1] + self.batch_shape[:1] + edges.shape[1:]) - return edges, enum_lengths - def _struct(self, sr=None): return self.struct(sr if sr is not None else LogSemiring) diff --git a/torch_struct/helpers.py b/torch_struct/helpers.py index bdad6678..e92ca493 100644 --- a/torch_struct/helpers.py +++ b/torch_struct/helpers.py @@ -4,36 +4,8 @@ from torch.autograd import Function -class Get(torch.autograd.Function): - @staticmethod - def forward(ctx, chart, grad_chart, indices): - ctx.save_for_backward(grad_chart) - out = chart[indices] - ctx.indices = indices - return out - - @staticmethod - def backward(ctx, grad_output): - (grad_chart,) = ctx.saved_tensors - grad_chart[ctx.indices] += grad_output - return grad_chart, None, None - - -class Set(torch.autograd.Function): - @staticmethod - def forward(ctx, chart, indices, vals): - chart[indices] = vals - ctx.indices = indices - return chart - - @staticmethod - def backward(ctx, grad_output): - z = grad_output[ctx.indices] - return None, None, z - - class Chart: - def __init__(self, size, potentials, semiring, cache=True): + def __init__(self, size, potentials, semiring): self.data = semiring.zero_( torch.zeros( *((semiring.size(),) + size), @@ -42,27 +14,14 @@ def __init__(self, size, potentials, semiring, cache=True): ) ) self.grad = self.data.detach().clone().fill_(0.0) - self.cache = cache def __getitem__(self, ind): I = slice(None) - if self.cache: - return Get.apply(self.data, self.grad, (I, I) + ind) - else: - return self.data[(I, I) + ind] + return self.data[(I, I) + ind] def __setitem__(self, ind, new): I = slice(None) - if self.cache: - self.data = Set.apply(self.data, (I, I) + ind, new) - else: - self.data[(I, I) + ind] = new - - def get(self, ind): - return Get.apply(self.data, self.grad, ind) - - def set(self, ind, new): - self.data = Set.apply(self.data, ind, new) + self.data[(I, I) + ind] = new class _Struct: @@ -161,9 +120,7 @@ def marginals(self, edge, lengths=None, _autograd=True, _raw=False): or self.semiring is not LogSemiring or not hasattr(self, "_dp_backward") ): - v, edges, _ = self._dp( - edge, lengths=lengths, force_grad=True, cache=not _raw - ) + v, edges, _ = self._dp(edge, lengths=lengths, force_grad=True) if _raw: all_m = [] for k in range(v.shape[0]): diff --git a/torch_struct/linearchain.py b/torch_struct/linearchain.py index 59151af2..0609ae60 100644 --- a/torch_struct/linearchain.py +++ b/torch_struct/linearchain.py @@ -41,7 +41,7 @@ def _check_potentials(self, edge, lengths=None): assert C == C2, "Transition shape doesn't match" return edge, batch, N, C, lengths - def _dp(self, log_potentials, lengths=None, force_grad=False, cache=True): + def _dp(self, log_potentials, lengths=None, force_grad=False): return self._dp_scan(log_potentials, lengths, force_grad) def _dp_scan(self, log_potentials, lengths=None, force_grad=False): @@ -161,116 +161,3 @@ def _rand(min_n=2): N = torch.randint(min_n, 4, (1,)) C = torch.randint(2, 4, (1,)) return torch.rand(b, N, C, C), (b.item(), (N + 1).item()) - - ### Tests - - def enumerate(self, edge, lengths=None): - semiring = self.semiring - ssize = semiring.size() - edge, batch, N, C, lengths = self._check_potentials(edge, lengths) - chains = [[([c], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)]] - - enum_lengths = torch.LongTensor(lengths.shape) - for n in range(1, N): - new_chains = [] - for chain, score in chains[-1]: - for c in range(C): - new_chains.append( - ( - chain + [c], - semiring.mul(score, edge[:, :, n - 1, c, chain[-1]]), - ) - ) - chains.append(new_chains) - - for b in range(lengths.shape[0]): - if lengths[b] == n + 1: - enum_lengths[b] = len(new_chains) - - edges = self.to_parts( - torch.stack([torch.tensor(c) for (c, _) in chains[-1]]), C - ) - # Sum out non-batch - a = torch.einsum("ancd,sbncd->sbancd", edges.float(), edge) - a = semiring.prod(a.view(*a.shape[:3] + (-1,)), dim=3) - a = semiring.sum(a, dim=2) - ret = semiring.sum(torch.stack([s for (_, s) in chains[-1]], dim=1), dim=1) - assert torch.isclose(a, ret).all(), "%s %s" % (a, ret) - - edges = torch.zeros(len(chains[-1]), batch, N - 1, C, C) - for b in range(lengths.shape[0]): - edges[: enum_lengths[b], b, : lengths[b] - 1] = self.to_parts( - torch.stack([torch.tensor(c) for (c, _) in chains[lengths[b] - 1]]), C - ) - - return ( - semiring.unconvert(ret), - [s for (_, s) in chains[-1]], - edges, - enum_lengths, - ) - - ## For reference - # - # def _dp_standard(self, edge, lengths=None, force_grad=False): - # semiring = self.semiring - # ssize = semiring.size() - # edge, batch, N, C, lengths = self._check_potentials(edge, lengths) - - # alpha = self._make_chart(N, (batch, C), edge, force_grad) - # edge_store = self._make_chart(N - 1, (batch, C, C), edge, force_grad) - - # semiring.one_(alpha[0].data) - - # for n in range(1, N): - # edge_store[n - 1][:] = semiring.times( - # alpha[n - 1].view(ssize, batch, 1, C), - # edge[:, :, n - 1].view(ssize, batch, C, C), - # ) - # alpha[n][:] = semiring.sum(edge_store[n - 1]) - - # for n in range(1, N): - # edge_store[n - 1][:] = semiring.times( - # alpha[n - 1].view(ssize, batch, 1, C), - # edge[:, :, n - 1].view(ssize, batch, C, C), - # ) - # alpha[n][:] = semiring.sum(edge_store[n - 1]) - - # ret = [alpha[lengths[i] - 1][:, i] for i in range(batch)] - # ret = torch.stack(ret, dim=1) - # v = semiring.sum(ret) - # return v, edge_store, alpha - - # def _dp_backward(self, edge, lengths, alpha_in, v=None): - # semiring = self.semiring - # batch, N, C, lengths = self._check_potentials(edge, lengths) - - # alpha = self._make_chart(N, (batch, C), edge, force_grad=False) - # edge_store = self._make_chart(N - 1, (batch, C, C), edge, force_grad=False) - - # for n in range(N - 1, 0, -1): - # for b, l in enumerate(lengths): - # alpha[l - 1][b].data.fill_(semiring.one()) - - # edge_store[n - 1][:] = semiring.times( - # alpha[n].view(batch, C, 1), edge[:, n - 1] - # ) - # alpha[n - 1][:] = semiring.sum(edge_store[n - 1], dim=-2) - # v = semiring.sum( - # torch.stack([alpha[0][i] for i, l in enumerate(lengths)]), dim=-1 - # ) - # edge_marginals = self._make_chart( - # 1, (batch, N - 1, C, C), edge, force_grad=False - # )[0] - - # for n in range(N - 1): - # edge_marginals[:, n] = semiring.div_exp( - # semiring.times( - # alpha_in[n].view(batch, 1, C), - # edge[:, n], - # alpha[n + 1].view(batch, C, 1), - # ), - # v.view(batch, 1, 1), - # ) - - # return edge_marginals diff --git a/torch_struct/semimarkov.py b/torch_struct/semimarkov.py index fc11ac9d..4f6ffed9 100644 --- a/torch_struct/semimarkov.py +++ b/torch_struct/semimarkov.py @@ -113,14 +113,6 @@ def _dp(self, log_potentials, lengths=None, force_grad=False, cache=True): # ) # return v, [edge], beta - @staticmethod - def _rand(): - b = torch.randint(2, 4, (1,)) - N = torch.randint(2, 4, (1,)) - K = torch.randint(2, 4, (1,)) - C = torch.randint(2, 4, (1,)) - return torch.rand(b, N, K, C, C), (b.item(), (N + 1).item()) - @staticmethod def to_parts(sequence, extra, lengths=None): """ @@ -177,32 +169,3 @@ def from_parts(edge): labels[on[i][0], on[i][1] + on[i][2]] = on[i][3] # print(edge.nonzero(), labels) return labels, (C, K) - - # Tests - def enumerate(self, edge): - semiring = self.semiring - ssize = semiring.size() - batch, N, K, C, _ = edge.shape - edge = semiring.convert(edge) - chains = {} - chains[0] = [ - ([(c, 0)], semiring.one_(torch.zeros(ssize, batch))) for c in range(C) - ] - - for n in range(1, N + 1): - chains[n] = [] - for k in range(1, K): - if n - k not in chains: - continue - for chain, score in chains[n - k]: - for c in range(C): - chains[n].append( - ( - chain + [(c, k)], - semiring.mul( - score, edge[:, :, n - k, k, c, chain[-1][0]] - ), - ) - ) - ls = [s for (_, s) in chains[N]] - return semiring.unconvert(semiring.sum(torch.stack(ls, dim=1), dim=1)), ls diff --git a/torch_struct/semirings/__init__.py b/torch_struct/semirings/__init__.py index e0e5b9db..a95edbc8 100644 --- a/torch_struct/semirings/__init__.py +++ b/torch_struct/semirings/__init__.py @@ -1,40 +1,5 @@ -from .semirings import ( - LogSemiring, - StdSemiring, - KMaxSemiring, - MaxSemiring, - EntropySemiring, - CrossEntropySemiring, - KLDivergenceSemiring, - TempMax, -) - -from .fast_semirings import FastLogSemiring, FastMaxSemiring, FastSampleSemiring - - -from .checkpoint import CheckpointSemiring, CheckpointShardSemiring - -from .sparse_max import SparseMaxSemiring - -from .sample import MultiSampledSemiring, SampledSemiring - - -# For flake8 compatibility. -__all__ = [ - FastLogSemiring, - FastMaxSemiring, - FastSampleSemiring, - LogSemiring, - StdSemiring, - SampledSemiring, - MaxSemiring, - SparseMaxSemiring, - KMaxSemiring, - EntropySemiring, - CrossEntropySemiring, - KLDivergenceSemiring, - MultiSampledSemiring, - CheckpointSemiring, - CheckpointShardSemiring, - TempMax, -] +from .semirings import * # noqa: F401,F403 +from .sparse_max import * # noqa: F401,F403 +from .fast_semirings import * # noqa: F401,F403 +from .checkpoint import * # noqa: F401,F403 +from .sample import * # noqa: F401,F403 diff --git a/torch_struct/semirings/keops.py b/torch_struct/semirings/keops.py deleted file mode 100644 index e9c3b5b6..00000000 --- a/torch_struct/semirings/keops.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch -import torch.distributions -from .semirings import _BaseLog - -try: - from pykeops.torch import LazyTensor -except ImportError: - pass - - -class LogSemiringKO(_BaseLog): - """ - Implements the log-space semiring (logsumexp, +, -inf, 0). - - Gradients give marginals. - """ - - @staticmethod - def sum(a, dim=-1): - a_lazy = LazyTensor(a.unsqueeze(-1).unsqueeze(-1).contiguous()) - c = a_lazy.sum(-1).logsumexp(a.dim() - 1).squeeze(-1).squeeze(-1) - return c - - @classmethod - def dot(cls, a, b): - """ - Dot product along last dim. (Faster than calling sum and times.) - """ - a_lazy = LazyTensor(a.unsqueeze(-1).unsqueeze(-1).contiguous()) - b_lazy = LazyTensor(b.unsqueeze(-1).unsqueeze(-1).contiguous()) - c = (a_lazy + b_lazy).sum(-1).logsumexp(a.dim() - 1).squeeze(-1).squeeze(-1) - return c - - -class _Max(torch.autograd.Function): - @staticmethod - def forward(ctx, a, b): - one_hot = b.shape[-1] - a_lazy = LazyTensor(a.unsqueeze(-1).unsqueeze(-1).contiguous()) - b_lazy = LazyTensor(b.unsqueeze(-1).unsqueeze(-1).contiguous()) - c = (a_lazy + b_lazy).sum(-1).max(a.dim() - 1).squeeze(-1).squeeze(-1) - ac = (a_lazy + b_lazy).sum(-1).argmax(a.dim() - 1).squeeze(-1).squeeze(-1) - ctx.save_for_backward(ac, torch.tensor(one_hot)) - return c - - @staticmethod - def backward(ctx, grad_output): - ac, size = ctx.saved_tensors - back = torch.nn.functional.one_hot(ac, size).type_as(grad_output) - ret = grad_output.unsqueeze(-1).mul(back) - return ret, ret - - -class MaxSemiringKO(_BaseLog): - @classmethod - def sum(cls, xs, dim=-1): - assert dim == -1 - return cls.dot(xs, xs.clone().fill_(0)) - - @classmethod - def dot(cls, a, b): - """ - Dot product along last dim. (Faster than calling sum and times.) - """ - return _Max.apply(a, b)