diff --git a/convoys/regression.py b/convoys/regression.py index 7dc3f4e..e82b008 100644 --- a/convoys/regression.py +++ b/convoys/regression.py @@ -42,7 +42,7 @@ def generalized_gamma_loss(x, X, B, T, W, fix_k, fix_p, elif flavor == 'linear': # L2 loss, linear c = dot(X, beta)+b LL_observed = -(1 - c)**2 + log_pdf - LL_censored = -c**2 * cdf - (1 - c)**2 * (1 - cdf) + LL_censored = -c**2 * cdf LL_data = sum( W * B * LL_observed + diff --git a/test_convoys.py b/test_convoys.py index bc39824..dc432ac 100644 --- a/test_convoys.py +++ b/test_convoys.py @@ -127,6 +127,13 @@ def test_weibull_regression_model(cs=[0.3, 0.5, 0.7], x = [int(r == j) for j in range(len(cs))] assert 0.80 * c < model.cdf(x, float('inf')) < 1.30 * c + # Fit a linear model + model = convoys.regression.Weibull(ci=False, flavor='linear') + model.fit(X, B, T) + model_cs = model.params['map']['b'] + model.params['map']['beta'] + for model_c, c in zip(model_cs, cs): + assert 0.8 * c < model_c < 1.2 * c + @flaky.flaky def test_gamma_regression_model(c=0.3, lambd=0.1, k=3.0, n=10000):