# Federated Survival Analysis Simulation with Custom Clustering and Growing Data

This notebook implements our federated survival analysis simulation with GPU acceleration (using CuPy) and a custom clustering step that minimizes an objective combining Euclidean distance between feature completeness vectors and the difference in predicted risk. In addition, the local training is modified so that at each round only a fraction of the data is used—and that fraction grows over the rounds (simulating a warm-up or incremental data availability scenario).

Before running the notebook in Google Colab, please select a GPU runtime (Runtime → Change runtime type → Hardware accelerator: GPU).

In [None]:
# If needed, install CuPy for your Colab GPU runtime (uncomment if necessary):
# !pip install cupy-cuda11x

!pip install lifelines
!pip install scikit-learn

import os
import random
import glob
import datetime
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import CoxPHFitter
from lifelines.utils import concordance_index
from sklearn.cluster import KMeans

# Try to import cupy. If not available, fallback to numpy.
try:
    import cupy as cp
except ImportError:
    cp = np


In [None]:
##############################################
# Synthetic Data Generation (HDF5 Files)
##############################################

def generate_synthetic_data(
    num_datasets=10,
    num_rows_range=(100, 200),
    total_features=50,
    common_features=10,
    censoring_percentage=0.30,
    missing_value_fraction=0.1,
    missingness_range=(0, 0.99),
    feature_null_fraction=0.3,
    output_dir="datasets_h5",
    filename=None  # Optional specific filename
):
    """
    Generate synthetic survival analysis datasets with metadata stored in HDF5 format.
    Each dataset file contains:
      - A data matrix (features, time, death)
      - A metadata group with:
          * node_features: the list of features available at the center.
          * binary_feature_vector: a binary vector of feature presence.
          * feature_vector: the real-valued feature completeness vector.
          * null_counts: number of null rows per available feature.
          * Attributes: date_created, num_rows, num_features, num_censored.
    """
    os.makedirs(output_dir, exist_ok=True)

    # Define the global feature set T
    all_features = [f"feat_{i}" for i in range(total_features)]
    all_features_path = os.path.join(output_dir, "all_features.txt")
    with open(all_features_path, "w") as f:
        f.write("\n".join(all_features))
    print(f"Global feature set saved to {all_features_path}")

    for i in range(num_datasets):
        num_rows = random.randint(*num_rows_range)
        # Define node-specific feature set F_i
        lower_bound = int(common_features * 0.1)
        upper_bound = int(common_features * 0.5)
        num_features = common_features + random.randint(lower_bound, upper_bound)
        # Randomly sample F_i from T
        f_i = sorted(random.sample(all_features, num_features))
        # Generate random values for each feature in F_i
        values = np.random.default_rng().normal(
            loc=0, scale=1/num_features, size=(num_rows, num_features)
        )
        df = pd.DataFrame(values, columns=f_i)

        # Introduce null values
        null_counts = {}
        num_null_features = int(feature_null_fraction * len(f_i))
        features_with_nulls = random.sample(f_i, num_null_features)
        for feature in features_with_nulls:
            null_percentage = random.uniform(*missingness_range)
            num_missing = int(null_percentage * num_rows)
            null_counts[feature] = num_missing
            missing_indices = np.random.choice(num_rows, num_missing, replace=False)
            df.loc[missing_indices, feature] = np.nan

        # Generate gamma values for survival time
        gamma = [random.uniform(-1, 1) for _ in f_i]
        temp = df[f_i].fillna(0).values @ gamma
        time_exponential = np.random.default_rng().exponential(scale=5 * np.exp(temp))
        time = time_exponential.astype(int) + 1

        # Introduce censoring
        d_nums = int(num_rows * (1 - censoring_percentage))
        d_arr = np.array([1] * d_nums + [0] * (num_rows - d_nums))
        np.random.shuffle(d_arr)
        for j in range(len(d_arr)):
            if d_arr[j] == 0:
                time[j] = random.randint(1, time[j])
        df["time"] = time
        df["death"] = d_arr

        # Generate Binary Feature Vector and Feature Completeness Vector
        binary_feature_vector = [1 if feature in f_i else 0 for feature in all_features]
        feature_vector = []
        for feature in all_features:
            if feature in f_i:
                num_non_null = df[feature].notnull().sum()
                feature_value = num_non_null / num_rows
            else:
                feature_value = 0
            feature_vector.append(feature_value)
        # Max-normalize the feature vector
        max_value = max(feature_vector) if feature_vector else 1
        feature_vector = [value / max_value for value in feature_vector]

        # Save dataset and metadata to HDF5
        if filename:
            dataset_path = os.path.join(output_dir, filename)
        else:
            dataset_path = os.path.join(output_dir, f"dataset_node_{i}.h5")
        with h5py.File(dataset_path, "w") as h5f:
            h5f.create_dataset("data", data=df.values, compression="gzip")
            metadata_group = h5f.create_group("metadata")
            metadata_group.create_dataset("node_features", data=np.array(f_i, dtype="S10"))
            metadata_group.create_dataset("gamma_values", data=np.array(gamma))
            metadata_group.create_dataset("binary_feature_vector", data=binary_feature_vector)
            metadata_group.create_dataset("feature_vector", data=feature_vector)
            metadata_group.create_dataset("null_counts", data=[null_counts.get(f, 0) for f in f_i])
            metadata_group.attrs["date_created"] = str(datetime.datetime.now())
            metadata_group.attrs["num_rows"] = num_rows
            metadata_group.attrs["num_features"] = num_features
            metadata_group.attrs["num_censored"] = num_rows - d_nums

        print(f"Dataset {i + 1}/{num_datasets} saved to {dataset_path}")


In [None]:
##############################################
# Helper Functions for Federated Survival Analysis
##############################################

def load_node_dataset(filepath):
    """
    Load the dataset and metadata from an HDF5 file.
    Returns:
      - df: DataFrame with data (columns: features..., time, death)
      - node_features: list of features available at this node
      - feature_vector: real-valued vector (length global_dim) indicating feature completeness
      - num_rows: number of rows (sample size) at this node (from metadata attribute)
    """
    with h5py.File(filepath, "r") as h5f:
        data = h5f["data"][:]
        metadata = h5f["metadata"]
        node_features = [s.decode("utf-8") for s in metadata["node_features"][:]]
        feature_vector = np.array(metadata["feature_vector"][:])
        num_rows = int(metadata.attrs["num_rows"])
        col_names = node_features + ["time", "death"]
        df = pd.DataFrame(data, columns=col_names)
    return df, node_features, feature_vector, num_rows

def get_global_coef_from_model(cox_model, node_features, global_dim=50):
    """
    Embed the local CoxPH model coefficients into a global vector of size global_dim.
    """
    coef_full = np.zeros(global_dim)
    for feat in node_features:
        try:
            idx = int(feat.split("_")[1])
        except Exception as e:
            continue
        if feat in cox_model.params_.index:
            coef_full[idx] = cox_model.params_[feat]
    return coef_full

def evaluate_cindex(coef, node_df, node_features, global_dim=50):
    """
    Evaluate the concordance index (c-index) on a node’s dataset given the coefficient vector.
    Only the coefficients corresponding to the node’s features are used.
    """
    node_df = node_df.fillna(0)
    risk_scores = np.zeros(len(node_df))
    for feat in node_features:
        try:
            idx = int(feat.split("_")[1])
        except:
            continue
        if feat in node_df.columns:
            risk_scores += node_df[feat].astype(float).values * coef[idx]
    times = node_df["time"].astype(float).values
    events = node_df["death"].astype(int).values
    return concordance_index(times, -risk_scores, events)

def generate_noise(node, round_num, T_honest, T_ramp, epsilon_max):
    """
    Generate noise for a node (if it is designated as noisy).
    """
    if round_num < T_honest:
        return np.zeros_like(node['z'])
    cycle_phase = (round_num - T_honest) % (T_honest + T_ramp)
    if cycle_phase < T_ramp:
        alpha_i = min(cycle_phase / T_ramp, epsilon_max)
        return alpha_i * node['z']
    else:
        return np.zeros_like(node['z'])


In [None]:
##############################################
# New Helper: Compute Predicted Risk for Clustering
##############################################

def compute_predicted_risk(node):
    """
    Compute a summary risk for a node using its local model coefficients.
    For each row in the node's data (only the available features), compute the dot product with the corresponding local coefficients, and return the average.
    """
    df_local = node['data'][node['features']].fillna(0)
    coefs = []
    for feat in node['features']:
        try:
            idx = int(feat.split('_')[1])
            coefs.append(node['local_coef'][idx])
        except:
            coefs.append(0)
    coefs = np.array(coefs)
    risk_values = df_local.dot(coefs)
    return np.mean(risk_values)


In [None]:
##############################################
# New Clustering Function (Custom Clustering)
##############################################

def custom_clustering(nodes, n_clusters, lambda_clust, max_iter=10):
    """
    Custom clustering that minimizes the objective:
      \(\sum_{i=1}^{c} \sum_{j \in C_i} \|B_j - \mu_i\|_2 + \lambda_{\text{clust}} \cdot |r_j - \bar{r}_i|\),
    where \(B_j\) is the feature vector of node j, \(r_j\) is its predicted risk, and \(\bar{r}_i\) is the average risk in cluster i.
    Returns a dictionary mapping cluster labels to lists of node indices and updates each node with its cluster assignment.
    """
    # Initialize cluster assignment using KMeans on feature vectors
    X = np.array([node['feature_vector'] for node in nodes])
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(X)
    labels = kmeans.labels_
    for i, node in enumerate(nodes):
        node['cluster'] = labels[i]

    for it in range(max_iter):
        # Compute centroids and average risk for each cluster
        centroids = {}
        avg_risks = {}
        clusters = {i: [] for i in range(n_clusters)}
        for i, node in enumerate(nodes):
            clusters[node['cluster']].append(i)
        for i in range(n_clusters):
            if len(clusters[i]) > 0:
                B = np.array([nodes[j]['feature_vector'] for j in clusters[i]])
                centroids[i] = np.mean(B, axis=0)
                risks = [compute_predicted_risk(nodes[j]) for j in clusters[i]]
                avg_risks[i] = np.mean(risks)
            else:
                centroids[i] = np.zeros_like(nodes[0]['feature_vector'])
                avg_risks[i] = 0

        changed = False
        # Reassign each node to the cluster that minimizes the cost
        for i, node in enumerate(nodes):
            B_j = np.array(node['feature_vector'])
            risk_j = compute_predicted_risk(node)
            best_cost = None
            best_cluster = None
            for c in range(n_clusters):
                cost = np.linalg.norm(B_j - centroids[c]) + lambda_clust * abs(risk_j - avg_risks[c])
                if best_cost is None or cost < best_cost:
                    best_cost = cost
                    best_cluster = c
            if best_cluster != node['cluster']:
                changed = True
                node['cluster'] = best_cluster
        if not changed:
            break

    cluster_assignment = {i: [] for i in range(n_clusters)}
    for i, node in enumerate(nodes):
        cluster_assignment[node['cluster']].append(i)
    return cluster_assignment


In [None]:
##############################################
# Cluster-wise Federated Averaging
##############################################

def cluster_fedavg(nodes, cluster_assignment):
    """
    For each cluster (given by cluster_assignment, a dict mapping cluster label -> list of node indices),
    compute the weighted average of the local coefficients (node['local_coef']) weighted by node['num_rows'].
    Update each node's global_coef with its cluster’s aggregated beta.
    """
    for clabel, node_ids in cluster_assignment.items():
        total_samples = sum(nodes[i]['num_rows'] for i in node_ids)
        if total_samples == 0:
            continue
        weighted_sum = np.zeros_like(nodes[node_ids[0]]['local_coef'])
        for i in node_ids:
            weighted_sum += nodes[i]['local_coef'] * nodes[i]['num_rows']
        cluster_global_coef = weighted_sum / total_samples
        for i in node_ids:
            nodes[i]['global_coef'] = cluster_global_coef.copy()


In [None]:
##############################################
# Main Simulation Parameters and Loop
##############################################

def main_simulation():
    # Global parameters
    num_nodes = 10               # Should match the number of datasets generated.
    global_dim = 50              # Total number of features (as in data generation)
    max_rounds = 100
    T_honest = 10
    T_ramp = 5
    epsilon_max = 0.05
    alpha = 0.1                  # Learning rate for trust update
    monte_carlo_runs = 1         # Number of Monte Carlo runs (adjust for testing)
    n_clusters = 4               # Number of clusters for feature presence clustering
    lambda_clust = 0.6           # Weight for risk difference in clustering

    # Designate a set of node IDs as "noisy" (adjust as desired)
    noisy_node_ids = {0, 4, 9}

    # (Assume datasets have already been generated and are in the folder "datasets_h5")
    node_files = sorted(glob.glob(os.path.join("datasets_h5", "dataset_node_*.h5")))
    nodes = []
    for i, filepath in enumerate(node_files):
        df, node_features, feature_vector, num_rows = load_node_dataset(filepath)
        if i in noisy_node_ids:
            z = np.random.uniform(-1, 1, size=len(node_features))
        else:
            z = np.zeros(len(node_features))
        nodes.append({
            'id': i,
            'data': df,
            'features': node_features,
            'feature_vector': feature_vector,   # real-valued vector of length global_dim
            'num_rows': num_rows,
            'local_coef': np.zeros(global_dim),   # will be computed every round
            'global_coef': np.zeros(global_dim),  # updated after federated averaging
            'trust': np.ones(num_nodes),          # trust vector for peers (length num_nodes)
            'is_noisy': i in noisy_node_ids,
            'noise': np.zeros(len(node_features)),  # noise vector for local features
            'z': z                              # noise direction vector
        })

    # Use CuPy for heavy array computations
    average_trust_history = cp.zeros((num_nodes, max_rounds))

    for mc_run in range(monte_carlo_runs):
        print(f"Running Monte Carlo simulation {mc_run + 1}/{monte_carlo_runs}")
        for node in nodes:
            node['trust'] = np.ones(num_nodes)
            node['global_coef'] = np.zeros(global_dim)
            if node['is_noisy']:
                node['noise'] = np.zeros(len(node['features']))
        trust_history = cp.zeros((num_nodes, max_rounds))

        for round_num in range(max_rounds):
            # Determine fraction of data to use; fraction grows with round number
            fraction = min(1.0, (round_num + 1) / max_rounds)  # e.g., 1% at round 0, 100% at final round

            # --- Node Selection Based on Trust (using CuPy) ---
            trust_matrix = cp.array([node['trust'] for node in nodes])
            avg_trusts = cp.array([cp.mean(cp.delete(trust_matrix[:, j], j)) for j in range(num_nodes)])
            probabilities = avg_trusts / cp.sum(avg_trusts)
            probabilities = cp.asnumpy(probabilities)
            selected_node_ids = np.random.choice(range(num_nodes), size=6, replace=False, p=probabilities)

            # --- Local Training (Compute local coefficients) ---
            for node in nodes:
                df = node['data']
                node_feats = node['features']
                df_local = df[node_feats + ["time", "death"]].copy()
                # Use only a fraction of the data (the fraction increases with round number)
                num_rows_local = int(len(df_local) * fraction)
                if num_rows_local < 1:
                    num_rows_local = 1
                df_local = df_local.iloc[:num_rows_local]
                df_local = df_local.fillna(0)
                try:
                    cox_model = CoxPHFitter()
                    cox_model.fit(df_local, duration_col="time", event_col="death", show_progress=False)
                except Exception as e:
                    print(f"Node {node['id']} encountered an error in fitting: {e}")
                    continue
                local_coef = get_global_coef_from_model(cox_model, node_feats, global_dim=global_dim)
                if node['is_noisy']:
                    delta_noise = generate_noise(node, round_num, T_honest, T_ramp, epsilon_max)
                    node['noise'] += delta_noise
                    for idx_local, feat in enumerate(node['features']):
                        try:
                            global_idx = int(feat.split("_")[1])
                        except:
                            continue
                        local_coef[global_idx] += node['noise'][idx_local]
                node['local_coef'] = local_coef.copy()

            # --- Custom Clustering Based on Feature Completeness and Risk Alignment ---
            cluster_assignment = custom_clustering(nodes, n_clusters, lambda_clust=lambda_clust, max_iter=10)

            # --- Trust Update Within Clusters ---
            for clabel, node_ids in cluster_assignment.items():
                for i in node_ids:
                    if i not in selected_node_ids:
                        continue
                    node = nodes[i]
                    df_local = node['data']
                    node_feats = node['features']
                    orig_cindex = evaluate_cindex(node['local_coef'], df_local, node_feats, global_dim=global_dim)
                    for peer_id in node_ids:
                        if peer_id == i:
                            continue
                        if peer_id >= len(node['trust']):
                            continue
                        peer = nodes[peer_id]
                        avg_coef = (node['local_coef'] + peer['local_coef']) / 2.0
                        new_cindex = evaluate_cindex(avg_coef, df_local, node_feats, global_dim=global_dim)
                        improvement = new_cindex - orig_cindex
                        node['trust'][peer_id] = max(0, min(node['trust'][peer_id] + alpha * improvement, 1))

            # --- Cluster-wise Federated Averaging ---
            cluster_fedavg(nodes, cluster_assignment)

            # --- Record Trust History (using CuPy) ---
            trust_matrix = cp.array([node['trust'] for node in nodes])
            avg_trusts = cp.array([cp.mean(cp.delete(trust_matrix[:, j], j)) for j in range(num_nodes)])
            for j in range(num_nodes):
                trust_history[j, round_num] = avg_trusts[j]

        average_trust_history += trust_history / monte_carlo_runs

    # Convert average_trust_history to a NumPy array for plotting
    average_trust_history = cp.asnumpy(average_trust_history)
    plt.figure(figsize=(12, 7))
    for node_id in range(num_nodes):
        label = f"Node {node_id+1}" + (" (Noisy)" if node_id in noisy_node_ids else "")
        linestyle = '--' if node_id in noisy_node_ids else '-'
        plt.plot(average_trust_history[node_id], label=label, linestyle=linestyle)

    phase = T_honest
    while phase < max_rounds:
        plt.axvspan(phase, phase + T_ramp, color='red', alpha=0.2, label='Noise Phase' if phase == T_honest else None)
        next_phase = phase + T_ramp + T_honest
        if phase + T_ramp < max_rounds:
            plt.axvspan(phase + T_ramp, min(next_phase, max_rounds), color='blue', alpha=0.1, label='No-Noise Phase' if phase == T_honest else None)
        phase = next_phase

    plt.xlabel('Rounds')
    plt.ylabel('Average Trust Score')
    plt.title(f'Average Trust Score Evolution with $\epsilon_{{max}}={epsilon_max}$, $\alpha={alpha}$ and Custom Clustering ($\lambda={lambda_clust}$)')
    plt.legend()
    plt.grid()
    plt.show()

##############################################
# Experiment: Noise Injection Magnitude over Time
##############################################

def test_noise_injection(T_honest, T_ramp, epsilon_max_values, rounds=200, d=1):
    """
    For a dummy noisy node with a noise direction vector z (of dimension d),
    compute the instantaneous noise injection magnitude (n_t) for each round t,
    using the generate_noise function. Returns a dictionary mapping each epsilon_max
    value to the list of noise magnitudes over time.
    """
    noise_curves = {}
    rounds_list = np.arange(rounds)
    for eps in epsilon_max_values:
        noise_vals = []
        dummy_node = {'z': np.ones(d)}  # simple vector of ones
        for t in rounds_list:
            noise = generate_noise(dummy_node, t, T_honest, T_ramp, eps)
            magnitude = np.linalg.norm(noise)  
            noise_vals.append(magnitude)
        noise_curves[eps] = noise_vals
    return rounds_list, noise_curves

# Set parameters for the noise injection experiment
T_honest_test = 20
T_ramp_test = 10
epsilon_max_values = [0.05, 0.1, 0.2, 0.4, 0.8]  # five different values
rounds_test = 200

rounds_list, noise_curves = test_noise_injection(T_honest_test, T_ramp_test, epsilon_max_values, rounds=rounds_test, d=1)

output_dir = "noise_plots"
os.makedirs(output_dir, exist_ok=True)
for eps, curve in noise_curves.items():
    plt.figure(figsize=(10,6))
    plt.plot(rounds_list, curve, label=f'epsilon_max = {eps}')
    plt.xlabel('Round (t)')
    plt.ylabel('Noise magnitude n_t')
    plt.title('Noise magnitude over time (T_honest = {}, T_ramp = {})'.format(T_honest_test, T_ramp_test))
    plt.legend()
    plt.grid()
    filename = os.path.join(output_dir, f"noise_plot_epsilon_{eps}.png")
    plt.savefig(filename)
    plt.close()
    print(f"Saved noise plot for epsilon_max = {eps} to {filename}")

##############################################
# Main
##############################################

if __name__ == '__main__':
    main_simulation()
