Skip to content

Commit

Permalink
more doc
Browse files Browse the repository at this point in the history
  • Loading branch information
Erik Bernhardsson committed May 30, 2018
1 parent a142291 commit 3bd18a4
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 8 deletions.
31 changes: 24 additions & 7 deletions convoys/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ class MultiModel:

class RegressionToMulti(MultiModel):
def __init__(self, *args, **kwargs):
self.base_model = self.base_model_cls(*args, **kwargs)
self.base_model = self._base_model_cls(*args, **kwargs)

def fit(self, G, B, T):
''' Fits the model
:param G: numpy vector of shape :math:`n`
:param B: numpy vector of shape :math:`n`
:param T: numpy vector of shape :math:`n`
'''
self._n_groups = max(G) + 1
X = numpy.zeros((len(G), self._n_groups))
for i, group in enumerate(G):
Expand All @@ -34,9 +40,15 @@ def rvs(self, group, *args, **kwargs):

class SingleToMulti(MultiModel):
def __init__(self, *args, **kwargs):
self.base_model_init = lambda: self.base_model_cls(*args, **kwargs)
self.base_model_init = lambda: self._base_model_cls(*args, **kwargs)

def fit(self, G, B, T):
''' Fits the model
:param G: numpy vector of shape :math:`n`
:param B: numpy vector of shape :math:`n`
:param T: numpy vector of shape :math:`n`
'''
group2bt = {}
for g, b, t in zip(G, B, T):
group2bt.setdefault(g, []).append((b, t))
Expand All @@ -50,20 +62,25 @@ def cdf(self, group, t, *args, **kwargs):


class Exponential(RegressionToMulti):
base_model_cls = regression.Exponential
''' Multi-group version of :class:`convoys.regression.Exponential`. '''
_base_model_cls = regression.Exponential


class Weibull(RegressionToMulti):
base_model_cls = regression.Weibull
''' Multi-group version of :class:`convoys.regression.Weibull`. '''
_base_model_cls = regression.Weibull


class Gamma(RegressionToMulti):
base_model_cls = regression.Gamma
''' Multi-group version of :class:`convoys.regression.Gamma`. '''
_base_model_cls = regression.Gamma


class GeneralizedGamma(RegressionToMulti):
base_model_cls = regression.GeneralizedGamma
''' Multi-group version of :class:`convoys.regression.GeneralizedGamma`. '''
_base_model_cls = regression.GeneralizedGamma


class KaplanMeier(SingleToMulti):
base_model_cls = single.KaplanMeier
''' Multi-group version of :class:`convoys.single.KaplanMeier`. '''
_base_model_cls = single.KaplanMeier
9 changes: 8 additions & 1 deletion convoys/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,14 @@ def __init__(self, ci=False):
self._ci = ci

def fit(self, X, B, T, W=None, fix_k=None, fix_p=None):
# Sanity check input:
'''Fits the model.
:param X: numpy matrix of shape :math:`k \cdot n`
:param B: numpy vector of shape :math:`n`
:param T: numpy vector of shape :math:`n`
:param W: (optional) numpy vector of shape :math:`n`
'''

if W is None:
W = [1] * len(X)
XBTW = [(x, b, t, w) for x, b, t, w in zip(X, B, T, W)
Expand Down
6 changes: 6 additions & 0 deletions convoys/single.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ class SingleModel:


class KaplanMeier(SingleModel):
''' Implementation of the Kaplan-Meier nonparametric method. '''
def fit(self, B, T):
''' Fits the model
:param B: numpy vector of shape :math:`n`
:param T: numpy vector of shape :math:`n`
'''
# See https://www.math.wustl.edu/~sawyer/handouts/greenwood.pdf
n = len(T)
self._ts = []
Expand Down
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ Full API documentation

.. automodule:: convoys.regression
:members:
:inherited-members:

.. automodule:: convoys.single
:members:

.. automodule:: convoys.multi
:members:
:inherited-members:

.. automodule:: convoys.utils
:members:
Expand Down

0 comments on commit 3bd18a4

Please sign in to comment.