Skip to content

Commit

Permalink
Merge a64a418 into 86a1718
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Mar 15, 2018
2 parents 86a1718 + a64a418 commit 05e5c23
Showing 1 changed file with 28 additions and 15 deletions.
43 changes: 28 additions & 15 deletions test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@ def sample_weibull(k, lambd):
# exp(-(x * lambda)^k) = y
return (-numpy.log(random.random())) ** (1.0/k) / lambd

def generate_censored_data(N, E, C):
B = numpy.array([random.random() < c and e < n for n, e, c in zip(N, E, C)])
T = numpy.array([e if b else n for e, b, n in zip(E, B, N)])
return B, T


def test_exponential_regression_model(c=0.3, lambd=0.1, n=100000):
# With a really long observation window, the rate should converge to the measured
X = numpy.ones((n, 1))
B = numpy.array([bool(random.random() < c) for x in range(n)])
T = numpy.array([scipy.stats.expon.rvs(scale=1.0/lambd) if b else 1000.0 for b in B])
c = numpy.mean(B)
C = scipy.stats.bernoulli.rvs(c, size=(n,)) # did it convert
N = scipy.stats.uniform.rvs(scale=5./lambd, size=(n,)) # time now
E = scipy.stats.expon.rvs(scale=1./lambd, size=(n,)) # time of event
B, T = generate_censored_data(N, E, C)
model = convoys.regression.ExponentialRegression()
model.fit(X, B, T)
assert 0.95*c < model.predict_final([1]) < 1.05*c
Expand All @@ -32,15 +37,17 @@ def test_exponential_regression_model(c=0.3, lambd=0.1, n=100000):
y, y_lo, y_hi = model.predict_final([1], ci=0.95)
c_lo = scipy.stats.beta.ppf(0.025, n*c, n*(1-c))
c_hi = scipy.stats.beta.ppf(0.975, n*c, n*(1-c))
assert 0.95*c < y < 1.05 * c
assert 0.95*c_lo < y_lo < 1.05 * c_lo
assert 0.95*c_hi < y_hi < 1.05 * c_hi
assert 0.95*c < y < 1.05*c
assert 0.70*(c_hi-c_lo) < (y_hi-y_lo) < 1.30*(c_hi-c_lo)


def test_weibull_regression_model(cs=[0.3, 0.5, 0.7], lambd=0.1, k=0.5, n=100000):
X = numpy.array([[1] + [r % len(cs) == j for j in range(len(cs))] for r in range(n)])
B = numpy.array([bool(random.random() < cs[r % len(cs)]) for r in range(n)])
T = numpy.array([b and sample_weibull(k, lambd) or 1000 for b in B])
C = numpy.array([bool(random.random() < cs[r % len(cs)]) for r in range(n)])
N = scipy.stats.uniform.rvs(scale=5./lambd, size=(n,))
E = numpy.array([sample_weibull(k, lambd) for r in range(n)])
B, T = generate_censored_data(N, E, C)

model = convoys.regression.WeibullRegression()
model.fit(X, B, T)
for r, c in enumerate(cs):
Expand All @@ -50,23 +57,29 @@ def test_weibull_regression_model(cs=[0.3, 0.5, 0.7], lambd=0.1, k=0.5, n=100000

def test_weibull_regression_model_ci(c=0.3, lambd=0.1, k=0.5, n=100000):
X = numpy.ones((n, 1))
B = numpy.array([bool(random.random() < c) for r in range(n)])
c = numpy.mean(B)
T = numpy.array([b and sample_weibull(k, lambd) or 1000 for b in B])
C = scipy.stats.bernoulli.rvs(c, size=(n,))
N = scipy.stats.uniform.rvs(scale=5./lambd, size=(n,))
E = numpy.array([sample_weibull(k, lambd) for r in range(n)])
B, T = generate_censored_data(N, E, C)

model = convoys.regression.WeibullRegression()
model.fit(X, B, T)
y, y_lo, y_hi = model.predict_final([1], ci=0.95)
c_lo = scipy.stats.beta.ppf(0.025, n*c, n*(1-c))
c_hi = scipy.stats.beta.ppf(0.975, n*c, n*(1-c))
assert 0.95*c < y < 1.05 * c
assert 0.95*(c_hi-c_lo) < (y_hi-y_lo) < 1.05 * (c_hi-c_lo)
assert 0.70*(c_hi-c_lo) < (y_hi-y_lo) < 1.30*(c_hi-c_lo)


def test_gamma_regression_model(c=0.3, lambd=0.1, k=3.0, n=100000):
# Something is a bit wacky with this one.
# If I replace N with a smaller observation window, it breaks.
X = numpy.ones((n, 1))
B = numpy.array([bool(random.random() < c) for r in range(n)])
T = numpy.array([b and scipy.stats.gamma.rvs(a=k, scale=1.0/lambd) or 1000 for b in B])
C = scipy.stats.bernoulli.rvs(c, size=(n,))
N = numpy.ones((n,)) * 1000.
E = scipy.stats.gamma.rvs(a=k, scale=1.0/lambd, size=(n,))
B, T = generate_censored_data(N, E, C)

model = convoys.regression.GammaRegression()
model.fit(X, B, T)
assert 0.95*c < model.predict_final([1]) < 1.05*c
Expand Down

0 comments on commit 05e5c23

Please sign in to comment.