In [None]:
%load_ext autoreload
%autoreload 2
from functools import partial
from copy import deepcopy
from survlimepy import SurvLimeExplainer
from survlimepy.utils.neighbours_generator import NeighboursGenerator
from survlimepy.load_datasets import RandomSurvivalData
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.nonparametric import nelson_aalen_estimator
import sklearn
from sklearn.utils import check_random_state
import numpy as np
import cvxpy as cp
from cvxpy.atoms.affine.binary_operators import MulExpression as MulExp

In [None]:
# Generate data
n_points = 500
#true_coef = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
true_coef = [1, 1]
r = 1
#center = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
center = [0, 0]
prob_event = 0.9
lambda_weibull = 10**(-6)
v_weibull = 2
n_features = len(true_coef)

rsd = RandomSurvivalData(
    center=center,
    radius=r,
    coefficients=true_coef,
    prob_event=prob_event,
    lambda_weibull=lambda_weibull,
    v_weibull=v_weibull,
    time_cap=None
)

# Train
X, time_to_event, delta = rsd.random_survival_data(num_points=n_points)
z = [(d, t) for d, t in zip(delta, time_to_event)]
y = np.array(z, dtype=[("delta", np.bool_), ("time_to_event", np.float32)])
total_row_train = X.shape[0]
print('total_row_train:', total_row_train)
unique_times = np.sort(np.unique(time_to_event))

In [None]:
# Fit a Cox model
cox = CoxPHSurvivalAnalysis()
cox.fit(X, y)
print(cox.coef_)

In [None]:
# SurvLime for COX
num_samples = 1000
max_difference_time_allowed = None
max_hazard_value_allowed = None
x_new = center
explainer_cox = SurvLimeExplainer(
    training_features=X,
    training_events=[tp[0] for tp in y],
    training_times=[tp[1] for tp in y],
    model_output_times=cox.event_times_,
    sample_around_instance=True,
    random_state=10,
)

b_cox = explainer_cox.explain_instance(
    data_row=x_new,
    predict_fn=cox.predict_cumulative_hazard_function,
    num_samples=num_samples,
    max_difference_time_allowed=max_difference_time_allowed,
    max_hazard_value_allowed=max_hazard_value_allowed,
    verbose=False,
)

print(b_cox)

In [None]:
# Kernel
kernel_width = np.sqrt(len(x_new)) * 0.75
def custom_kernel_width(d: np.ndarray, kernel_width: float) -> np.ndarray:
    return np.sqrt(np.exp(-(d**2) / kernel_width**2))

kernel_fn = partial(custom_kernel_width, kernel_width=kernel_width)

In [None]:
# Generate the neighbours
data_point = np.array(x_new).reshape(1, -1)
neighbours_generator = NeighboursGenerator(
    training_features=X,
    data_row=data_point,
    random_state=check_random_state(10),
)
scaled_data = neighbours_generator.generate_neighbours(num_samples=num_samples, sample_around_instance=True)
print(scaled_data)

In [None]:
def compute_nelson_aalen_estimator(event: np.ndarray, time: np.ndarray) -> np.ndarray:
        nelson_aalen = nelson_aalen_estimator(event, time)
        H0 = nelson_aalen[1]
        m = H0.shape[0]
        H0 = np.reshape(H0, newshape=(m, 1))
        return H0

In [None]:
#b = cp.Variable((len(x_new), 1))
#b = np.array(true_coef).reshape(-1, 1)
b = np.array(
    [
        144.62512727, 461.45983917, 92.86307043,
        -37667.68360483, 328.24800985, -124.94553269,
        1119.01058628, -2037.22348856, -58136.93066531,
        666.83663439, -316.71479306, -238.55402923,
        299.33087086
    ]
).reshape(-1, 1)
training_events=[tp[0] for tp in y]
training_times=[tp[1] for tp in y]
kernel_distance = "euclidean"
epsilon = 10 ** (-6)
num_features = len(x_new)
unique_times_to_event = np.sort(unique_times)
m = unique_times_to_event.shape[0]
print("m: ", m)
FN_pred = cox.predict_cumulative_hazard_function(scaled_data, return_array=True)
print("FN_pred.shape:", FN_pred.shape)
print("max FN_pred:", np.max(FN_pred))
print("min FN_pred:", np.min(FN_pred))
H0 = compute_nelson_aalen_estimator(training_events, training_times)
print("H0.shape:", H0.shape)
print("max H0:", np.max(H0))
print("min H0:", np.min(H0))
distances = sklearn.metrics.pairwise_distances(scaled_data, data_point, metric=kernel_distance).ravel()
#print("distances:", distances)
H_score = deepcopy(FN_pred)
H_score = np.clip(a=H_score, a_min=None, a_max=10)
print("H_score.shape:", H_score.shape)
print("max H_score:", np.max(H_score))
print("min H_score:", np.min(H_score))
log_H = np.log(H_score + epsilon)
print("log_H.shape:", log_H.shape)
print("max log_H:", np.max(log_H))
print("min log_H:", np.min(log_H))
log_correction = np.divide(H_score, log_H)
print("log_correction.shape:", log_correction.shape)
print("max log_correction:", np.max(log_correction))
print("min log_correction:", np.min(log_correction))
H = np.reshape(np.array(H_score), newshape=(num_samples, m))
print("H.shape:", H.shape)
print("max H:", np.max(H))
print("min H:", np.min(H))
LnH = np.log(H + epsilon)
print("LnH.shape:", LnH.shape)
print("max LnH:", np.max(LnH))
print("min LnH:", np.min(LnH))
LnH0 = np.log(H0 + epsilon)
print("LnH0.shape:", LnH0.shape)
print("max LnH0:", np.max(LnH0))
print("min LnH0:", np.min(LnH0))
logs = np.reshape(log_correction, newshape=(num_samples, m))
print("logs.shape:", logs.shape)
print("max logs:", np.max(logs))
print("min logs:", np.min(logs))
weights = kernel_fn(distances)
print("weights.shape:", weights.shape)
print("max weights:", np.max(weights))
print("min weights:", np.min(weights))
w = np.reshape(weights, newshape=(num_samples, 1))
print("w.shape:", w.shape)
print("max w:", np.max(w))
print("min w:", np.min(w))
# Time differences
t = np.empty(shape=(m + 1, 1))
t[:m, 0] = unique_times_to_event
t[m, 0] = t[m - 1, 0] + epsilon
delta_t = [min(t[i + 1, 0] - t[i, 0], max_difference_time_allowed)  for i in range(m)]
delta_t = np.reshape(np.array(delta_t), newshape=(m, 1))
print("delta_t.shape:", delta_t.shape)
print("max delta_t:", np.max(delta_t))
print("min delta_t:", np.min(delta_t))
# Matrices to produce the proper sizes
is_numpy = isinstance(b, np.ndarray)

ones_N = np.ones(shape=(num_samples, 1))
ones_m_1 = np.ones(shape=(m, 1))
B = np.dot(ones_N, LnH0.T)
C = LnH - B
print("b.shape", b.shape)
Z =  scaled_data @ b
if is_numpy:
    print("Z.shape:", Z.shape)
    print("max Z:", np.max(Z))
    print("min Z:", np.min(Z))
D = Z @ ones_m_1.T
E = C - D
if is_numpy:
    print("E.shape:", E.shape)
    print("max E:", np.max(E))
    print("min E:", np.min(E))
V_sq = np.square(log_correction)
if is_numpy:
    print("V_sq.shape:", V_sq.shape)
    print("max V_sq:", np.max(V_sq))
    print("min V_sq:", np.min(V_sq))
if is_numpy:
    E_norm = np.power(E, 2)
else:
    E_norm = cp.power(E, 2)
if is_numpy:
    print("E_norm.shape:", E_norm.shape)
    print("max E_norm:", np.max(E_norm))
    print("min E_norm:", np.min(E_norm))
if is_numpy:
    F = np.multiply(E_norm, V_sq)
else:
    F = cp.multiply(E_norm, V_sq)
if is_numpy:
    print("F.shape:", F.shape)
    print("max F:", np.max(F))
    print("min F:", np.min(F))
G = F @ delta_t
if is_numpy:
    print("G.shape:", G.shape)
    print("max G:", np.max(G))
    print("min G:", np.min(G))
funct = G.T @ weights
if not is_numpy:
    objective = cp.Minimize(funct)
    prob = cp.Problem(objective)
    result = prob.solve(verbose=True)
    print(b.value)
if is_numpy:
    print("funct:", funct)

In [None]:
scaled_data.shape