Skip to content

Commit

Permalink
major revamp of MLE using SurpyvalData and better loglike functions f…
Browse files Browse the repository at this point in the history
…or each class of observation and truncation.
  • Loading branch information
derrynknife committed Apr 12, 2023
1 parent 4a76b2b commit 0e484ac
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 146 deletions.
204 changes: 68 additions & 136 deletions surpyval/parametric/fitters/mle.py
@@ -1,51 +1,29 @@
import copy
import sys

from autograd import hessian, jacobian
from autograd.numpy.linalg import inv
from scipy.optimize import approx_fprime, minimize
from scipy.optimize import minimize

from surpyval import np


def _create_censor_flags(x_mle, gamma, c, dist):
if 2 in c:
l_flag = x_mle[:, 0] <= np.min([gamma, dist.support[0]])
r_flag = x_mle[:, 1] >= dist.support[1]
mask = np.vstack([l_flag, r_flag]).T
inf_c_flags = (mask).astype(int)
x_mle[(c == 2).reshape(-1, 1) & mask] = 1
else:
inf_c_flags = np.zeros_like(x_mle)

return inf_c_flags, x_mle


def mle(model):
"""
Maximum Likelihood Estimation (MLE)
"""
dist = model.dist
x, c, n, t = (
model.data["x"],
model.data["c"],
model.data["n"],
model.data["t"],
)
# Function that adds in any fixed parameters
const = model.fitting_info["const"]
trans = model.fitting_info["transform"]
# Inverse transform function for parameters. i.e. from (None, None) to
# correct bounded values
inv_trans = model.fitting_info["inv_trans"]
# Initial guess
init = model.fitting_info["init"]
fixed_idx = model.fitting_info["fixed_idx"]
offset = model.offset
lfp = model.lfp
zi = model.zi

if hasattr(dist, "mle"):
return dist.mle(
x, c, n, t, const, trans, inv_trans, init, fixed_idx, offset
)
# Offset, Limited Failure Population, Zero Inflated logic.
offset, lfp, zi = model.offset, model.lfp, model.zi

if hasattr(model.dist, "mle"):
return model.dist.mle(model.surv_data)

results = {}

Expand All @@ -54,95 +32,57 @@ def mle(model):
doesn't fail. Because autograd fails if it encounters any inf, nan, -inf
etc even if they don't affect the gradient. A must for autograd
"""
t_flags = np.ones_like(t)
t_mle = copy.copy(t)
# Create flags to indicate where the truncation values are infinite
t_flags[:, 0] = np.where(np.isfinite(t[:, 0]), 1, 0)
t_flags[:, 1] = np.where(np.isfinite(t[:, 1]), 1, 0)
# Convert the infinite values to a finite value to ensure
# the autodiff functions don't fail
t_mle[:, 0] = np.where(t_flags[:, 0] == 1, t[:, 0], 1)
t_mle[:, 1] = np.where(t_flags[:, 1] == 1, t[:, 1], 1)

results["t_flags"] = t_flags
results["t_mle"] = t_mle

# Create the objective function

def fun(
params, offset=False, lfp=False, zi=False, transform=True, gamma=0.0
params,
offset=False,
lfp=False,
zi=False,
transform=True,
gamma=0,
f0=0,
p=1,
):
x_mle = np.copy(x)

# Transform parameters from (-Inf, Inf) range to parameter
# to correct bounded values
if transform:
params = inv_trans(const(params))

# Unpack offset, zi, lfp parameters
if offset:
gamma = params[0]
params = params[1:]
else:
# Use the assumed value
pass
gamma, *params = params

if zi:
f0 = params[-1]
params = params[0:-1]
else:
f0 = 0.0
*params, f0 = params

if lfp:
p = params[-1]
params = params[0:-1]
else:
p = 1.0

inf_c_flags, x_mle = _create_censor_flags(x_mle, gamma, c, dist)
return dist.neg_ll(
x_mle, c, n, inf_c_flags, t_mle, t_flags, gamma, p, f0, *params
)
*params, p = params

return model.dist._neg_ll_func(model.surv_data, *params, gamma, f0, p)

old_err_state = np.seterr(all="ignore")
use_initial = False

if zi:

def jac(x, offset, lfp, zi, transform):
return approx_fprime(
x,
fun,
np.sqrt(np.finfo(float).eps),
offset,
lfp,
zi,
transform,
)

hess = None
else:
jac = jacobian(fun)
hess = hessian(fun)

res = minimize(
fun,
init,
args=(offset, lfp, zi, True),
method="Newton-CG",
jac=jac,
hess=hess,
)

if (res.success is False) or (np.isnan(res.x).any()):
jac = jacobian(fun)
hess = hessian(fun)

# Try easiest, to most complex optimisations
for method, jac_i, hess_i in [
("Nelder-Mead", None, None),
("BFGS", None, None),
("TNC", jac, None),
("Newton-CG", jac, hess),
]:
res = minimize(
fun, init, args=(offset, lfp, zi, True), method="BFGS", jac=jac
fun,
init,
args=(offset, lfp, zi, True),
method=method,
jac=jac_i,
hess=hess_i,
)

if (res.success is False) or (np.isnan(res.x).any()):
res = minimize(fun, init, args=(offset, lfp, zi, True))

if "Desired error " in res["message"]:
res_tmp = minimize(
fun, res.x, args=(offset, lfp, zi, True), method="TNC", jac=jac
)
if res_tmp.success:
res = res_tmp
if res.success:
break

if "Desired error not necessarily" in res["message"]:
print(
Expand All @@ -155,60 +95,52 @@ def jac(x, offset, lfp, zi, transform):

elif (not res.success) | (np.isnan(res.x).any()):
print(
"MLE Failed: Try making the values of the data closer to "
"MLE Failed, using MPP results instead. "
+ "Try making the values of the data closer to "
+ "1 by dividing or multiplying by some constant."
+ "\n\nAlternately try setting the `init` keyword in the `fit()`"
+ " method to a value you believe is closer."
+ "A good way to do this is to set any shape parameter to 1. "
+ "and any scale parameter to be the mean of the data "
+ "(or it's inverse)"
+ "\n\nModel returned with inital guesses",
+ "\n\nModel returned with inital guesses (MPP)",
file=sys.stderr,
)

use_initial = True

if use_initial:
p_hat = inv_trans(const(init))
params = inv_trans(const(init))
else:
p_hat = inv_trans(const(res.x))
params = inv_trans(const(res.x))

if offset:
results["gamma"] = p_hat[0]
params = p_hat[1:]
parameters_for_hessian = copy.copy(params)
gamma = params[0]
params = params[1:]
else:
results["gamma"] = 0
params = p_hat
parameters_for_hessian = copy.copy(params)
gamma = 0.0

results["gamma"] = gamma

if zi:
results["f0"] = params[-1]
f0 = params[-1]
params = params[0:-1]
else:
results["f0"] = 0.0
params = params
f0 = 0.0
results["f0"] = f0

if lfp:
results["p"] = params[-1]
results["params"] = params[0:-1]
p = params[-1]
params = params[0:-1]
else:
results["p"] = 1.0
results["params"] = params

try:
if zi or lfp:
results["hess_inv"] = None
else:
results["hess_inv"] = inv(
hess(
parameters_for_hessian,
*(False, lfp, zi, False, results["gamma"])
)
)
except np.linalg.LinAlgError:
results["hess_inv"] = None
p = 1.0

results["p"] = p
results["params"] = params
# Do not account for variation of gamma, f0, p in confidence bounds.
results["hess_inv"] = inv(
hess(params, *(False, False, False, False, gamma, f0, p))
)
results["_neg_ll"] = res["fun"]
results["res"] = res

Expand Down
87 changes: 85 additions & 2 deletions surpyval/parametric/parametric_fitter.py
Expand Up @@ -2,9 +2,11 @@

import pandas as pd
from scipy.integrate import quad
from scipy.special import expit

import surpyval
from surpyval import np
from surpyval.utils import check_x_not_empty

from ..nonparametric import plotting_positions as pp
from .fitters import bounds_convert, fix_idx_and_function
Expand Down Expand Up @@ -121,7 +123,78 @@ def like_t(self, t, t_flags, *params):
t_denom = tr_denom - tl_denom
return t_denom

@check_x_not_empty
def ll_observed(self, x, n, *params):
*params, gamma, f0, p = params
x = x - gamma
n_zeros = np.sum(n[x == 0])
zero_weight = n_zeros * np.log(f0) if n_zeros != 0 else 0
non_zero_mask = x != 0
N = np.sum(n[non_zero_mask])
return (
(n[non_zero_mask] * self.log_df(x[non_zero_mask], *params)).sum()
+ zero_weight
+ N * np.log(p - f0)
)

@check_x_not_empty
def ll_right_censored(self, x, n, *params):
*params, gamma, f0, p = params
x = x - gamma
if p == 1:
return np.sum(n * (np.log(1 - f0) + self.log_sf(x, *params)))
else:
F = self.ff(x, *params)
# ALso could be:
# np.sum(n * np.log((1 - p + (p - f0)*self.sf(x, *params))))
return np.sum(n * np.log(1 - f0 - (p - f0) * F))

@check_x_not_empty
def ll_left_censored(self, x, n, *params):
*params, gamma, f0, p = params
x = x - gamma
if f0 == 1:
return np.sum(n * self.log_ff(x, *params)) + n.sum() * np.log(p)
else:
return np.sum(n * np.log(f0 + (p - f0) * self.ff(x, *params)))

@check_x_not_empty
def ll_interval_or_truncated(self, xl, xr, n, *params):
*params, gamma, f0, p = params
xr = xr - gamma
xl = xl - gamma
right = np.where(np.isfinite(xr), self.ff(xr, *params), 1)
left = np.where(np.isfinite(xl), self.ff(xl, *params), 0)
return np.sum(n * np.log(right - left)) + n.sum() * np.log(p - f0)

def parameter_transform(self, x_min, params):
*params, gamma, f0, p = params
p = expit(p)
f0 = expit(f0)
gamma = x_min - np.exp(gamma) if gamma < 0 else x_min - 1 - gamma
params = self._parameter_transform(*params)
return (*params, gamma, f0, p)

def _neg_ll_func(self, data, *params):
# params = self.parameter_transform(data.x_min, *params)
return -(
self.ll_observed(data.x_o, data.n_o, *params)
+ self.ll_right_censored(data.x_r, data.n_r, *params)
+ self.ll_left_censored(data.x_l, data.n_l, *params)
+ self.ll_interval_or_truncated(
data.x_il, data.x_ir, data.n_i, *params
)
- self.ll_interval_or_truncated(
data.x_tl, data.x_tr, data.n_t, *params
)
)

def neg_ll(self, x, c, n, inf_c_flags, t, t_flags, gamma, p, f0, *params):
"""
This is absolutely awful and needs to be replaced with the
above function.
"""

x = copy(x) - gamma

if 2 in c:
Expand Down Expand Up @@ -420,9 +493,18 @@ def fit(
)
raise ValueError()

x, c, n, t = surpyval.xcnt_handler(
x=x, c=c, n=n, t=t, tl=tl, tr=tr, xl=xl, xr=xr
surv_data = surpyval.xcnt_handler(
x=x,
c=c,
n=n,
t=t,
tl=tl,
tr=tr,
xl=xl,
xr=xr,
as_surpyval_dataset=True,
)
x, c, n, t = surv_data.x, surv_data.c, surv_data.n, surv_data.t

if (c == 1).all():
raise ValueError("Cannot have only right censored data")
Expand Down Expand Up @@ -513,6 +595,7 @@ def fit(
data = {"x": x, "c": c, "n": n, "t": t}

model = Parametric(self, how, data, offset, lfp, zi)
model.surv_data = surv_data
fitting_info = {}

if how == "MPS":
Expand Down

0 comments on commit 0e484ac

Please sign in to comment.