Skip to content

Commit

Permalink
Rewrite the tests structure (#92)
Browse files Browse the repository at this point in the history
* update tests

* style
  • Loading branch information
srush committed Jan 16, 2021
1 parent 95b2c19 commit d272745
Show file tree
Hide file tree
Showing 10 changed files with 420 additions and 470 deletions.
138 changes: 79 additions & 59 deletions tests/extensions.py
Expand Up @@ -2,24 +2,28 @@
import torch
from torch_struct import LogSemiring
import itertools
from hypothesis.strategies import integers, composite, floats
from hypothesis.extra.numpy import arrays
import numpy as np


class LinearChainTest:
def __init__(self, semiring=LogSemiring):
self.semiring = semiring

@staticmethod
def _rand(min_n=2):
b = torch.randint(2, 4, (1,))
N = torch.randint(min_n, 4, (1,))
C = torch.randint(2, 4, (1,))
return torch.rand(b, N, C, C), (b.item(), (N + 1).item())
@composite
def logpotentials(draw, min_n=2):
b = draw(integers(min_value=2, max_value=3))
N = draw(integers(min_value=min_n, max_value=3))
C = draw(integers(min_value=2, max_value=3))
logp = draw(
arrays(np.float, (b, N, C, C), floats(min_value=-100.0, max_value=100.0))
)
return torch.tensor(logp), (b, (N + 1))

### Tests

def enumerate(self, edge, lengths=None):
model = torch_struct.LinearChain(self.semiring)
semiring = self.semiring
@staticmethod
def enumerate(semiring, edge, lengths=None):
model = torch_struct.LinearChain(semiring)
semiring = semiring
ssize = semiring.size()
edge, batch, N, C, lengths = model._check_potentials(edge, lengths)
chains = [[([c], semiring.one_(torch.zeros(ssize, batch))) for c in range(C)]]
Expand Down Expand Up @@ -66,17 +70,18 @@ def enumerate(self, edge, lengths=None):


class DepTreeTest:
def __init__(self, semiring=LogSemiring):
self.semiring = semiring

@staticmethod
def _rand():
b = torch.randint(2, 4, (1,))
N = torch.randint(2, 4, (1,))
return torch.rand(b, N, N), (b.item(), N.item())
@composite
def logpotentials(draw):
b = draw(integers(min_value=2, max_value=3))
N = draw(integers(min_value=2, max_value=3))
logp = draw(
arrays(np.float, (b, N, N), floats(min_value=-10.0, max_value=10.0))
)
return torch.tensor(logp), (b, N)

def enumerate(self, arc_scores, non_proj=False, multi_root=True):
semiring = self.semiring
@staticmethod
def enumerate(semiring, arc_scores, non_proj=False, multi_root=True):
parses = []
q = []
arc_scores = torch_struct.convert(arc_scores)
Expand All @@ -101,21 +106,23 @@ def enumerate(self, arc_scores, non_proj=False, multi_root=True):


class SemiMarkovTest:
def __init__(self, semiring=LogSemiring):
self.semiring = semiring

# Tests

@staticmethod
def _rand():
b = torch.randint(2, 4, (1,))
N = torch.randint(2, 4, (1,))
K = torch.randint(2, 4, (1,))
C = torch.randint(2, 4, (1,))
return torch.rand(b, N, K, C, C), (b.item(), (N + 1).item())
@composite
def logpotentials(draw):
b = draw(integers(min_value=2, max_value=3))
N = draw(integers(min_value=2, max_value=3))
K = draw(integers(min_value=2, max_value=3))
C = draw(integers(min_value=2, max_value=3))
logp = draw(
arrays(np.float, (b, N, K, C, C), floats(min_value=-100.0, max_value=100.0))
)
return torch.tensor(logp), (b, (N + 1))

def enumerate(self, edge):
semiring = self.semiring
@staticmethod
def enumerate(semiring, edge):
ssize = semiring.size()
batch, N, K, C, _ = edge.shape
edge = semiring.convert(edge)
Expand Down Expand Up @@ -213,12 +220,22 @@ def _is_projective(parse):


class CKY_CRFTest:
def __init__(self, semiring=LogSemiring):
self.semiring = semiring
@staticmethod
@composite
def logpotentials(draw):
batch = draw(integers(min_value=2, max_value=4))
N = draw(integers(min_value=2, max_value=4))
NT = draw(integers(min_value=2, max_value=4))
logp = draw(
arrays(
np.float, (batch, N, N, NT), floats(min_value=-100.0, max_value=100.0)
)
)
return torch.tensor(logp), (batch, N)

# For testing
def enumerate(self, scores):
semiring = self.semiring
@staticmethod
def enumerate(semiring, scores):
semiring = semiring
batch, N, _, NT = scores.shape

def enumerate(x, start, end):
Expand All @@ -243,22 +260,36 @@ def enumerate(x, start, end):

return semiring.sum(torch.stack(ls, dim=-1)), None

@staticmethod
def _rand():
batch = torch.randint(2, 5, (1,))
N = torch.randint(2, 5, (1,))
NT = torch.randint(2, 5, (1,))
scores = torch.rand(batch, N, N, NT)
return scores, (batch.item(), N.item())


class CKYTest:
def __init__(self, semiring=LogSemiring):
self.semiring = semiring
@staticmethod
@composite
def logpotentials(draw):
batch = draw(integers(min_value=2, max_value=3))
N = draw(integers(min_value=2, max_value=4))
NT = draw(integers(min_value=2, max_value=3))
T = draw(integers(min_value=2, max_value=3))
terms = draw(
arrays(np.float, (batch, N, T), floats(min_value=-100.0, max_value=100.0))
)
rules = draw(
arrays(
np.float,
(batch, NT, NT + T, NT + T),
floats(min_value=-100.0, max_value=100.0),
)
)
roots = draw(
arrays(np.float, (batch, NT), floats(min_value=-100.0, max_value=100.0))
)
return (torch.tensor(terms), torch.tensor(rules), torch.tensor(roots)), (
batch,
N,
)

def enumerate(self, scores):
@staticmethod
def enumerate(semiring, scores):
terms, rules, roots = scores
semiring = self.semiring
batch, N, T = terms.shape
_, NT, _, _ = rules.shape

Expand All @@ -283,17 +314,6 @@ def enumerate(x, start, end):
ls += [semiring.times(s, roots[:, nt]) for s, _ in enumerate(nt, 0, N)]
return semiring.sum(torch.stack(ls, dim=-1)), None

@staticmethod
def _rand():
batch = torch.randint(2, 5, (1,))
N = torch.randint(2, 5, (1,))
NT = torch.randint(2, 5, (1,))
T = torch.randint(2, 5, (1,))
terms = torch.rand(batch, N, T)
rules = torch.rand(batch, NT, (NT + T), (NT + T))
roots = torch.rand(batch, NT)
return (terms, rules, roots), (batch.item(), N.item())


class AlignmentTest:
def __init__(self, semiring=LogSemiring):
Expand Down

0 comments on commit d272745

Please sign in to comment.