Skip to content

Commit

Permalink
wider acceptance interval + pep8
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Bernhardsson committed May 17, 2018
1 parent 4fe0bd4 commit 51469d9
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,34 +73,38 @@ def test_exponential_regression_model(c=0.3, lambd=0.1, n=10000):
model = convoys.regression.Exponential()
model.fit(X, B, T)
assert model.cdf([1], float('inf')).shape == ()
assert 0.90*c < model.cdf([1], float('inf')) < 1.20*c
assert 0.80*c < model.cdf([1], float('inf')) < 1.20*c
assert model.cdf([1], 0).shape == ()
assert model.cdf([1], [0, 1, 2, 3]).shape == (4,)
for t in [1, 3, 10]:
d = 1 - numpy.exp(-lambd*t)
assert 0.90*c*d < model.cdf([1], t) < 1.20*c*d
assert 0.80*c*d < model.cdf([1], t) < 1.20*c*d

# Check the confidence intervals
assert model.cdf([1], float('inf'), ci=0.95).shape == (3,)
assert model.cdf([1], [0, 1, 2, 3], ci=0.95).shape == (4, 3)
y, y_lo, y_hi = model.cdf([1], float('inf'), ci=0.95)
assert 0.90*c < y < 1.20*c
assert 0.80*c < y < 1.20*c

# Check the random variates
will_convert, convert_at = model.rvs([1], n_curves=1, n_samples=10000)
assert 0.90*c < numpy.mean(will_convert) < 1.20*c
assert 0.80*c < numpy.mean(will_convert) < 1.20*c
convert_times = convert_at[will_convert]
for t in [1, 3, 10]:
d = 1 - numpy.exp(-lambd*t)
assert 0.70*d < (convert_times < t).mean() < 1.30*d


@flaky.flaky
def test_weibull_regression_model(cs=[0.3, 0.5, 0.7], lambd=0.1, k=0.5, n=10000):
X = numpy.array([[r % len(cs) == j for j in range(len(cs))] for r in range(n)])
C = numpy.array([bool(random.random() < cs[r % len(cs)]) for r in range(n)])
def test_weibull_regression_model(cs=[0.3, 0.5, 0.7],
lambd=0.1, k=0.5, n=10000):
X = numpy.array([[r % len(cs) == j for j in range(len(cs))]
for r in range(n)])
C = numpy.array([bool(random.random() < cs[r % len(cs)])
for r in range(n)])
N = scipy.stats.uniform.rvs(scale=5./lambd, size=(n,))
E = numpy.array([sample_weibull(k, lambd) for r in range(n)])
E = numpy.array([sample_weibull(k, lambd)
for r in range(n)])
B, T = generate_censored_data(N, E, C)

model = convoys.regression.Weibull()
Expand All @@ -118,7 +122,7 @@ def test_weibull_regression_model(cs=[0.3, 0.5, 0.7], lambd=0.1, k=0.5, n=10000)
# Check results
for r, c in enumerate(cs):
x = [int(r == j) for j in range(len(cs))]
assert 0.90 * c < model.cdf(x, float('inf')) < 1.20 * c
assert 0.80 * c < model.cdf(x, float('inf')) < 1.20 * c
expected_time = 1./lambd * scipy.special.gamma(1 + 1/k)


Expand All @@ -133,11 +137,12 @@ def test_gamma_regression_model(c=0.3, lambd=0.1, k=3.0, n=10000):

model = convoys.regression.Gamma()
model.fit(X, B, T)
assert 0.90*c < model.cdf([1], float('inf')) < 1.20*c
assert 0.90*k < numpy.mean(model.params['k']) < 1.20*k
assert 0.80*c < model.cdf([1], float('inf')) < 1.20*c
assert 0.80*k < numpy.mean(model.params['k']) < 1.20*k


def _test_plot_cohorts(cs=[0.3, 0.5, 0.7], k=0.5, lambd=0.1, n=1000, model='weibull', extra_model=None):
def _test_plot_cohorts(cs=[0.3, 0.5, 0.7], k=0.5, lambd=0.1, n=1000,
model='weibull', extra_model=None):
C = numpy.array([bool(random.random() < cs[r % len(cs)]) for r in range(n)])
N = scipy.stats.expon.rvs(scale=10./lambd, size=(n,))
E = numpy.array([sample_weibull(k, lambd) for r in range(n)])
Expand All @@ -152,12 +157,13 @@ def _test_plot_cohorts(cs=[0.3, 0.5, 0.7], k=0.5, lambd=0.1, n=1000, model='weib

matplotlib.pyplot.clf()
_, result = convoys.plot_cohorts(data, model=model, extra_model=extra_model)
matplotlib.pyplot.savefig('%s-%s.png' % (model, extra_model) if extra_model is not None else '%s.png' % model)
matplotlib.pyplot.savefig('%s-%s.png' % (model, extra_model)
if extra_model is not None else '%s.png' % model)
group, y, y_lo, y_hi = result[0]
c = cs[0]
assert group == 'Group 0'
if model != 'kaplan-meier':
assert 0.90*c < y < 1.20 * c
assert 0.80*c < y < 1.20 * c


@flaky.flaky
Expand Down

0 comments on commit 51469d9

Please sign in to comment.