Skip to content

Commit

Permalink
style
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Jan 16, 2021
1 parent 725f8e5 commit 131960d
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 52 deletions.
48 changes: 32 additions & 16 deletions tests/extensions.py
Expand Up @@ -6,16 +6,17 @@
from hypothesis.extra.numpy import arrays
import numpy as np

class LinearChainTest:


class LinearChainTest:
@staticmethod
@composite
def logpotentials(draw, min_n=2):
b = draw(integers(min_value=2, max_value=3))
N = draw(integers(min_value=min_n, max_value=3))
C = draw(integers(min_value=2, max_value=3))
logp = draw(arrays(np.float, (b, N, C, C), floats(min_value=-100.0, max_value=100.0)))
logp = draw(
arrays(np.float, (b, N, C, C), floats(min_value=-100.0, max_value=100.0))
)
return torch.tensor(logp), (b, (N + 1))

### Tests
Expand Down Expand Up @@ -69,13 +70,14 @@ def enumerate(semiring, edge, lengths=None):


class DepTreeTest:

@staticmethod
@composite
def logpotentials(draw):
b = draw(integers(min_value=2, max_value=3))
N = draw(integers(min_value=2, max_value=3))
logp = draw(arrays(np.float, (b, N, N), floats(min_value=-10.0, max_value=10.0)))
logp = draw(
arrays(np.float, (b, N, N), floats(min_value=-10.0, max_value=10.0))
)
return torch.tensor(logp), (b, N)

@staticmethod
Expand Down Expand Up @@ -105,8 +107,6 @@ def enumerate(semiring, arc_scores, non_proj=False, multi_root=True):

class SemiMarkovTest:



# Tests

@staticmethod
Expand All @@ -116,7 +116,9 @@ def logpotentials(draw):
N = draw(integers(min_value=2, max_value=3))
K = draw(integers(min_value=2, max_value=3))
C = draw(integers(min_value=2, max_value=3))
logp = draw(arrays(np.float, (b, N, K, C, C), floats(min_value=-100.0, max_value=100.0)))
logp = draw(
arrays(np.float, (b, N, K, C, C), floats(min_value=-100.0, max_value=100.0))
)
return torch.tensor(logp), (b, (N + 1))

@staticmethod
Expand Down Expand Up @@ -224,7 +226,11 @@ def logpotentials(draw):
batch = draw(integers(min_value=2, max_value=4))
N = draw(integers(min_value=2, max_value=4))
NT = draw(integers(min_value=2, max_value=4))
logp = draw(arrays(np.float, (batch, N, N, NT), floats(min_value=-100.0, max_value=100.0)))
logp = draw(
arrays(
np.float, (batch, N, N, NT), floats(min_value=-100.0, max_value=100.0)
)
)
return torch.tensor(logp), (batch, N)

@staticmethod
Expand Down Expand Up @@ -255,7 +261,6 @@ def enumerate(x, start, end):
return semiring.sum(torch.stack(ls, dim=-1)), None



class CKYTest:
@staticmethod
@composite
Expand All @@ -264,11 +269,23 @@ def logpotentials(draw):
N = draw(integers(min_value=2, max_value=4))
NT = draw(integers(min_value=2, max_value=3))
T = draw(integers(min_value=2, max_value=3))
terms = draw(arrays(np.float, (batch, N, T), floats(min_value=-100.0, max_value=100.0)))
rules = draw(arrays(np.float, (batch, NT, NT + T, NT +T), floats(min_value=-100.0, max_value=100.0)))
roots = draw(arrays(np.float, (batch, NT), floats(min_value=-100.0, max_value=100.0)))
return (torch.tensor(terms), torch.tensor(rules), torch.tensor(roots)), (batch, N)

terms = draw(
arrays(np.float, (batch, N, T), floats(min_value=-100.0, max_value=100.0))
)
rules = draw(
arrays(
np.float,
(batch, NT, NT + T, NT + T),
floats(min_value=-100.0, max_value=100.0),
)
)
roots = draw(
arrays(np.float, (batch, NT), floats(min_value=-100.0, max_value=100.0))
)
return (torch.tensor(terms), torch.tensor(rules), torch.tensor(roots)), (
batch,
N,
)

@staticmethod
def enumerate(semiring, scores):
Expand Down Expand Up @@ -298,7 +315,6 @@ def enumerate(x, start, end):
return semiring.sum(torch.stack(ls, dim=-1)), None



class AlignmentTest:
def __init__(self, semiring=LogSemiring):
self.semiring = semiring
Expand Down
79 changes: 47 additions & 32 deletions tests/test_algorithms.py
@@ -1,24 +1,38 @@
from torch_struct import CKY, CKY_CRF, DepTree, LinearChain, SemiMarkov, Alignment, deptree_nonproj, deptree_part
from torch_struct import (
CKY,
CKY_CRF,
DepTree,
LinearChain,
SemiMarkov,
Alignment,
deptree_nonproj,
deptree_part,
)
from torch_struct import (
LogSemiring,
CheckpointSemiring,
CheckpointShardSemiring,
GumbelCRFSemiring,
KMaxSemiring,
SparseMaxSemiring,
MaxSemiring,
StdSemiring,
SampledSemiring,
EntropySemiring,
MultiSampledSemiring,
)
from .extensions import LinearChainTest, SemiMarkovTest, DepTreeTest, CKYTest, CKY_CRFTest, CKYTest
from .extensions import (
LinearChainTest,
SemiMarkovTest,
DepTreeTest,
CKYTest,
CKY_CRFTest,
test_lookup,
)
import torch
from hypothesis import given, settings
from hypothesis import given
from hypothesis.strategies import integers, data, sampled_from
import pytest

from hypothesis import settings

settings.register_profile("ci", max_examples=50, deadline=None)

settings.load_profile("ci")
Expand All @@ -30,15 +44,17 @@


algorithms = {
"LinearChain" : (LinearChain, LinearChainTest),
"SemiMarkov" : (SemiMarkov, SemiMarkovTest),
"Dep" : (DepTree, DepTreeTest),
"CKY_CRF" : (CKY_CRF, CKY_CRFTest),
"CKY" : (CKY, CKYTest),
"LinearChain": (LinearChain, LinearChainTest),
"SemiMarkov": (SemiMarkov, SemiMarkovTest),
"Dep": (DepTree, DepTreeTest),
"CKY_CRF": (CKY_CRF, CKY_CRFTest),
"CKY": (CKY, CKYTest),
}


class Gen:
"Helper class for tests"

def __init__(self, model_test, data, semiring):
model_test = algorithms[model_test]
self.data = data
Expand All @@ -50,13 +66,16 @@ def __init__(self, model_test, data, semiring):
if not isinstance(self.vals, tuple):
self.vals = self.vals + 1e-6 * torch.rand(*self.vals.shape)
self.semiring = semiring

def enum(self, semiring=None):
return self.test.enumerate(semiring if semiring is not None else self.semiring,
self.vals)
return self.test.enumerate(
semiring if semiring is not None else self.semiring, self.vals
)


# Model specific tests.


@given(smint, smint, smint)
@settings(max_examples=50, deadline=None)
def test_linear_chain_counting(batch, N, C):
Expand All @@ -69,6 +88,7 @@ def test_linear_chain_counting(batch, N, C):

# Semiring tests


@given(data())
@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov", "Dep"])
@pytest.mark.parametrize("semiring", [LogSemiring, MaxSemiring])
Expand All @@ -82,7 +102,7 @@ def test_log_shapes(model_test, semiring, data):
assert alpha.shape == count.shape
assert torch.isclose(count[0], alpha[0])


@given(data())
@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov"])
def test_entropy(model_test, data):
Expand All @@ -107,7 +127,6 @@ def test_sparse_max(model_test, data):
sparsemax = gen.struct.marginals(gen.vals)
sparsemax.sum().backward()



@given(data())
@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov", "Dep"])
Expand Down Expand Up @@ -151,7 +170,6 @@ def test_cky(model_test, semiring, data):
assert torch.isclose(count[0], alpha[0])



@given(data())
@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov", "CKY_CRF", "Dep"])
def test_max(model_test, data):
Expand Down Expand Up @@ -181,13 +199,12 @@ def test_labeled_proj_deptree(model_test, semiring, data):
assert torch.isclose(max_score, struct.score(arc_scores, argmax)).all()



# todo: add CKY, DepTree too?
@given(data())
@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov", "Dep", "CKY_CRF"])
def test_parts_from_marginals(model_test, data):
gen = Gen(model_test, data, MaxSemiring)

edge = gen.struct.marginals(gen.vals).long()
sequence, extra = gen.model.from_parts(edge)
edge_ = gen.model.to_parts(sequence, extra)
Expand All @@ -202,7 +219,7 @@ def test_parts_from_marginals(model_test, data):
@given(data())
@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov"])
def test_parts_from_sequence(model_test, data):
gen = Gen(model_test, data, LogSemiring)
gen = Gen(model_test, data, LogSemiring)
C = gen.vals.size(-1)
if isinstance(gen.struct, LinearChain):
K = 2
Expand Down Expand Up @@ -264,11 +281,13 @@ def test_generic_lengths(model_test, data):


@given(data())
@pytest.mark.parametrize("model_test", ["LinearChain", "SemiMarkov", "Dep", "CKY", "CKY_CRF"])
@pytest.mark.parametrize(
"model_test", ["LinearChain", "SemiMarkov", "Dep", "CKY", "CKY_CRF"]
)
def test_params(model_test, data):
gen = Gen(model_test, data, LogSemiring)
model, struct, vals, N, batch = gen.model, gen.struct, gen.vals, gen.N, gen.batch
_, struct, vals, _, _ = gen.model, gen.struct, gen.vals, gen.N, gen.batch

if isinstance(vals, tuple):
vals = tuple((v.requires_grad_(True) for v in vals))
else:
Expand All @@ -287,8 +306,6 @@ def test_gumbel(model_test, data):
print(torch.autograd.grad(alpha, gen.vals, alpha.detach())[0][0])




def test_hmm():
C, V, batch, N = 5, 20, 2, 5
transition = torch.rand(C, C)
Expand All @@ -299,8 +316,6 @@ def test_hmm():
LinearChain().sum(out)




def test_sparse_max2():
print(LinearChain(SparseMaxSemiring).sum(torch.rand(1, 8, 3, 3)))
print(LinearChain(SparseMaxSemiring).marginals(torch.rand(1, 8, 3, 3)))
Expand Down Expand Up @@ -373,14 +388,15 @@ def test_lc_custom():
def test_non_proj(model_test, semiring, data):
gen = Gen(model_test, data, semiring)
alpha = deptree_part(gen.vals, False)
count = gen.test.enumerate(LogSemiring, gen.vals, non_proj=True, multi_root=False)[0]
count = gen.test.enumerate(LogSemiring, gen.vals, non_proj=True, multi_root=False)[
0
]

assert alpha.shape[0] == gen.batch
assert count.shape[0] == gen.batch
assert alpha.shape == count.shape
# assert torch.isclose(count[0], alpha[0], 1e-2)


alpha = deptree_part(gen.vals, True)
count = gen.test.enumerate(LogSemiring, gen.vals, non_proj=True, multi_root=True)[0]

Expand All @@ -389,20 +405,20 @@ def test_non_proj(model_test, semiring, data):
assert alpha.shape == count.shape
# assert torch.isclose(count[0], alpha[0], 1e-2)



marginals = deptree_nonproj(gen.vals, multi_root=False)
print(marginals.sum(1))
marginals = deptree_nonproj(gen.vals, multi_root=True)
print(marginals.sum(1))


# # assert(False)
# # vals, _ = model._rand()
# # struct = model(MaxSemiring)
# # score = struct.sum(vals)
# # marginals = struct.marginals(vals)
# # assert torch.isclose(score, struct.score(vals, marginals)).all()


@given(data())
@settings(max_examples=50, deadline=None)
def ignore_alignment(data):
Expand Down Expand Up @@ -483,4 +499,3 @@ def ignore_alignment(data):
# assert torch.isclose(count, alpha).all()
struct = model(semiring, max_gap=1)
alpha = struct.sum(vals)

5 changes: 1 addition & 4 deletions torch_struct/helpers.py
@@ -1,7 +1,6 @@
import torch
import math
from .semirings import LogSemiring
from torch.autograd import Function


class Chart:
Expand Down Expand Up @@ -81,9 +80,7 @@ def sum(self, logpotentials, lengths=None, _raw=False):
return v
return self.semiring.unconvert(v)

def marginals(
self, logpotentials, lengths=None, _raw=False
):
def marginals(self, logpotentials, lengths=None, _raw=False):
"""
Compute the marginals of a structured model.
Expand Down

0 comments on commit 131960d

Please sign in to comment.