Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Bernhardsson committed Mar 31, 2020
1 parent 6043eb0 commit 84ec1a2
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions test_convoys.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_output_shapes(c=0.3, lambd=0.1, n=1000, k=5):
B, T = generate_censored_data(N, E, C)

# Fit model with ci
model = convoys.regression.Exponential(ci=True)
model = convoys.regression.Exponential(mcmc=True)
model.fit(X, B, T)

# Generate output without ci
Expand All @@ -91,7 +91,7 @@ def test_output_shapes(c=0.3, lambd=0.1, n=1000, k=5):
assert model.cdf(X[0], 0, ci=0.8).shape == (3,)

# Fit model without ci (should be the same)
model = convoys.regression.Exponential(ci=False)
model = convoys.regression.Exponential(mcmc=False)
model.fit(X, B, T)
assert model.predict(X[0], 0).shape == ()
assert model.predict([X[0], X[1]], [0, 1]).shape == (2,)
Expand All @@ -104,7 +104,7 @@ def test_exponential_regression_model(c=0.3, lambd=0.1, n=10000):
N = scipy.stats.uniform.rvs(scale=5./lambd, size=(n,)) # time now
E = scipy.stats.expon.rvs(scale=1./lambd, size=(n,)) # time of event
B, T = generate_censored_data(N, E, C)
model = convoys.regression.Exponential(ci=True)
model = convoys.regression.Exponential(mcmc=True)
model.fit(X, B, T)
assert 0.80*c < model.predict([1], float('inf')) < 1.30*c
for t in [1, 3, 10]:
Expand All @@ -126,7 +126,7 @@ def test_exponential_regression_model(c=0.3, lambd=0.1, n=10000):
assert 0.70*d < (convert_times < t).mean() < 1.30*d

# Fit a linear model
model = convoys.regression.Exponential(ci=False, flavor='linear')
model = convoys.regression.Exponential(mcmc=False, flavor='linear')
model.fit(X, B, T)
model_c = model.params['map']['b'] + model.params['map']['beta'][0]
assert 0.9*c < model_c < 1.1*c
Expand Down Expand Up @@ -162,7 +162,7 @@ def test_weibull_regression_model(cs=[0.3, 0.5, 0.7],
assert 0.80 * c < model.predict(x, float('inf')) < 1.30 * c

# Fit a linear model
model = convoys.regression.Weibull(ci=False, flavor='linear')
model = convoys.regression.Weibull(mcmc=False, flavor='linear')
model.fit(X, B, T)
model_cs = model.params['map']['b'] + model.params['map']['beta']
for model_c, c in zip(model_cs, cs):
Expand All @@ -184,7 +184,7 @@ def test_gamma_regression_model(c=0.3, lambd=0.1, k=3.0, n=10000):
assert 0.80*k < numpy.mean(model.params['map']['k']) < 1.30*k

# Fit a linear model
model = convoys.regression.Gamma(ci=False, flavor='linear')
model = convoys.regression.Gamma(mcmc=False, flavor='linear')
model.fit(X, B, T)
model_c = model.params['map']['b'] + model.params['map']['beta'][0]
assert 0.9*c < model_c < 1.1*c
Expand All @@ -201,7 +201,7 @@ def test_linear_model(n=10000, m=5, k=3.0, lambd=0.1):
E = numpy.array([sample_weibull(k, lambd) for r in range(n)])
B, T = generate_censored_data(N, E, C)

model = convoys.regression.Weibull(ci=False, flavor='linear')
model = convoys.regression.Weibull(mcmc=False, flavor='linear')
model.fit(X, B, T)

# Check the fitted parameters
Expand Down Expand Up @@ -365,7 +365,7 @@ def _test_plot_cohorts(model='weibull', extra_model=None):
def test_plot_cohorts_model():
df = _generate_dataframe()
unit, groups, (G, B, T) = convoys.utils.get_arrays(df)
model = convoys.multi.Exponential(ci=None)
model = convoys.multi.Exponential(mcmc=None)
model.fit(G, B, T)
matplotlib.pyplot.clf()
convoys.plotting.plot_cohorts(G, B, T, model=model, groups=groups)
Expand Down

0 comments on commit 84ec1a2

Please sign in to comment.