Skip to content

Commit

Permalink
remove kaplan-meier
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Bernhardsson committed Mar 20, 2018
1 parent a5f213b commit f1145ab
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 46 deletions.
6 changes: 2 additions & 4 deletions convoys/__init__.py
@@ -1,12 +1,11 @@
import abc
import datetime
import lifelines
import math
import numpy
import random
import seaborn
from matplotlib import pyplot
from convoys.multi import Exponential, Weibull, Gamma, KaplanMeier, Nonparametric
from convoys.multi import Exponential, Weibull, Gamma, Nonparametric


def get_timescale(t):
Expand Down Expand Up @@ -77,15 +76,14 @@ def get_groups(data, group_min_size, max_groups):


_models = {
'kaplan-meier': KaplanMeier,
'nonparametric': Nonparametric,
'exponential': Exponential,
'weibull': Weibull,
'gamma': Gamma,
}


def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, model='kaplan-meier'):
def plot_cohorts(data, t_max=None, title=None, group_min_size=0, max_groups=100, model='nonparametric'):
# Set x scale
if t_max is None:
t_max = max(now - created_at for group, created_at, converted_at, now in data)
Expand Down
4 changes: 0 additions & 4 deletions convoys/multi.py
Expand Up @@ -70,9 +70,5 @@ class Gamma(RegressionToMulti):
_base_model_cls = regression.Gamma


class KaplanMeier(SingleToMulti):
_base_model_cls = single.KaplanMeier


class Nonparametric(SingleToMulti):
_base_model_cls = single.Nonparametric
37 changes: 0 additions & 37 deletions convoys/single.py
@@ -1,5 +1,4 @@
import bisect
import lifelines
import numpy
from scipy.special import expit
import tensorflow as tf
Expand All @@ -10,42 +9,6 @@ class SingleModel:
pass # TODO


class KaplanMeier(SingleModel):
def fit(self, B, T):
kmf = lifelines.KaplanMeierFitter()
kmf.fit(T, event_observed=B)
self.ts = kmf.survival_function_.index.values
self.ps = 1.0 - kmf.survival_function_['KM_estimate'].values
self.ps_hi = 1.0 - kmf.confidence_interval_['KM_estimate_lower_0.95'].values
self.ps_lo = 1.0 - kmf.confidence_interval_['KM_estimate_upper_0.95'].values

def predict(self, ts, ci=None):
# TODO: should also handle scalars
js = [bisect.bisect_left(self.ts, t) for t in ts]
def array_lookup(a):
return numpy.array([a[min(j, len(self.ts)-1)] for j in js])
if ci is not None:
return (array_lookup(self.ps), array_lookup(self.ps_lo), array_lookup(self.ps_hi))
else:
return array_lookup(self.ps)

def predict_final(self, ci=None):
if ci is not None:
return (self.ps[-1], self.ps_lo[-1], self.ps_hi[-1])
else:
return self.ps[-1]

def predict_time(self, ci=None):
# TODO: should not use median here, but mean is no good
def median(ps):
i = bisect.bisect_left(ps, 0.5)
return self.ts[min(i, len(ps)-1)]
if ci is not None:
return median(self.ps), median(self.ps_lo), median(self.ps_hi)
else:
return median(self.ps)


class Nonparametric(SingleModel):
def fit(self, B, T, n=1000):
# We're going to fit c and p_0, p_1, ...
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
@@ -1,4 +1,3 @@
lifelines==0.11.2
matplotlib>=2.0.0
numpy
scipy
Expand Down

0 comments on commit f1145ab

Please sign in to comment.