diff --git a/convoys/__init__.py b/convoys/__init__.py index 6bbe093..513fa96 100644 --- a/convoys/__init__.py +++ b/convoys/__init__.py @@ -1,12 +1,11 @@ import abc import datetime -import lifelines import math import numpy import random import seaborn from matplotlib import pyplot -from convoys.multi import Exponential, Weibull, Gamma, KaplanMeier, Nonparametric +from convoys.multi import Exponential, Weibull, Gamma, Nonparametric def get_timescale(t): @@ -77,7 +76,6 @@ def get_groups(data, group_min_size, max_groups): _models = { - 'kaplan-meier': KaplanMeier, 'nonparametric': Nonparametric, 'exponential': Exponential, 'weibull': Weibull, @@ -85,7 +83,7 @@ def get_groups(data, group_min_size, max_groups): } -def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, model='kaplan-meier'): +def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, model='nonparametric'): # Set x scale if t_max is None: t_max = max(now - created_at for group, created_at, converted_at, now in data) diff --git a/convoys/multi.py b/convoys/multi.py index cdeb438..1eeebad 100644 --- a/convoys/multi.py +++ b/convoys/multi.py @@ -70,9 +70,5 @@ class Gamma(RegressionToMulti): _base_model_cls = regression.Gamma -class KaplanMeier(SingleToMulti): - _base_model_cls = single.KaplanMeier - - class Nonparametric(SingleToMulti): _base_model_cls = single.Nonparametric diff --git a/convoys/single.py b/convoys/single.py index 7b03279..2c47750 100644 --- a/convoys/single.py +++ b/convoys/single.py @@ -1,5 +1,4 @@ import bisect -import lifelines import numpy from scipy.special import expit import tensorflow as tf @@ -10,42 +9,6 @@ class SingleModel: pass # TODO -class KaplanMeier(SingleModel): - def fit(self, B, T): - kmf = lifelines.KaplanMeierFitter() - kmf.fit(T, event_observed=B) - self.ts = kmf.survival_function_.index.values - self.ps = 1.0 - kmf.survival_function_['KM_estimate'].values - self.ps_hi = 1.0 - kmf.confidence_interval_['KM_estimate_lower_0.95'].values - self.ps_lo = 1.0 - kmf.confidence_interval_['KM_estimate_upper_0.95'].values - - def predict(self, ts, ci=None): - # TODO: should also handle scalars - js = [bisect.bisect_left(self.ts, t) for t in ts] - def array_lookup(a): - return numpy.array([a[min(j, len(self.ts)-1)] for j in js]) - if ci is not None: - return (array_lookup(self.ps), array_lookup(self.ps_lo), array_lookup(self.ps_hi)) - else: - return array_lookup(self.ps) - - def predict_final(self, ci=None): - if ci is not None: - return (self.ps[-1], self.ps_lo[-1], self.ps_hi[-1]) - else: - return self.ps[-1] - - def predict_time(self, ci=None): - # TODO: should not use median here, but mean is no good - def median(ps): - i = bisect.bisect_left(ps, 0.5) - return self.ts[min(i, len(ps)-1)] - if ci is not None: - return median(self.ps), median(self.ps_lo), median(self.ps_hi) - else: - return median(self.ps) - - class Nonparametric(SingleModel): def fit(self, B, T, n=1000): # We're going to fit c and p_0, p_1, ... diff --git a/requirements.txt b/requirements.txt index 49aabb2..42ba94d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -lifelines==0.11.2 matplotlib>=2.0.0 numpy scipy