In [1]:
#import tensorflow as tf 
# https://sci-hub.se/10.1002/sim.1203
# https://journals.sagepub.com/doi/pdf/10.1177/1536867x0900900206

# Dataset

In [2]:
import numpy as np 
import matplotlib.pyplot as plt 

from sksurv.metrics import concordance_index_censored
from lifelines.datasets import load_rossi

data = load_rossi()
data.dropna(inplace=True)
print(data.shape)
data.head()

(432, 9)


Unnamed: 0,week,arrest,fin,age,race,wexp,mar,paro,prio
0,20,1,0,27,1,0,0,1,3
1,17,1,0,18,1,0,0,1,8
2,25,1,0,19,0,1,0,1,13
3,52,0,1,23,1,1,1,1,1
4,52,0,0,19,0,1,0,1,3


In [3]:
# TODO Experiments:
# - Performance of global model on data aggregated across clients 
# - Performance of global model on separate local datasets 
# - Performance of local models on separate local datasets 

In [4]:
event_col = "arrest"
duration_col = "week"

In [5]:
from utils.client import Client 

client = Client(data, n_knots=5, n_epochs=5, event_col=event_col, duration_col=duration_col)
client.preprocess_data()
client.init_model()
client.fit_model()

2025-04-24 08:11:15.499991: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [19]:
client.risk_score(client.X_train).shape

(108, 1)

In [21]:
# Unpack structured array 
event_train, duration_train = zip(*client.y_train)

# Cast to numpy array 
duration_train = np.array(duration_train)
event_train = np.array(event_train)

client.survival_curve(client.X_train, duration_train).shape

(108, 1)

# Instantiate clients 

In [6]:
from utils.client import Client 

N_CLIENTS = 3 

data_idx = np.array_split(np.arange(data.shape[0]), N_CLIENTS)

participants = [] 
for idx in data_idx:
    # Mutually exclusive data samples 
    client = Client(data.iloc[idx], n_knots=5, n_epochs=5, event_col=event_col, duration_col=duration_col)
    
    # Apply data pre-processing
    client.preprocess_data()
    # Initialize model and parameters
    client.init_model()
    
    participants.append(client)

In [7]:
X_train, y_train, X_test, y_test = [], [], [], []
for client in participants:
    X_train.append(client.X_train)
    y_train.append(client.y_train)

    X_test.append(client.X_test)
    y_test.append(client.y_test)

X_train = np.vstack(X_train)
X_test = np.vstack(X_test)

y_train = np.hstack(y_train)
y_test = np.hstack(y_test)

X_train.shape, y_train.shape, X_test.shape, y_test.shape

((324, 7), (324,), (108, 7), (108,))

In [13]:
# Unpack structured array 
event_train, duration_train = zip(*y_train)
event_test, duration_test = zip(*y_test)

event_train = np.array(event_train)
event_test = np.array(event_test)

duration_train = np.array(duration_train)
duration_test = np.array(duration_test)

In [14]:
for client in participants:
    print(client.risk_score(X_train).shape)

(324, 1)
(324, 1)
(324, 1)


In [15]:
# check client performance *before* fitting models
# check client performance *after* fitting models

In [17]:
from sksurv.metrics import concordance_index_censored
print(concordance_index_censored(event_train.astype(bool), duration_train, client.risk_score(X_train).squeeze()))
print(concordance_index_censored(event_test.astype(bool), duration_test, client.risk_score(X_test).squeeze()))

(0.64695735306343, 15559, 8488, 11, 952)
(0.48678667177326695, 1270, 1339, 2, 0)
