In [1]:
# https://sci-hub.se/10.1002/sim.1203
# https://journals.sagepub.com/doi/pdf/10.1177/1536867x0900900206

# Dataset

In [2]:
from lifelines.datasets import load_rossi

data = load_rossi()
data.dropna(inplace=True)
print(data.shape)
data.head()

(432, 9)


Unnamed: 0,week,arrest,fin,age,race,wexp,mar,paro,prio
0,20,1,0,27,1,0,0,1,3
1,17,1,0,18,1,0,0,1,8
2,25,1,0,19,0,1,0,1,13
3,52,0,1,23,1,1,1,1,1
4,52,0,0,19,0,1,0,1,3


In [3]:
event_col = "arrest"
duration_col = "week"

X = data.drop(columns=[event_col, duration_col])
X.shape

(432, 7)

In [4]:
from sklearn.model_selection import train_test_split

train_idx, test_idx = train_test_split(
    range(data.shape[0]), test_size=0.15, random_state=42, stratify=data[event_col]
)
data_train, data_test = data.iloc[train_idx], data.iloc[test_idx]
X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
data_train.shape, data_test.shape

((367, 9), (65, 9))

# Parameter initialization 

In [5]:
import numpy as np 

penalizer = 0.1
n_baseline_knots = 3

initial_point = np.random.random((data.shape[1] + n_baseline_knots - 1))

# Fitting centralized model

In [6]:
from lifelines import CoxPHFitter

fit_options = {"maxiter": 5000}

model = CoxPHFitter(baseline_estimation_method="spline", penalizer=penalizer, n_baseline_knots=n_baseline_knots)
model.fit(
    data, 
    duration_col=duration_col, 
    event_col=event_col, 
    fit_options=fit_options, 
    initial_point=initial_point
)

print(model.score(data_train, scoring_method="concordance_index"))
print(model.score(data_test, scoring_method="concordance_index"))

0.5
0.5






In [7]:
def predict(X, model):
    intercept = model.params_.beta_.values[-1]
    weights = model.params_.beta_.values[:-1]
    return intercept + X @ weights


from sksurv.metrics import concordance_index_censored

y_hat_train = predict(X_train, model)
y_hat_test = predict(X_test, model)

print(concordance_index_censored(data_train[event_col].astype(bool), data_train[duration_col], y_hat_train))
print(concordance_index_censored(data_test[event_col].astype(bool), data_test[duration_col], y_hat_test))

(0.4325325878490394, 13285, 17436, 42, 1080)
(0.29915878023133546, 284, 666, 1, 0)


# Fitting de-centralized model

In [8]:
N_CLIENTS = 3
CLIENT_DATA_IDX = np.array_split(np.arange(data_train.shape[0]), N_CLIENTS)

models = []
for i, idx in enumerate(CLIENT_DATA_IDX):
    models.append(CoxPHFitter(
        baseline_estimation_method="spline", 
        penalizer=penalizer, 
        n_baseline_knots=n_baseline_knots
    ))

In [9]:
import pandas as pd 

global_params = initial_point.copy()

for epoch in range(10):

    for i, idx in enumerate(CLIENT_DATA_IDX):
        models[i].fit(
            data.iloc[idx], 
            duration_col=duration_col, 
            event_col=event_col, 
            fit_options=fit_options, 
            initial_point=global_params
        )

    # Aggregate parameters 
    agg_params = pd.concat([model.params_ for model in models], axis=1).mean(axis=1)

    # Check for convergence 
    if np.linalg.norm(agg_params - global_params) / np.linalg.norm(global_params) < 1e-4:
        print(f"Convergence: {epoch}")
        break 
    
    # Update global params 
    global_params = agg_params.values

Convergence: 2


In [10]:
for i, idx in enumerate(CLIENT_DATA_IDX):
    # Share final model params 
    models[i].params_ = agg_params

In [11]:
print(models[i].score(data_train, scoring_method="concordance_index"))
print(models[i].score(data_test, scoring_method="concordance_index"))

0.5
0.5






In [12]:
from sksurv.metrics import concordance_index_censored

y_hat_train = predict(X_train, models[i])
y_hat_test = predict(X_test, models[i])

print(concordance_index_censored(data_train[event_col].astype(bool), data_train[duration_col], y_hat_train))
print(concordance_index_censored(data_test[event_col].astype(bool), data_test[duration_col], y_hat_test))

(0.5, 0, 0, 30763, 1080)
(0.5, 0, 0, 951, 0)
