From e4c7f368282a35f775d4f319ee86be9dbc2eaa8b Mon Sep 17 00:00:00 2001 From: derrynknife Date: Mon, 15 May 2023 07:41:54 +1000 Subject: [PATCH] added fit_from_surpyval_data function to ParametricFitter, refactored code to suit. --- surpyval/parametric/parametric_fitter.py | 733 ++++++++++++++--------- surpyval/utils/surpyval_data.py | 27 + 2 files changed, 490 insertions(+), 270 deletions(-) diff --git a/surpyval/parametric/parametric_fitter.py b/surpyval/parametric/parametric_fitter.py index d789502..631c445 100755 --- a/surpyval/parametric/parametric_fitter.py +++ b/surpyval/parametric/parametric_fitter.py @@ -286,6 +286,128 @@ def mom_moment_gen(self, *params, offset=False): moments[i] = self._moment(n, *params, offset=offset) return moments + def _validate_fit_inputs( + self, surv_data, how, offset, lfp, zi, heuristic, turnbull_estimator + ): + if offset and (self.support[0] != 0): + detail = "{} distribution cannot be offset".format(self.name) + raise ValueError(detail) + + if how not in PARA_METHODS: + raise ValueError('"how" must be one of: ' + str(PARA_METHODS)) + + if how == "MPP" and self.name == "ExpoWeibull": + detail = ( + "ExpoWeibull distribution does not work" + + " with probability plot fitting" + ) + raise ValueError(detail) + + if np.isfinite(surv_data.t).any() and how == "MSE": + detail = "Mean square error doesn't yet support tuncation" + raise NotImplementedError(detail) + + if np.isfinite(surv_data.t).any() and how == "MOM": + detail = "Maximum product spacing doesn't support tuncation" + raise ValueError(detail) + + if (lfp or zi) & (how != "MLE"): + detail = ( + "Limited failure or zero-inflated models" + + " can only be made with MLE" + ) + raise ValueError(detail) + + if zi & (self.support[0] != 0): + detail = ( + "zero-inflated models can only work" + + "with models starting at 0" + ) + raise ValueError() + + if (surv_data.c == 1).all(): + raise ValueError("Cannot have only right censored data") + + if (surv_data.c == -1).all(): + raise ValueError("Cannot have only left censored data") + + if surpyval.utils.check_no_censoring(surv_data.c) and (how == "MOM"): + raise ValueError("Method of moments doesn't support censoring") + + if ( + (surpyval.utils.no_left_or_int(surv_data.c)) + and (how == "MPP") + and (not heuristic == "Turnbull") + ): + detail = ( + "Probability plotting estimation with left or " + + "interval censoring only works with Turnbull heuristic" + ) + raise ValueError() + + if ( + (heuristic == "Turnbull") + and (not ((-1 in surv_data.c) or (2 in surv_data.c))) + and ((~np.isfinite(surv_data.tr)).any()) + ): + # The Turnbull method is extremely memory intensive. + # So if no left or interval censoring and no right-truncation + # then this is equivalent. + heuristic = turnbull_estimator + + if (not offset) & (not zi): + detail_template = """ + Some of your data is outside support of distribution, observed + values must be within [{lower}, {upper}]. + + Are some of your observed values 0, -Inf, or Inf? + """ + + if surv_data.x.ndim == 2: + if ( + (surv_data.x[:, 0] <= self.support[0]) & (surv_data.c == 0) + ).any(): + detail = detail_template.format( + lower=self.support[0], upper=self.support[1] + ) + raise ValueError(detail) + elif ( + (surv_data.x[:, 1] >= self.support[1]) & (surv_data.c == 0) + ).any(): + detail = detail_template.format( + lower=self.support[0], upper=self.support[1] + ) + raise ValueError(detail) + else: + if ( + (surv_data.x <= self.support[0]) & (surv_data.c == 0) + ).any(): + detail = detail_template.format( + lower=self.support[0], upper=self.support[1] + ) + raise ValueError(detail) + elif ( + (surv_data.x >= self.support[1]) & (surv_data.c == 0) + ).any(): + detail = detail_template.format( + lower=self.support[0], upper=self.support[1] + ) + raise ValueError(detail) + + if (surv_data.tl[0] != surv_data.tl).any() and how == "MPS": + raise ValueError( + "Left truncated value can only be single number \ + when using MPS" + ) + + if (surv_data.tr[0] != surv_data.tr).any() and how == "MPS": + raise ValueError( + "Right truncated value can only be single number \ + when using MPS" + ) + + return True + def fit( self, x=None, @@ -296,10 +418,340 @@ def fit( offset=False, zi=False, lfp=False, - tl=None, - tr=None, - xl=None, - xr=None, + tl=None, + tr=None, + xl=None, + xr=None, + fixed=None, + heuristic="Turnbull", + init=[], + rr="y", + on_d_is_0=False, + turnbull_estimator="Fleming-Harrington", + ): + r""" + + The central feature to SurPyval's capability. This function aimed to + have an API to mimic the simplicity of the scipy API. That is, to use + a simple :code:`fit()` call, with as many or as few parameters as + is needed. + + Parameters + ---------- + + x : array like, optional + Array of observations of the random variables. If x is + :code:`None`, xl and xr must be provided. + c : array like, optional + Array of censoring flag. -1 is left censored, 0 is observed, 1 is + right censored, and 2 is intervally censored. If not provided + will assume all values are observed. + n : array like, optional + Array of counts for each x. If data is proivded as counts, then + this can be provided. If :code:`None` will assume each + observation is 1. + t : 2D-array like, optional + 2D array like of the left and right values at which the + respective observation was truncated. If not provided it assumes + that no truncation occurs. + how : {'MLE', 'MPP', 'MOM', 'MSE', 'MPS'}, optional + Method to estimate parameters, these are: + + - MLE : Maximum Likelihood Estimation + - MPP : Method of Probability Plotting + - MOM : Method of Moments + - MSE : Mean Square Error + - MPS : Maximum Product Spacing + + offset : boolean, optional + If :code:`True` finds the shifted distribution. If not provided + assumes not a shifted distribution. Only works with distributions + that are supported on the half-real line. + + tl : array like or scalar, optional + Values of left truncation for observations. If it is a scalar + value assumes each observation is left truncated at the value. + If an array, it is the respective 'late entry' of the observation + + tr : array like or scalar, optional + Values of right truncation for observations. If it is a scalar + value assumes each observation is right truncated at the value. + If an array, it is the respective right truncation value for each + observation + + xl : array like, optional + Array like of the left array for 2-dimensional input of x. This + is useful for data that is all intervally censored. Must be used + with the :code:`xr` input. + + xr : array like, optional + Array like of the right array for 2-dimensional input of x. This + is useful for data that is all intervally censored. Must be used + with the :code:`xl` input. + + fixed : dict, optional + Dictionary of parameters and their values to fix. Fixes parameter + by name. + + heuristic : {'"Blom", "Median", "ECDF", "Modal", "Midpoint", "Mean", + "Weibull", "Benard", "Beard", "Hazen", "Gringorten", + "None", "Tukey", "DPW", "Fleming-Harrington", + "Kaplan-Meier", "Nelson-Aalen", "Filliben", + "Larsen", "Turnbull"} + Plotting method to use, if using the probability plotting, + MPP, method. + + init : array like, optional + initial guess of parameters. Useful if method is failing. + + rr : ('y', 'x') + The dimension on which to minimise the spacing between the line + and the observation. If 'y' the mean square error between the + line and vertical distance to each point is minimised. If 'x' the + mean square error between the line and horizontal distance to each + point is minimised. + + on_d_is_0 : boolean, optional + For the case when using MPP and the highest value is right + censored, you can choosed to include this value into the + regression analysis or not. That is, if :code:`False`, all values + where there are 0 deaths are excluded from the regression. If + :code:`True` all values regardless of whether there is a death + or not are included in the regression. + + turnbull_estimator : ('Nelson-Aalen', 'Kaplan-Meier', or + 'Fleming-Harrington'), str, optional + If using the Turnbull heuristic, you can elect to use either the + KM, NA, or FH estimator with the Turnbull estimates of r, and d. + Defaults to FH. + + Returns + ------- + + model : Parametric + A parametric model with the fitted parameters and methods for + all functions of the distribution using the fitted parameters. + + Examples + -------- + >>> from surpyval import Weibull + >>> import numpy as np + >>> x = Weibull.random(100, 10, 4) + >>> model = Weibull.fit(x) + >>> print(model) + Parametric SurPyval Model + ========================= + Distribution : Weibull + Fitted by : MLE + Parameters : + alpha: 10.551521182640098 + beta: 3.792549834495306 + >>> Weibull.fit(x, how='MPS', fixed={'alpha' : 10}) + Parametric SurPyval Model + ========================= + Distribution : Weibull + Fitted by : MPS + Parameters : + alpha: 10.0 + beta: 3.4314657446866836 + >>> Weibull.fit(xl=x-1, xr=x+1, how='MPP') + Parametric SurPyval Model + ========================= + Distribution : Weibull + Fitted by : MPP + Parameters : + alpha: 9.943092756713078 + beta: 8.613016934518258 + >>> c = np.zeros_like(x) + >>> c[x > 13] = 1 + >>> x[x > 13] = 13 + >>> c = c[x > 6] + >>> x = x[x > 6] + >>> Weibull.fit(x=x, c=c, tl=6) + Parametric SurPyval Model + ========================= + Distribution : Weibull + Fitted by : MLE + Parameters : + alpha: 10.363725328793413 + beta: 4.9886821457305865 + """ + + surv_data = surpyval.xcnt_handler( + x=x, + c=c, + n=n, + t=t, + tl=tl, + tr=tr, + xl=xl, + xr=xr, + as_surpyval_dataset=True, + ) + return self.fit_from_surpyval_data( + surv_data, + how=how, + offset=offset, + zi=zi, + lfp=lfp, + fixed=fixed, + heuristic=heuristic, + init=init, + rr=rr, + on_d_is_0=on_d_is_0, + turnbull_estimator=turnbull_estimator, + ) + + def fit_from_df( + self, + df, + x=None, + c=None, + n=None, + xl=None, + xr=None, + tl=None, + tr=None, + **fit_options + ): + r""" + The central feature to SurPyval's capability. This function aimed to + have an API to mimic the simplicity of the scipy API. That is, to use + a simple :code:`fit()` call, with as many or as few parameters as + is needed. + + Parameters + ---------- + + df : DataFrame + DataFrame of data to be used to create surpyval model + + x : string, optional + column name for the column in df containing the variable data. + If not provided must provide both xl and xr. + + c : string, optional + column name for the column in df containing the censor flag of x. + If not provided assumes all values of x are observed. + + n : string, optional + column name in for the column in df containing the counts of x. + If not provided assumes each x is one observation. + + tl : string or scalar, optional + If string, column name in for the column in df containing the left + truncation data. If scalar assumes each x is left truncated by + that value. If not provided assumes x is not left truncated. + + tr : string or scalar, optional + If string, column name in for the column in df containing the + right truncation data. If scalar assumes each x is right truncated + by that value. If not provided assumes x is not right truncated. + + xl : string, optional + column name for the column in df containing the left interval for + interval censored data. If left interval is -Inf, assumes left + censored. If xl[i] == xr[i] assumes observed. Cannot be provided + with x, must be provided with xr. + + xr : string, optional + column name for the column in df containing the right interval + for interval censored data. If right interval is Inf, assumes + right censored. If xl[i] == xr[i] assumes observed. Cannot be + provided with x, must be provided with xl. + + fit_options : dict, optional + dictionary of fit options that will be passed to the :code:`fit` + method, see that method for options. + + Returns + ------- + + model : Parametric + A parametric model with the fitted parameters and methods for + all functions of the distribution using the fitted parameters. + + + Examples + -------- + >>> import surpyval as surv + >>> df = surv.datasets.BoforsSteel.df + >>> model = surv.Weibull.fit_from_df(df, x='x', n='n', offset=True) + >>> print(model) + Parametric SurPyval Model + ========================= + Distribution : Weibull + Fitted by : MLE + Offset (gamma) : 39.76562962867477 + Parameters : + alpha: 7.141925216146524 + beta: 2.6204524040137844 + """ + + if not type(df) == pd.DataFrame: + raise ValueError("df must be a pandas DataFrame") + + if (x is not None) and ((xl is not None) or (xr is not None)): + raise ValueError("Cannot use `x` and (`xl` and `xr`) together") + + if x is not None: + x = df[x].astype(float) + else: + xl = df[xl].astype(float) + xr = df[xr].astype(float) + x = np.vstack([xl, xr]).T + + if c is not None: + c = df[c].values.astype(int) + + if n is not None: + n = df[n].values.astype(int) + + if tl is not None: + if type(tl) == str: + tl = df[tl].values.astype(float) + elif np.isscalar(tl): + tl = (np.ones(df.shape[0]) * tl).astype(float) + else: + raise ValueError("`tl` must be scalar or column label string") + else: + tl = np.ones(df.shape[0]) * -np.inf + + if tr is not None: + if type(tr) == str: + tr = df[tr].values.astype(float) + elif np.isscalar(tr): + tr = (np.ones(df.shape[0]) * tr).astype(float) + else: + detail = "`tr` must be scalar or a column label string" + raise ValueError(detail) + else: + tr = np.ones(df.shape[0]) * np.inf + + t = np.vstack([tl, tr]).T + + return self.fit(x=x, c=c, n=n, t=t, **fit_options) + + def fit_from_ecdf(self, x, F): + model = Parametric(self, "given ecdf", None, False, False, False) + res = mpp_from_ecfd(self, x, F) + model.dist = self + model.params = np.array(res["params"]) + model.support = self.support + + return model + + def fit_from_non_parametric(self, non_parametric_model): + x, F = non_parametric_model.x, 1 - non_parametric_model.R + return self.fit_from_ecdf(x, F) + + def fit_from_surpyval_data( + self, + surv_data, + how="MLE", + offset=False, + zi=False, + lfp=False, fixed=None, heuristic="Turnbull", init=[], @@ -454,134 +906,11 @@ def fit( alpha: 10.363725328793413 beta: 4.9886821457305865 """ - - if offset and (self.support[0] != 0): - # self.name in ['Normal', 'Beta', 'Uniform', 'Gumbel', 'Logistic']: - detail = "{} distribution cannot be offset".format(self.name) - raise ValueError(detail) - - if how not in PARA_METHODS: - raise ValueError('"how" must be one of: ' + str(PARA_METHODS)) - - if how == "MPP" and self.name == "ExpoWeibull": - detail = ( - "ExpoWeibull distribution does not work" - + " with probability plot fitting" - ) - raise ValueError(detail) - - if t is not None and how == "MSE": - detail = "Mean square error doesn't yet support tuncation" - raise NotImplementedError(detail) - - if t is not None and how == "MOM": - detail = "Maximum product spacing doesn't support tuncation" - raise ValueError(detail) - - if (lfp or zi) & (how != "MLE"): - detail = ( - "Limited failure or zero-inflated models" - + " can only be made with MLE" - ) - raise ValueError(detail) - - if zi & (self.support[0] != 0): - detail = ( - "zero-inflated models can only work" - + "with models starting at 0" - ) - raise ValueError() - - surv_data = surpyval.xcnt_handler( - x=x, - c=c, - n=n, - t=t, - tl=tl, - tr=tr, - xl=xl, - xr=xr, - as_surpyval_dataset=True, - ) x, c, n, t = surv_data.x, surv_data.c, surv_data.n, surv_data.t - - if (c == 1).all(): - raise ValueError("Cannot have only right censored data") - - if (c == -1).all(): - raise ValueError("Cannot have only left censored data") - - if surpyval.utils.check_no_censoring(c) and (how == "MOM"): - raise ValueError("Method of moments doesn't support censoring") - - if ( - (surpyval.utils.no_left_or_int(c)) - and (how == "MPP") - and (not heuristic == "Turnbull") - ): - detail = ( - "Probability plotting estimation with left or " - + "interval censoring only works with Turnbull heuristic" - ) - raise ValueError() - - if ( - (heuristic == "Turnbull") - and (not ((-1 in c) or (2 in c))) - and ((~np.isfinite(t[:, 1])).any()) - ): - # The Turnbull method is extremely memory intensive. - # So if no left or interval censoring and no right-truncation - # then this is equivalent. - heuristic = turnbull_estimator - - if (not offset) & (not zi): - detail_template = """ - Some of your data is outside support of distribution, observed - values must be within [{lower}, {upper}]. - - Are some of your observed values 0, -Inf, or Inf? - """ - - if x.ndim == 2: - if ((x[:, 0] <= self.support[0]) & (c == 0)).any(): - detail = detail_template.format( - lower=self.support[0], upper=self.support[1] - ) - raise ValueError(detail) - elif ((x[:, 1] >= self.support[1]) & (c == 0)).any(): - detail = detail_template.format( - lower=self.support[0], upper=self.support[1] - ) - raise ValueError(detail) - else: - if ((x <= self.support[0]) & (c == 0)).any(): - detail = detail_template.format( - lower=self.support[0], upper=self.support[1] - ) - raise ValueError(detail) - elif ((x >= self.support[1]) & (c == 0)).any(): - detail = detail_template.format( - lower=self.support[0], upper=self.support[1] - ) - raise ValueError(detail) - # Unpack the truncation tl = t[:, 0] tr = t[:, 1] - if (tl[0] != tl).any() and how == "MPS": - raise ValueError( - "Left truncated value can only be single number \ - when using MPS" - ) - - if (tr[0] != tr).any() and how == "MPS": - raise ValueError( - "Right truncated value can only be single number \ - when using MPS" - ) - # Ensure truncation values move to edge where support is not # -np.inf to np.inf if np.isfinite(self.support[0]): @@ -590,6 +919,11 @@ def fit( if np.isfinite(self.support[1]): tr = np.where(tl > self.support[1], self.support[1], tr) + # Validate inputs + self._validate_fit_inputs( + surv_data, how, offset, lfp, zi, heuristic, turnbull_estimator + ) + # Passed checks data = {"x": x, "c": c, "n": n, "t": t} @@ -600,6 +934,8 @@ def fit( if how == "MPS": # Need to set the scalar truncation values # if the MPS method is used. + # since it has already been checked that they are all the same + # we need only get the first item of each truncation array. model.tl = tl[0] model.tr = tr[0] @@ -736,149 +1072,6 @@ def fit( return model - def fit_from_df( - self, - df, - x=None, - c=None, - n=None, - xl=None, - xr=None, - tl=None, - tr=None, - **fit_options - ): - r""" - The central feature to SurPyval's capability. This function aimed to - have an API to mimic the simplicity of the scipy API. That is, to use - a simple :code:`fit()` call, with as many or as few parameters as - is needed. - - Parameters - ---------- - - df : DataFrame - DataFrame of data to be used to create surpyval model - - x : string, optional - column name for the column in df containing the variable data. - If not provided must provide both xl and xr. - - c : string, optional - column name for the column in df containing the censor flag of x. - If not provided assumes all values of x are observed. - - n : string, optional - column name in for the column in df containing the counts of x. - If not provided assumes each x is one observation. - - tl : string or scalar, optional - If string, column name in for the column in df containing the left - truncation data. If scalar assumes each x is left truncated by - that value. If not provided assumes x is not left truncated. - - tr : string or scalar, optional - If string, column name in for the column in df containing the - right truncation data. If scalar assumes each x is right truncated - by that value. If not provided assumes x is not right truncated. - - xl : string, optional - column name for the column in df containing the left interval for - interval censored data. If left interval is -Inf, assumes left - censored. If xl[i] == xr[i] assumes observed. Cannot be provided - with x, must be provided with xr. - - xr : string, optional - column name for the column in df containing the right interval - for interval censored data. If right interval is Inf, assumes - right censored. If xl[i] == xr[i] assumes observed. Cannot be - provided with x, must be provided with xl. - - fit_options : dict, optional - dictionary of fit options that will be passed to the :code:`fit` - method, see that method for options. - - Returns - ------- - - model : Parametric - A parametric model with the fitted parameters and methods for - all functions of the distribution using the fitted parameters. - - - Examples - -------- - >>> import surpyval as surv - >>> df = surv.datasets.BoforsSteel.df - >>> model = surv.Weibull.fit_from_df(df, x='x', n='n', offset=True) - >>> print(model) - Parametric SurPyval Model - ========================= - Distribution : Weibull - Fitted by : MLE - Offset (gamma) : 39.76562962867477 - Parameters : - alpha: 7.141925216146524 - beta: 2.6204524040137844 - """ - - if not type(df) == pd.DataFrame: - raise ValueError("df must be a pandas DataFrame") - - if (x is not None) and ((xl is not None) or (xr is not None)): - raise ValueError("Cannot use `x` and (`xl` and `xr`) together") - - if x is not None: - x = df[x].astype(float) - else: - xl = df[xl].astype(float) - xr = df[xr].astype(float) - x = np.vstack([xl, xr]).T - - if c is not None: - c = df[c].values.astype(int) - - if n is not None: - n = df[n].values.astype(int) - - if tl is not None: - if type(tl) == str: - tl = df[tl].values.astype(float) - elif np.isscalar(tl): - tl = (np.ones(df.shape[0]) * tl).astype(float) - else: - raise ValueError("`tl` must be scalar or column label string") - else: - tl = np.ones(df.shape[0]) * -np.inf - - if tr is not None: - if type(tr) == str: - tr = df[tr].values.astype(float) - elif np.isscalar(tr): - tr = (np.ones(df.shape[0]) * tr).astype(float) - else: - detail = "`tr` must be scalar or a column label string" - raise ValueError(detail) - else: - tr = np.ones(df.shape[0]) * np.inf - - t = np.vstack([tl, tr]).T - - return self.fit(x=x, c=c, n=n, t=t, **fit_options) - - def fit_from_ecdf(self, x, F): - model = Parametric(self, "given ecdf", None, False, False, False) - res = mpp_from_ecfd(self, x, F) - model.dist = self - model.params = np.array(res["params"]) - model.support = self.support - - return model - - def fit_from_non_parametric(self, non_parametric_model): - x, F = non_parametric_model.x, 1 - non_parametric_model.R - return self.fit_from_ecdf(x, F) - def from_params(self, params, gamma=None, p=None, f0=None): r""" diff --git a/surpyval/utils/surpyval_data.py b/surpyval/utils/surpyval_data.py index 653ebeb..93df31a 100644 --- a/surpyval/utils/surpyval_data.py +++ b/surpyval/utils/surpyval_data.py @@ -4,8 +4,11 @@ class SurpyvalData: def __init__(self, x, c, n, t): self.x, self.c, self.n, self.t = x, c, n, t + self.tl = self.t[:, 0] + self.tr = self.t[:, 1] self.x_min, self.x_max = np.min(x), np.max(x) self.split_to_observation_types() + self._index = 0 def split_to_observation_types(self): self.x_o, self.n_o = self._split_by_mask(self.c == 0) @@ -30,6 +33,30 @@ def _split_by_mask(self, mask, x=None): x = self.x[mask] if self.x.ndim == 1 else self.x[mask, 0] return x, self.n[mask] + def __getitem__(self, index): + return SurpyvalData( + self.x[index], + self.c[index], + self.n[index], + self.t[index], + ) + + def __iter__(self): + return self + + def __next__(self): + if self._index < len(self.x): + result = ( + self.x[self._index], + self.c[self._index], + self.n[self._index], + self.t[self._index], + ) + self._index += 1 + return result + else: + raise StopIteration + def __repr__(self): return f""" SurpyvalData(\nx={self.x},\nc={self.c},\nn={self.n},\nt={self.t})