Skip to content

Commit

Permalink
speed up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Bernhardsson committed May 17, 2018
1 parent 299e4bb commit 4fe0bd4
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_gammainc(k=2.5, x=4.2, g_eps=1e-7):


@flaky.flaky
def test_exponential_regression_model(c=0.3, lambd=0.1, n=100000):
def test_exponential_regression_model(c=0.3, lambd=0.1, n=10000):
X = numpy.ones((n, 1))
C = scipy.stats.bernoulli.rvs(c, size=(n,)) # did it convert
N = scipy.stats.uniform.rvs(scale=5./lambd, size=(n,)) # time now
Expand Down Expand Up @@ -92,11 +92,11 @@ def test_exponential_regression_model(c=0.3, lambd=0.1, n=100000):
convert_times = convert_at[will_convert]
for t in [1, 3, 10]:
d = 1 - numpy.exp(-lambd*t)
assert 0.80*d < (convert_times < t).mean() < 1.30*d
assert 0.70*d < (convert_times < t).mean() < 1.30*d


@flaky.flaky
def test_weibull_regression_model(cs=[0.3, 0.5, 0.7], lambd=0.1, k=0.5, n=100000):
def test_weibull_regression_model(cs=[0.3, 0.5, 0.7], lambd=0.1, k=0.5, n=10000):
X = numpy.array([[r % len(cs) == j for j in range(len(cs))] for r in range(n)])
C = numpy.array([bool(random.random() < cs[r % len(cs)]) for r in range(n)])
N = scipy.stats.uniform.rvs(scale=5./lambd, size=(n,))
Expand All @@ -123,7 +123,7 @@ def test_weibull_regression_model(cs=[0.3, 0.5, 0.7], lambd=0.1, k=0.5, n=100000


@flaky.flaky
def test_gamma_regression_model(c=0.3, lambd=0.1, k=3.0, n=100000):
def test_gamma_regression_model(c=0.3, lambd=0.1, k=3.0, n=10000):
# TODO: this one seems very sensitive to large values for N (i.e. less censoring)
X = numpy.ones((n, 1))
C = scipy.stats.bernoulli.rvs(c, size=(n,))
Expand All @@ -134,10 +134,10 @@ def test_gamma_regression_model(c=0.3, lambd=0.1, k=3.0, n=100000):
model = convoys.regression.Gamma()
model.fit(X, B, T)
assert 0.90*c < model.cdf([1], float('inf')) < 1.20*c
assert 0.90*k < numpy.mean(model.params['k']) < 1.10*k
assert 0.90*k < numpy.mean(model.params['k']) < 1.20*k


def _test_plot_cohorts(cs=[0.3, 0.5, 0.7], k=0.5, lambd=0.1, n=10000, model='weibull', extra_model=None):
def _test_plot_cohorts(cs=[0.3, 0.5, 0.7], k=0.5, lambd=0.1, n=1000, model='weibull', extra_model=None):
C = numpy.array([bool(random.random() < cs[r % len(cs)]) for r in range(n)])
N = scipy.stats.expon.rvs(scale=10./lambd, size=(n,))
E = numpy.array([sample_weibull(k, lambd) for r in range(n)])
Expand All @@ -157,7 +157,7 @@ def _test_plot_cohorts(cs=[0.3, 0.5, 0.7], k=0.5, lambd=0.1, n=10000, model='wei
c = cs[0]
assert group == 'Group 0'
if model != 'kaplan-meier':
assert 0.90*c < y < 1.10 * c
assert 0.90*c < y < 1.20 * c


@flaky.flaky
Expand All @@ -172,4 +172,4 @@ def test_plot_cohorts_weibull():

@flaky.flaky
def test_plot_cohorts_two_models():
_test_plot_cohorts(model='kaplan-meier', extra_model='nonparametric')
_test_plot_cohorts(model='kaplan-meier', extra_model='weibull')

0 comments on commit 4fe0bd4

Please sign in to comment.