In [33]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from xgbsurv.models.utils import transform_back, transform


In [34]:
df = pd.read_csv('/Users/JUSC/Documents/xgbsurv/xgbsurv/tests/simulation_data/survival_simulation_1000.csv')
risks = pd.read_csv('/Users/JUSC/Documents/xgbsurv/xgbsurv/tests/simulation_data/survival_simulation_preds1000.csv')
#X_train, X_test, y_train, y_test = train_test_split(df, risks, test_size=0.33, random_state=42)

In [35]:
def breslow_estimator(log_hazard, time, event):
    #time, event = transform_back(y)
    risk_score = np.exp(log_hazard)

    is_sorted = lambda a: np.all(a[:-1] <= a[1:])

    if is_sorted(time) == False:
        order = np.argsort(time, kind="mergesort")
        time = time[order]
        event = event[order]
        risk_score = risk_score[order]

    uniq_times = np.unique(time)
    idx = np.digitize(time, np.unique(time))
    breaks = np.flatnonzero(np.concatenate(([1], np.diff(idx))))
    # numpy diff nth discrete difference over index, add 1 at the beginning
    # flatnonzero return indices that are nonzero in flattened version
    n_events = np.add.reduceat(event, breaks, axis=0)

    # consider removing zero rows, would this be the right approach?
    risk_matrix = np.unique((np.outer(time,time)>=np.square(time)).astype(int).T, axis=0)
    denominator = np.sum(risk_score[None,:]*risk_matrix,axis=1)[::-1]     

    cum_hazard_baseline = np.cumsum(n_events / denominator)
    baseline_survival = np.exp(-cum_hazard_baseline)
    return uniq_times, cum_hazard_baseline, baseline_survival

In [36]:
def breslow_estimator_loop(    
    predictor: np.array,
    time: np.array,
    event: np.array

):
    exp_predictor: np.array = np.exp(predictor)
    local_risk_set: float = np.sum(exp_predictor)
    event_mask: np.array = event.astype(np.bool_)
    n_unique_events: int = np.unique(time[event_mask]).shape[0]
    cumulative_baseline_hazards: np.array = np.zeros(n_unique_events)
    n_events_counted: int = 0
    local_death_set: int = 0
    accumulated_risk_set: float = 0
    previous_time: float = time[0]

    for _ in range(len(time)):
        sample_time: float = time[_]
        sample_event: int = event[_]
        sample_predictor: float = exp_predictor[_]

        if sample_time > previous_time and local_death_set:
            cumulative_baseline_hazards[n_events_counted] = local_death_set / (
                local_risk_set
            )

            local_death_set = 0
            local_risk_set -= accumulated_risk_set
            accumulated_risk_set = 0
            n_events_counted += 1

        if sample_event:
            local_death_set += 1
        accumulated_risk_set += sample_predictor
        previous_time = sample_time

    cumulative_baseline_hazards[n_events_counted] = local_death_set / (
        local_risk_set
    )

    return (
        np.unique(time[event_mask]),
        np.cumsum(cumulative_baseline_hazards),
    )

In [37]:
log_hazard, time, event = risks.to_numpy(), df.time.to_numpy(), df.event.to_numpy()

In [38]:
res_loop = breslow_estimator_loop( log_hazard[:800], time[:800], event[:800])
res = breslow_estimator(log_hazard, time, event)


In [24]:
# times comparison
times_loop = res_loop[0]
print(times_loop.shape)
times = res[0]
print(times.shape)

(296,)
(1000,)


In [25]:
# hazards comparison
cum_hazard_loop = res_loop[1]
print(cum_hazard_loop.shape)
cum_hazard = res[1]
print(cum_hazard.shape)

(296,)
(1000,)


In [26]:
def get_cumulative_hazard_function(X_train: np.array, 
        X_test: np.array, y_train: np.array, y_test: np.array,
        predictor_train: np.array, predictor_test: np.array
    #self, X: np.array, time: np.array
    ) -> pd.DataFrame:
    # inputs necessary: train_time, train_event, train_preds, 
    time_train, event_train = transform_back(y_train)
    time_test, event_test = transform_back(y_test)
    if np.min(time_test) < 0:
        raise ValueError(
            "Times for survival and cumulative hazard prediction must be greater than or equal to zero."
            + f"Minimum time found was {np.min(time_test)}."
            + "Please remove any times strictly less than zero."
        )
    cumulative_baseline_hazards_times: np.array
    cumulative_baseline_hazards: np.array
    (
        cumulative_baseline_hazards_times,
        cumulative_baseline_hazards,
    ) = breslow_estimator_loop(
        time=time_train, event=event_train, predictor=predictor_train
    )
    cumulative_baseline_hazards = np.concatenate(
        [np.array([0.0]), cumulative_baseline_hazards]
    )
    cumulative_baseline_hazards_times: np.array = np.concatenate(
        [np.array([0.0]), cumulative_baseline_hazards_times]
    )
    cumulative_baseline_hazards: np.array = np.tile(
        A=cumulative_baseline_hazards[
            np.digitize(
                x=time_test, bins=cumulative_baseline_hazards_times, right=False
            )
            - 1
        ],
        reps=X_test.shape[0],
    ).reshape((X_test.shape[0], time_test.shape[0]))
    log_hazards: np.array = (
        np.tile(
            A= predictor_test, #self.predict(X),
            reps=time_test.shape[0],
        )
        .reshape((time_test.shape[0], X_test.shape[0]))
        .T
    )
    cumulative_hazard_function: pd.DataFrame = pd.DataFrame(
        cumulative_baseline_hazards * np.exp(log_hazards),
        columns=time_test,
    )
    return cumulative_hazard_function

In [40]:
X_train = df[['x_1', 'x_2', 'x_3', 'x_4', 'x_5']][:800]
X_test = df[['x_1', 'x_2', 'x_3', 'x_4', 'x_5']][800:]
y_train = transform(time[:800], event[:800])
y_test = transform(time[200:], event[200:])
predictor_train,predictor_test = log_hazard[:800],log_hazard[800:]

get_cumulative_hazard_function(X_train, 
        X_test, y_train, y_test,
        predictor_train, predictor_test
    )

Unnamed: 0,23.001357,35.031683,28.161253,23.634923,23.662349,0.432474,16.520627,31.427819,34.634629,31.020862,...,2.078888,28.222006,13.863866,23.995064,1.573249,34.345211,24.376159,11.159887,32.158197,11.889380
0,0.055914,0.055914,0.055914,0.055914,0.07331,0.003479,0.07331,0.07331,1.448213,1.448213,...,0.062513,0.519923,2.044217,2.31035,0.22938,2.31035,0.022355,0.014164,0.022355,0.015574
1,0.055914,0.055914,0.055914,0.055914,0.07331,0.003479,0.07331,0.07331,1.448213,1.448213,...,0.062513,0.519923,2.044217,2.31035,0.22938,2.31035,0.022355,0.014164,0.022355,0.015574
2,0.055914,0.055914,0.055914,0.055914,0.07331,0.003479,0.07331,0.07331,1.448213,1.448213,...,0.062513,0.519923,2.044217,2.31035,0.22938,2.31035,0.022355,0.014164,0.022355,0.015574
3,0.055914,0.055914,0.055914,0.055914,0.07331,0.003479,0.07331,0.07331,1.448213,1.448213,...,0.062513,0.519923,2.044217,2.31035,0.22938,2.31035,0.022355,0.014164,0.022355,0.015574
4,0.055914,0.055914,0.055914,0.055914,0.07331,0.003479,0.07331,0.07331,1.448213,1.448213,...,0.062513,0.519923,2.044217,2.31035,0.22938,2.31035,0.022355,0.014164,0.022355,0.015574
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
195,0.055914,0.055914,0.055914,0.055914,0.07331,0.003479,0.07331,0.07331,1.448213,1.448213,...,0.062513,0.519923,2.044217,2.31035,0.22938,2.31035,0.022355,0.014164,0.022355,0.015574
196,0.055914,0.055914,0.055914,0.055914,0.07331,0.003479,0.07331,0.07331,1.448213,1.448213,...,0.062513,0.519923,2.044217,2.31035,0.22938,2.31035,0.022355,0.014164,0.022355,0.015574
197,0.055914,0.055914,0.055914,0.055914,0.07331,0.003479,0.07331,0.07331,1.448213,1.448213,...,0.062513,0.519923,2.044217,2.31035,0.22938,2.31035,0.022355,0.014164,0.022355,0.015574
198,0.055914,0.055914,0.055914,0.055914,0.07331,0.003479,0.07331,0.07331,1.448213,1.448213,...,0.062513,0.519923,2.044217,2.31035,0.22938,2.31035,0.022355,0.014164,0.022355,0.015574


In [28]:
df.columns

Index(['x_1', 'x_2', 'x_3', 'x_4', 'x_5', 'time', 'event'], dtype='object')

In [32]:
X_train.shape

(800, 5)