In [6]:
# https://sci-hub.se/10.1002/sim.1203
# https://journals.sagepub.com/doi/pdf/10.1177/1536867x0900900206

# Dataset

In [7]:
from lifelines.datasets import load_rossi

data = load_rossi()
data.dropna(inplace=True)
print(data.shape)
data.head()

(432, 9)


Unnamed: 0,week,arrest,fin,age,race,wexp,mar,paro,prio
0,20,1,0,27,1,0,0,1,3
1,17,1,0,18,1,0,0,1,8
2,25,1,0,19,0,1,0,1,13
3,52,0,1,23,1,1,1,1,1
4,52,0,0,19,0,1,0,1,3


In [8]:
event_col = "arrest"
duration_col = "week"

X = data.drop(columns=[event_col, duration_col])
X.shape

(432, 7)

In [9]:
from sklearn.model_selection import train_test_split

train_idx, test_idx = train_test_split(
    range(data.shape[0]), test_size=0.15, random_state=42, stratify=data[event_col]
)
data_train, data_test = data.iloc[train_idx], data.iloc[test_idx]
X_train, X_test = X.iloc[train_idx].values, X.iloc[test_idx].values
data_train.shape, data_test.shape

((367, 9), (65, 9))

# Pre-processing 

In [10]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Knots

In [11]:
import numpy as np

log_t_test = np.log(data_test[duration_col])
log_t_train = np.log(data_train[duration_col])
min(log_t_train), max(log_t_train)

(0.0, 3.9512437185814275)

In [12]:
# Knot locations: Centiles of the distribution of **uncensored** log event times
# - Boundary knots: placed at the 0th and 100th centiles (min and max values)
# - Internal knots: internal knots are placed at the centiles between the min and max   
knots = np.percentile(log_t_train[data_train[event_col] == 1], [0, 25, 50, 75, 100])
min(knots), max(knots)

(0.0, 3.9512437185814275)

# Splines

In [13]:
def relu(x):
    return max(x, 0)


def spline_basis(x, k_j, k_min, k_max, derivative=False):
    """Computes the basis function S(x; k_j)."""
    # Scaling coefficient 
    s = (k_max - k_j) / (k_max - k_min)

    if derivative:
        # Derivative of the spline basis function
        return 3 * relu(x - k_j) ** 2 - 3 * s * relu(x - k_min) ** 2 - 3 * (1 - s) * relu(x - k_max) ** 2

    # Spline basis function 
    return relu(x - k_j) ** 3 - s * relu(x - k_min) ** 3 - (1 - s) * relu(x - k_max) ** 3


def spline_design_matrix(ln_t, knots):
    """Computes the spline function s(x; γ, k)."""
    # Boundary knots
    k_min, k_max = knots[0], knots[-1]
    # Construct basis functions over internal knots 
    basis = [spline_basis(ln_t, k_j, k_min, k_max) for k_j in knots[1:-1]]
    # Design matrix 
    return np.array([ln_t] + basis)


def spline_derivative_design_matrix(ln_t, knots):
    """Computes the spline function s(x; γ, k)."""
    # Boundary knots
    k_min, k_max = knots[0], knots[-1]
    # Construct basis functions over internal knots 
    basis = [spline_basis(ln_t, k_j, k_min, k_max, derivative=True) for k_j in knots[1:-1]]
    # Design matrix 
    return 1 / np.exp(ln_t) * np.array([1] + basis)


def create_splines(log_t, knots):

    D, dDdt = [], []
    for log_time in log_t:
        D.append(spline_design_matrix(log_time, knots))
        dDdt.append(spline_derivative_design_matrix(log_time, knots))
    
    # Cast to <numpy.ndarray>
    return np.array(D), np.array(dDdt)


D, D_prime = create_splines(log_t_train, knots)
D.shape, D_prime.shape

((367, 4), (367, 4))

# Parameter initialization 

In [14]:
delta_train = data_train[event_col].values[:, None]
delta_test = data_test[event_col].values[:, None]
delta_train.shape

(367, 1)

In [15]:
np.random.seed(42)
beta0 = np.random.random((1, X.shape[1]))
gamma0 = np.random.random((1, D.shape[1]))

# Fitting centralized model

In [16]:
import tensorflow as tf 

In [87]:
# Optimization variables 
gamma_var = tf.Variable(gamma0, dtype=tf.float32)
beta_var = tf.Variable(beta0, dtype=tf.float32)  

# Dataset 
X_tf = tf.cast(X_train, dtype=tf.float32)

# Spline design matrices 
D_tf = tf.cast(D, dtype=tf.float32)
D_prime_tf = tf.cast(D_prime, dtype=tf.float32)

delta_tf = tf.cast(delta_train, dtype=tf.float32)

reg_coef = 0


def neg_log_likelihood():
    phi = D_tf @ tf.transpose(gamma_var) + X_tf @ tf.transpose(beta_var)
    
    ds_dt = tf.clip_by_value(D_prime_tf @ tf.transpose(gamma_var), 1e-8, 1e8) 

    # Log-likelihood function for each data sample (N x 1)
    log_likelihood = delta_tf * (phi + tf.math.log(ds_dt)) - tf.math.exp(phi)

    # Regularization
    reg_gamma = reg_coef * tf.norm(gamma_var)
    reg_beta = reg_coef * tf.norm(beta_var)

    return -1.0 * tf.reduce_sum(log_likelihood, axis=0) + reg_gamma + reg_beta


optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)

for epoch in tf.range(10):

    with tf.GradientTape() as tape:
        loss_value = neg_log_likelihood()

    # Compute gradients
    gradients = tape.gradient(loss_value, [beta_var, gamma_var])
    # Apply gradients to update gamma and beta
    optimizer.apply_gradients(zip(gradients, [beta_var, gamma_var]))
    
    print("Loss:", float(loss_value.numpy()))

Loss: 2590.5224609375
Loss: 2567.444580078125
Loss: 2544.60791015625
Loss: 2522.0087890625
Loss: 2499.6416015625
Loss: 2477.501708984375
Loss: 2455.5830078125
Loss: 2420.611328125
Loss: 2398.0849609375
Loss: 2376.2333984375


In [88]:
from sksurv.metrics import concordance_index_censored

# Baseline
print(concordance_index_censored(delta_train.astype(bool).squeeze(), np.exp(log_t_train), (X_train @ beta0.T).squeeze()))
print(concordance_index_censored(delta_test.astype(bool).squeeze(), np.exp(log_t_test), (X_test @ beta0.T).squeeze()))

# Fitted model 
print(concordance_index_censored(delta_train.astype(bool).squeeze(), np.exp(log_t_train), (X_train @ beta_var.numpy().T).squeeze()))
print(concordance_index_censored(delta_test.astype(bool).squeeze(), np.exp(log_t_test), (X_test @ beta_var.numpy().T).squeeze()))

(0.41476774046744463, 12742, 17986, 35, 1080)
(0.2907465825446898, 276, 674, 1, 0)
(0.4187985567077333, 12866, 17862, 35, 1080)
(0.30651945320715035, 291, 659, 1, 0)
