In [None]:
# TODO: 

In [1]:
# https://sci-hub.se/10.1002/sim.1203
# https://journals.sagepub.com/doi/pdf/10.1177/1536867x0900900206

In [2]:
import numpy as np 
from sksurv.datasets import load_aids, load_breast_cancer, load_veterans_lung_cancer
from sklearn.model_selection import train_test_split

# Load dataset 
X, y = load_breast_cancer() # load_veterans_lung_cancer() # load_aids() #load_breast_cancer()

# Drop columns that are not numerical 
X = X[X.columns[X.dtypes != "category"]]
X.head()

Unnamed: 0,X200726_at,X200965_s_at,X201068_s_at,X201091_s_at,X201288_at,X201368_at,X201663_s_at,X201664_at,X202239_at,X202240_at,...,X221028_s_at,X221241_s_at,X221344_at,X221634_at,X221816_s_at,X221882_s_at,X221916_at,X221928_at,age,size
0,10.926361,8.962608,11.630078,10.964107,11.518305,12.038527,9.623518,9.814798,10.016732,7.847383,...,7.12934,5.649573,6.313081,7.842097,10.132635,10.926365,6.477749,5.991885,57.0,3.0
1,12.24209,9.531718,12.626106,11.594716,12.317659,10.776911,10.604577,10.704329,10.161838,8.744875,...,7.189642,7.599788,5.126765,8.780328,10.213467,9.555092,4.96805,7.05113,57.0,3.0
2,11.661716,10.23868,12.572919,9.166088,11.698658,11.353333,9.384927,10.161654,10.032721,8.125487,...,7.222765,4.987613,6.936022,7.855649,10.164514,9.308048,4.283777,6.828986,48.0,2.5
3,12.174021,9.819279,12.109888,9.086937,13.132617,11.859394,8.400839,8.670721,10.727427,8.65081,...,6.584748,7.205051,6.787297,6.678375,10.660092,10.208241,5.713404,6.927251,42.0,1.8
4,11.484011,11.489233,11.779285,8.887616,10.429663,11.401139,7.741092,8.642018,9.556686,8.478862,...,8.05299,6.973316,7.312287,7.358556,11.57033,10.931843,5.817265,6.655448,46.0,3.0


In [3]:
# Unpack structured array 
delta, times = zip(*y)
delta = np.array(delta)[:, None]
times = np.array(times)[:, None]
delta[:5], times[:5]

(array([[ True],
        [False],
        [ True],
        [False],
        [ True]]),
 array([[ 723.],
        [6591.],
        [ 524.],
        [6255.],
        [3822.]]))

# Initialize clients

In [4]:
from scipy.linalg import lstsq

from sksurv.util import Surv
from sksurv.metrics import concordance_index_censored
from sksurv.linear_model import CoxPHSurvivalAnalysis

from sklearn.preprocessing import StandardScaler


def feature_scaling(X_train, X_test):
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)
    return X_train, X_test


def init_beta_gamma(D, X, y):

    # Fit a Cox PH model
    cox = CoxPHSurvivalAnalysis(alpha=0.01, tol=1e-6)
    cox.fit(X, y)

    # Estimated cumulative hazard conditional on covarites
    cumulative_hazards = cox.predict_cumulative_hazard_function(X, return_array=True)
    # Average log cumulative hazard for each subject over observed times
    mean_hazard = np.log(np.mean(cumulative_hazards, axis=1) + 1e-6)

    # Fit linear regression to log cumulative hazard
    beta, _, _, _ = lstsq(X, mean_hazard)
    gamma, _, _, _ = lstsq(D, mean_hazard)

    beta = np.random.random(X.shape[1])
    gamma = np.random.random(D.shape[1])
    
    return beta[None, :], gamma[None, :]


def spline_basis(x, k_j, k_min, k_max, derivative=False):
    """Computes the basis function S(x; k_j)."""
    phi_j = (k_max - k_j) / (k_max - k_min)
    
    def relu(x):
        return max(x, 0)

    if derivative:
        # Derivative of the spline basis function
        return 3 * relu(x - k_j) ** 2 - 3 * phi_j * relu(x - k_min) ** 2 - 3 * (1 - phi_j) * relu(x - k_max) ** 2
    # Spline basis function 
    return relu(x - k_j) ** 3 - phi_j * relu(x - k_min) ** 3 - (1 - phi_j) * relu(x - k_max) ** 3


def spline_design_matrix(ln_t, knots):
    """Computes the spline function s(x; γ, k)."""
    # Boundary knots
    k_min, k_max = knots[0], knots[-1]
    # Construct basis functions over internal knots 
    basis = [spline_basis(ln_t, k_j, k_min, k_max) for k_j in knots[1:-1]]
    # Design matrix 
    return np.array([ln_t] + basis)


def spline_derivative_design_matrix(ln_t, knots):
    """Computes the spline function s(x; γ, k)."""
    # Boundary knots
    k_min, k_max = knots[0], knots[-1]
    # Construct basis functions over internal knots 
    basis = [spline_basis(ln_t, k_j, k_min, k_max, derivative=True) for k_j in knots[1:-1]]
    # Design matrix 
    return 1 / np.exp(ln_t) * np.array([1] + basis)


def create_splines(log_t, knots):

    D, dDdt = [], []
    for log_time in log_t:
        D.append(spline_design_matrix(log_time, knots))
        dDdt.append(spline_derivative_design_matrix(log_time, knots))
    
    # Cast to <numpy.ndarray>
    return np.array(D), np.array(dDdt)


class Client:

    def __init__(self, cid, data, times, delta):

        self.cid = cid
        self.data = data 
        self.times = times
        self.delta = delta 

    def set_params(self, beta, gamma):
        self.beta = beta
        self.gamma = gamma 

    def _train_test_split(self, test_size=0.25):

        # Train-test splitting 
        train_idx, test_idx = train_test_split(
            np.arange(self.data.shape[0]), 
            test_size=test_size, 
            random_state=42, 
            stratify=self.delta.squeeze().astype(int)
        )
        return train_idx, test_idx

    def train_test_split(self):

        # Create training and test splits 
        train_idx, test_idx = self._train_test_split()

        # Cast to <numpy>
        X = self.data.to_numpy().copy()
        
        # Split data
        self.d_train, self.d_test = self.delta[train_idx], self.delta[test_idx]
        self.t_train, self.t_test = self.times[train_idx], self.times[test_idx]

        # Scale training and test data
        self.X_train, self.X_test = feature_scaling(X[train_idx], X[test_idx])

    def initialize_params(self, knots):
        
        # Spline design matrices of log-time 
        self.D, self.dDdt = create_splines(log_t=np.log(self.t_train.squeeze()), knots=knots)

        # Create structured array
        y_train = Surv.from_arrays(event=self.d_train.squeeze(), time=np.log(self.t_train.squeeze()))
        
        # Initialize model parameters 
        self.beta, self.gamma = init_beta_gamma(self.D, self.X_train, y_train)

    def loss(self):
        phi = self.D @ self.gamma.T + self.X_train @ self.beta.T
    
        uncensored = np.exp(phi - np.exp(phi)) * (self.dDdt @ self.gamma.T)
        censored = np.exp(-1.0 * np.exp(phi))
        
        return float(np.sum(self.d_train * uncensored + (1 - self.d_train) * censored, axis=0))

    def gradient_gamma(self):
    
        phi = self.D @ self.gamma.T + self.X_train @ self.beta.T
        dsdt = self.dDdt @ self.gamma.T
        
        return (np.exp(phi) - self.d_train).T @ self.D - (self.d_train / dsdt).T @ self.dDdt

    def gradient_beta(self):
        phi = self.D @ self.gamma.T + self.X_train @ self.beta.T
        return (np.exp(phi) - self.d_train).T @ self.X_train

    def train_steps(self, local_steps=1, learning_rate=0.01):
        
        for i in range(local_steps):
            # Gradient descent steps 
            self.beta -= learning_rate * self.gradient_beta()
            self.gamma -= learning_rate * self.gradient_gamma()

    def c_score(self, beta=None):
        
        if beta is None:
            beta = self.beta 
    
        train_score, _, _, _, _ = concordance_index_censored(
            self.d_train.squeeze(), 
            np.log(self.t_train.squeeze()), 
            (self.X_train @ beta.T).squeeze()
        )
        test_score, _, _, _, _ = concordance_index_censored(
            self.d_test.squeeze(), 
            np.log(self.t_test.squeeze()), 
            (self.X_test @ beta.T).squeeze()
        )
        return train_score, test_score

In [5]:
N_CLIENTS = 3
CLIENT_DATA_IDX = np.array_split(np.arange(X.shape[0]), N_CLIENTS)

clients = []
for i, idx in enumerate(CLIENT_DATA_IDX):
    participant = Client(cid=i, data=X.iloc[idx], times=times[idx], delta=delta[idx])
    participant.train_test_split()
    clients.append(participant)

# Initialize parameters

In [6]:
# Knot locations: Centiles of the distribution of **uncensored** log event times
# - Boundary knots: placed at the 0th and 100th centiles (min and max values)
# - Internal knots: internal knots are placed at the centiles between the min and max   
knots = np.percentile(np.log(times[delta.astype(int) == 1]), [0, 25, 50, 75, 100])
knots

array([4.82831374, 6.58822688, 7.09506438, 7.91221465, 8.8797508 ])

In [7]:
for participant in clients:
    # Distribute the same knot positions to every client 
    participant.initialize_params(knots.copy())
    print(participant.gradient_gamma())

[[1083.15630289 -937.60875742 -737.75669944 -397.92426227]]
[[ 30362.61576429 -12275.13129708  -9549.89758796  -5171.80414341]]
[[  755.80405751 -1259.87848515 -1104.93214381  -647.50875962]]


# Initialize parameters

In [8]:
def average_parameters(clients):
    
    beta_avg, gamma_avg = 0, 0
    for participant in clients:
        
        beta_avg += participant.beta 
        gamma_avg += participant.gamma  

    beta_avg /= len(clients)
    gamma_avg /= len(clients)

    return beta_avg, gamma_avg


beta0, gamma0 = average_parameters(clients)
gamma0

array([[0.25829183, 0.68146208, 0.1561162 , 0.40536194]])

In [9]:
for participant in clients:
    print(participant.gamma)
    participant.set_params(beta=beta0.copy(), gamma=gamma0.copy())
    print(participant.gamma)
    print()

[[0.23737992 0.71531589 0.37116895 0.15554021]]
[[0.25829183 0.68146208 0.1561162  0.40536194]]

[[0.40962522 0.83773586 0.0944291  0.96165086]]
[[0.25829183 0.68146208 0.1561162  0.40536194]]

[[0.12787035 0.49133449 0.00275056 0.09889474]]
[[0.25829183 0.68146208 0.1561162  0.40536194]]



# Training 

In [10]:
def scores(clients, beta=None):

    all_train_scores, all_test_scores = [], []
    for participant in clients:
        train_score, test_score = participant.c_score(beta=beta)
        all_train_scores.append(train_score)
        all_test_scores.append(test_score)

    return all_train_scores, all_test_scores

In [11]:
# Baseline results (before any training)
from sksurv.metrics import concordance_index_censored

beta_avg0, gamma_avg0 = average_parameters(clients)

train_scores0, test_scores0 = scores(clients, beta=beta_avg0)
np.mean(train_scores0), np.mean(test_scores0)

(0.5994774521653349, 0.4388770433546553)

In [12]:
# NOTE: difference between sending model weights to the server for aggregation or sending 
# weight gradients to the server 

_beta, _gamma = beta_avg0.copy(), gamma_avg0.copy()

for participant in clients:
    participant.set_params(beta=_beta, gamma=_gamma)


def learning_rate_sched(epoch, init_lr=0.01, total_epochs=10):
    return init_lr * 0.5 * (1 + np.cos(np.pi * epoch / total_epochs))


for epoch in range(20):

    learning_rate = learning_rate_sched(epoch)

    for participant in clients:
        participant.train_steps(local_steps=50, learning_rate=learning_rate)

    _beta, _gamma = average_parameters(clients)

    for participant in clients:
        participant.set_params(beta=_beta, gamma=_gamma)

train_scores, test_scores = scores(clients, beta=_beta)
np.mean(train_scores), np.mean(test_scores)

(0.7798841499634367, 0.680912895838269)