# Federated Survival Analysis Simulation with Growing Data and Trust Comparisons

This notebook simulates three scenarios in federated survival analysis:

1. **Our Model:** Our custom trust mechanism with custom clustering and a trust update based on the improvement in the c-index. Local training uses a growing fraction of each node’s data over rounds.
2. **Baseline (FLTrust):** A baseline trust mechanism inspired by FLTrust where (a) clustering is done using standard KMeans on the feature completeness vectors and (b) trust updates use a lower learning rate.
3. **No Trust:** Plain federated averaging (FedAvg) without any trust mechanism (all nodes weighted equally).

At every round we record the number of nodes whose c-index improved compared to the previous round. Finally, a table is generated and t-tests are performed to statistically compare the three approaches.

Before running, please select a GPU runtime in Google Colab (Runtime → Change runtime type → Hardware accelerator: GPU).

In [None]:
# If needed, install CuPy for your Colab GPU runtime (uncomment if necessary):
# !pip install cupy-cuda11x

!pip install lifelines
!pip install scikit-learn

import os
import random
import glob
import datetime
import h5py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import CoxPHFitter
from lifelines.utils import concordance_index
from sklearn.cluster import KMeans
from scipy.stats import ttest_ind

# Try to import cupy. If not available, fallback to numpy.
try:
    import cupy as cp
except ImportError:
    cp = np


In [None]:
##############################################
# Synthetic Data Generation (HDF5 Files)
##############################################

# def generate_synthetic_data(
#     num_datasets=10,
#     num_rows_range=(100, 200),
#     total_features=50,
#     common_features=10,
#     censoring_percentage=0.30,
#     missing_value_fraction=0.1,
#     missingness_range=(0, 0.99),
#     feature_null_fraction=0.3,
#     output_dir="datasets_h5",
#     filename=None  # Optional specific filename
# ):
#     """
#     Generate synthetic survival analysis datasets with metadata stored in HDF5 format.
#     Each dataset file contains:
#       - A data matrix (features, time, death)
#       - A metadata group with:
#           * node_features: the list of features available at the center.
#           * binary_feature_vector: a binary vector of feature presence.
#           * feature_vector: the real-valued feature completeness vector.
#           * null_counts: number of null rows per available feature.
#           * Attributes: date_created, num_rows, num_features, num_censored.
#     """
#     os.makedirs(output_dir, exist_ok=True)

#     all_features = [f"feat_{i}" for i in range(total_features)]
#     all_features_path = os.path.join(output_dir, "all_features.txt")
#     with open(all_features_path, "w") as f:
#         f.write("\n".join(all_features))
#     print(f"Global feature set saved to {all_features_path}")

#     for i in range(num_datasets):
#         num_rows = random.randint(*num_rows_range)
#         lower_bound = int(common_features * 0.1)
#         upper_bound = int(common_features * 0.5)
#         num_features = common_features + random.randint(lower_bound, upper_bound)
#         f_i = sorted(random.sample(all_features, num_features))
#         values = np.random.default_rng().normal(loc=0, scale=1/num_features, size=(num_rows, num_features))
#         df = pd.DataFrame(values, columns=f_i)

#         null_counts = {}
#         num_null_features = int(feature_null_fraction * len(f_i))
#         features_with_nulls = random.sample(f_i, num_null_features)
#         for feature in features_with_nulls:
#             null_percentage = random.uniform(*missingness_range)
#             num_missing = int(null_percentage * num_rows)
#             null_counts[feature] = num_missing
#             missing_indices = np.random.choice(num_rows, num_missing, replace=False)
#             df.loc[missing_indices, feature] = np.nan

#         gamma = [random.uniform(-1, 1) for _ in f_i]
#         temp = df[f_i].fillna(0).values @ gamma
#         time_exponential = np.random.default_rng().exponential(scale=5 * np.exp(temp))
#         time = time_exponential.astype(int) + 1

#         d_nums = int(num_rows * (1 - censoring_percentage))
#         d_arr = np.array([1] * d_nums + [0] * (num_rows - d_nums))
#         np.random.shuffle(d_arr)
#         for j in range(len(d_arr)):
#             if d_arr[j] == 0:
#                 time[j] = random.randint(1, time[j])
#         df["time"] = time
#         df["death"] = d_arr

#         binary_feature_vector = [1 if feature in f_i else 0 for feature in all_features]
#         feature_vector = []
#         for feature in all_features:
#             if feature in f_i:
#                 num_non_null = df[feature].notnull().sum()
#                 feature_value = num_non_null / num_rows
#             else:
#                 feature_value = 0
#             feature_vector.append(feature_value)
#         max_value = max(feature_vector) if feature_vector else 1
#         feature_vector = [value / max_value for value in feature_vector]

#         if filename:
#             dataset_path = os.path.join(output_dir, filename)
#         else:
#             dataset_path = os.path.join(output_dir, f"dataset_node_{i}.h5")
#         with h5py.File(dataset_path, "w") as h5f:
#             h5f.create_dataset("data", data=df.values, compression="gzip")
#             metadata_group = h5f.create_group("metadata")
#             metadata_group.create_dataset("node_features", data=np.array(f_i, dtype="S10"))
#             metadata_group.create_dataset("gamma_values", data=np.array(gamma))
#             metadata_group.create_dataset("binary_feature_vector", data=binary_feature_vector)
#             metadata_group.create_dataset("feature_vector", data=feature_vector)
#             metadata_group.create_dataset("null_counts", data=[null_counts.get(f, 0) for f in f_i])
#             metadata_group.attrs["date_created"] = str(datetime.datetime.now())
#             metadata_group.attrs["num_rows"] = num_rows
#             metadata_group.attrs["num_features"] = num_features
#             metadata_group.attrs["num_censored"] = num_rows - d_nums

#         print(f"Dataset {i + 1}/{num_datasets} saved to {dataset_path}")


In [None]:
##############################################
# Helper Functions for Federated Survival Analysis
##############################################

def load_node_dataset(filepath):
    """
    Load the dataset and metadata from an HDF5 file.
    Returns:
      - df: DataFrame with data (columns: features..., time, death)
      - node_features: list of features available at this node
      - feature_vector: real-valued vector (length global_dim) indicating feature completeness
      - num_rows: number of rows (sample size) at this node (from metadata attribute)
    """
    with h5py.File(filepath, "r") as h5f:
        data = h5f["data"][:]
        metadata = h5f["metadata"]
        node_features = [s.decode("utf-8") for s in metadata["node_features"][:]]
        feature_vector = np.array(metadata["feature_vector"][:])
        num_rows = int(metadata.attrs["num_rows"])
        col_names = node_features + ["time", "death"]
        df = pd.DataFrame(data, columns=col_names)
    return df, node_features, feature_vector, num_rows

def get_global_coef_from_model(cox_model, node_features, global_dim=50):
    """
    Embed the local CoxPH model coefficients into a global vector of size global_dim.
    """
    coef_full = np.zeros(global_dim)
    for feat in node_features:
        try:
            idx = int(feat.split("_")[1])
        except Exception as e:
            continue
        if feat in cox_model.params_.index:
            coef_full[idx] = cox_model.params_[feat]
    return coef_full

def evaluate_cindex(coef, node_df, node_features, global_dim=50):
    """
    Evaluate the concordance index (c-index) on a node’s dataset given the coefficient vector.
    Only the coefficients corresponding to the node’s features are used.
    """
    node_df = node_df.fillna(0)
    risk_scores = np.zeros(len(node_df))
    for feat in node_features:
        try:
            idx = int(feat.split("_")[1])
        except:
            continue
        if feat in node_df.columns:
            risk_scores += node_df[feat].astype(float).values * coef[idx]
    times = node_df["time"].astype(float).values
    events = node_df["death"].astype(int).values
    return concordance_index(times, -risk_scores, events)

def generate_noise(node, round_num, T_honest, T_ramp, epsilon_max):
    """
    Generate noise for a node (if it is designated as noisy).
    """
    if round_num < T_honest:
        return np.zeros_like(node['z'])
    cycle_phase = (round_num - T_honest) % (T_honest + T_ramp)
    if cycle_phase < T_ramp:
        alpha_i = min(cycle_phase / T_ramp, epsilon_max)
        return alpha_i * node['z']
    else:
        return np.zeros_like(node['z'])


In [None]:
##############################################
# New Helper: Compute Predicted Risk for Clustering
##############################################

def compute_predicted_risk(node):
    """
    Compute a summary risk for a node using its local model coefficients.
    For each row in the node's data (only the available features), compute the dot product with the corresponding local coefficients, and return the average.
    """
    df_local = node['data'][node['features']].fillna(0)
    coefs = []
    for feat in node['features']:
        try:
            idx = int(feat.split('_')[1])
            coefs.append(node['local_coef'][idx])
        except:
            coefs.append(0)
    coefs = np.array(coefs)
    risk_values = df_local.dot(coefs)
    return np.mean(risk_values)


In [None]:
##############################################
# New Clustering Function (Custom Clustering)
##############################################

def custom_clustering(nodes, n_clusters, lambda_clust, max_iter=10):
    """
    Custom clustering that minimizes the objective:
      \(\sum_{i=1}^{c} \sum_{j \in C_i} \|B_j - \mu_i\|_2 + \lambda_{\text{clust}} \cdot |r_j - \bar{r}_i|\),
    where \(B_j\) is the feature vector of node j, \(r_j\) is its predicted risk, and \(\bar{r}_i\) is the average risk in cluster i.
    Returns a dictionary mapping cluster labels to lists of node indices and updates each node with its cluster assignment.
    """
    # Initialize cluster assignment using KMeans on feature vectors
    X = np.array([node['feature_vector'] for node in nodes])
    kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(X)
    labels = kmeans.labels_
    for i, node in enumerate(nodes):
        node['cluster'] = labels[i]

    for it in range(max_iter):
        centroids = {}
        avg_risks = {}
        clusters = {i: [] for i in range(n_clusters)}
        for i, node in enumerate(nodes):
            clusters[node['cluster']].append(i)
        for i in range(n_clusters):
            if len(clusters[i]) > 0:
                B = np.array([nodes[j]['feature_vector'] for j in clusters[i]])
                centroids[i] = np.mean(B, axis=0)
                risks = [compute_predicted_risk(nodes[j]) for j in clusters[i]]
                avg_risks[i] = np.mean(risks)
            else:
                centroids[i] = np.zeros_like(nodes[0]['feature_vector'])
                avg_risks[i] = 0

        changed = False
        for i, node in enumerate(nodes):
            B_j = np.array(node['feature_vector'])
            risk_j = compute_predicted_risk(node)
            best_cost = None
            best_cluster = None
            for c in range(n_clusters):
                cost = np.linalg.norm(B_j - centroids[c]) + lambda_clust * abs(risk_j - avg_risks[c])
                if best_cost is None or cost < best_cost:
                    best_cost = cost
                    best_cluster = c
            if best_cluster != node['cluster']:
                changed = True
                node['cluster'] = best_cluster
        if not changed:
            break

    cluster_assignment = {i: [] for i in range(n_clusters)}
    for i, node in enumerate(nodes):
        cluster_assignment[node['cluster']].append(i)
    return cluster_assignment


In [None]:
##############################################
# Cluster-wise Federated Averaging
##############################################

def cluster_fedavg(nodes, cluster_assignment):
    """
    For each cluster (given by cluster_assignment, a dict mapping cluster label -> list of node indices),
    compute the weighted average of the local coefficients (node['local_coef']) weighted by node['num_rows'].
    Update each node's global_coef with its cluster’s aggregated beta.
    """
    for clabel, node_ids in cluster_assignment.items():
        total_samples = sum(nodes[i]['num_rows'] for i in node_ids)
        if total_samples == 0:
            continue
        weighted_sum = np.zeros_like(nodes[node_ids[0]]['local_coef'])
        for i in node_ids:
            weighted_sum += nodes[i]['local_coef'] * nodes[i]['num_rows']
        cluster_global_coef = weighted_sum / total_samples
        for i in node_ids:
            nodes[i]['global_coef'] = cluster_global_coef.copy()


In [None]:
##############################################
# Main Simulation with Growing Data and Three Scenarios
##############################################

def simulate_scenario(scenario, epsilon_max, alpha, lambda_clust, n_clusters, monte_carlo_runs, max_rounds):
    """
    Simulate the federated learning process for one scenario.
    scenario: 'our', 'baseline', or 'no_trust'
    For each round, use a fraction of each node's data that grows linearly over rounds.
    Returns an array of length max_rounds where each entry is the number of nodes with improvement in c-index compared to the previous round (averaged over Monte Carlo runs).
    """
    num_nodes = 10
    global_dim = 50
    T_honest = 10
    T_ramp = 5

    # Load nodes from datasets
    node_files = sorted(glob.glob(os.path.join("datasets_h5", "dataset_node_*.h5")))
    nodes = []
    noisy_node_ids = {0, 4, 9}  
    for i, filepath in enumerate(node_files):
        df, node_features, feature_vector, num_rows = load_node_dataset(filepath)
        if i in noisy_node_ids:
            z = np.random.uniform(-1, 1, size=len(node_features))
        else:
            z = np.zeros(len(node_features))
        nodes.append({
            'id': i,
            'data': df,
            'features': node_features,
            'feature_vector': feature_vector,
            'num_rows': num_rows,
            'local_coef': np.zeros(global_dim),
            'global_coef': np.zeros(global_dim),
            'trust': np.ones(num_nodes),
            'is_noisy': i in noisy_node_ids,
            'noise': np.zeros(len(node_features)),
            'z': z
        })

    # Prepare to record improvement counts
    improvements_all = np.zeros(max_rounds)

    for mc in range(monte_carlo_runs):
        # Reinitialize nodes for this Monte Carlo run
        for node in nodes:
            node['trust'] = np.ones(num_nodes)
            node['global_coef'] = np.zeros(global_dim)
            if node['is_noisy']:
                node['noise'] = np.zeros(len(node['features']))
            # For each node, initialize previous c-index using all data fraction of round0
            # Here, we run a quick local training on a very small fraction
            fraction = 0.01
            df_local = node['data'].iloc[:max(1, int(len(node['data']) * fraction))].fillna(0)
            try:
                cox_model = CoxPHFitter()
                cox_model.fit(df_local, duration_col="time", event_col="death", show_progress=False)
                prev_cindex = evaluate_cindex(get_global_coef_from_model(cox_model, node['features'], global_dim), df_local, node['features'], global_dim)
            except Exception as e:
                prev_cindex = 0
            node['prev_cindex'] = prev_cindex

        improvements = []

        for round_num in range(max_rounds):
            fraction = min(1.0, (round_num + 1) / max_rounds)

            # --- Local Training ---
            for node in nodes:
                df_local = node['data'].iloc[:max(1, int(len(node['data']) * fraction))].copy()
                df_local = df_local.fillna(0)
                try:
                    cox_model = CoxPHFitter()
                    cox_model.fit(df_local, duration_col="time", event_col="death", show_progress=False)
                    local_coef = get_global_coef_from_model(cox_model, node['features'], global_dim)
                except Exception as e:
                    local_coef = node['local_coef']
                if node['is_noisy']:
                    delta_noise = generate_noise(node, round_num, T_honest, T_ramp, epsilon_max)
                    node['noise'] += delta_noise
                    for idx_local, feat in enumerate(node['features']):
                        try:
                            global_idx = int(feat.split("_")[1])
                        except:
                            continue
                        local_coef[global_idx] += node['noise'][idx_local]
                node['local_coef'] = local_coef.copy()
                new_cindex = evaluate_cindex(node['local_coef'], df_local, node['features'], global_dim)
                node['new_cindex'] = new_cindex

            # --- Trust and Aggregation based on Scenario ---
            if scenario == "our":
                cluster_assignment = custom_clustering(nodes, n_clusters, lambda_clust, max_iter=10)
                # Trust update as in our simulation
                for clabel, node_ids in cluster_assignment.items():
                    for i in node_ids:
                        # For selected nodes (simulate selection as in our simulation):
                        for peer_id in node_ids:
                            if peer_id == i:
                                continue
                            peer = nodes[peer_id]
                            improvement = nodes[i]['new_cindex'] - nodes[i]['prev_cindex']
                            nodes[i]['trust'][peer_id] = max(0, min(nodes[i]['trust'][peer_id] + alpha * improvement, 1))
                cluster_fedavg(nodes, cluster_assignment)

            elif scenario == "baseline":
                # Simulate FLTrust baseline: use standard KMeans clustering (without risk adjustment)
                X = np.array([node['feature_vector'] for node in nodes])
                kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(X)
                labels = kmeans.labels_
                for i, node in enumerate(nodes):
                    node['cluster'] = labels[i]
                # Baseline trust update with a lower alpha (e.g., 0.05) and no risk difference term
                alpha_baseline = 0.05
                clusters = {i: [] for i in range(n_clusters)}
                for i, node in enumerate(nodes):
                    clusters[node['cluster']].append(i)
                for clabel, node_ids in clusters.items():
                    for i in node_ids:
                        for peer_id in node_ids:
                            if peer_id == i:
                                continue
                            improvement = nodes[i]['new_cindex'] - nodes[i]['prev_cindex']
                            nodes[i]['trust'][peer_id] = max(0, min(nodes[i]['trust'][peer_id] + alpha_baseline * improvement, 1))
                # Use standard KMeans centroid for aggregation
                # For simplicity, average all local coefficients in each cluster
                cluster_fedavg(nodes, clusters)

            elif scenario == "no_trust":
                # No trust update: simply average all local models equally
                for node in nodes:
                    node['trust'] = np.ones(num_nodes)
                # Federated averaging without clustering
                global_coef = np.mean([node['local_coef'] for node in nodes], axis=0)
                for node in nodes:
                    node['global_coef'] = global_coef.copy()

            # --- Count Improvement: number of nodes with new_cindex > prev_cindex ---
            improvement_count = sum(1 for node in nodes if node['new_cindex'] > node['prev_cindex'])
            improvements.append(improvement_count)

            # Update prev_cindex for next round
            for node in nodes:
                node['prev_cindex'] = node['new_cindex']

        improvements_all += np.array(improvements)

    improvements_all /= monte_carlo_runs
    return improvements_all

# Run simulations for the three scenarios
max_rounds = 100
monte_carlo_runs = 1  # For illustration; increase for statistical significance
our_results = simulate_scenario("our", epsilon_max=0.2, alpha=0.1, lambda_clust=0.6, n_clusters=4, monte_carlo_runs=monte_carlo_runs, max_rounds=max_rounds)
baseline_results = simulate_scenario("baseline", epsilon_max=0.2, alpha=0.1, lambda_clust=0.6, n_clusters=4, monte_carlo_runs=monte_carlo_runs, max_rounds=max_rounds)
no_trust_results = simulate_scenario("no_trust", epsilon_max=0.2, alpha=0.1, lambda_clust=0.6, n_clusters=4, monte_carlo_runs=monte_carlo_runs, max_rounds=max_rounds)

# Create a DataFrame table comparing the three scenarios (improvement counts per round)
df_comparison = pd.DataFrame({
    'Round': np.arange(1, max_rounds + 1),
    'Our_Model': our_results,
    'FLTrust_Baseline': baseline_results,
    'No_Trust': no_trust_results
})
print(df_comparison.head())

# Perform t-tests comparing Our_Model vs FLTrust_Baseline and Our_Model vs No_Trust
t_stat1, p_val1 = ttest_ind(our_results, baseline_results)
t_stat2, p_val2 = ttest_ind(our_results, no_trust_results)
print("\nStatistical Comparison (T-Test):")
print("Our_Model vs FLTrust_Baseline: t-statistic = {:.3f}, p-value = {:.3f}".format(t_stat1, p_val1))
print("Our_Model vs No_Trust: t-statistic = {:.3f}, p-value = {:.3f}".format(t_stat2, p_val2))

# Save the comparison table as a LaTeX table
latex_table = df_comparison.to_latex(index=False)
with open("comparison_table.tex", "w") as f:
    f.write(latex_table)
print("LaTeX comparison table saved to comparison_table.tex")


In [None]:
##############################################
# Experiment: Noise Injection Magnitude over Time
##############################################

def test_noise_injection(T_honest, T_ramp, epsilon_max_values, rounds=200, d=1):
    """
    For a dummy noisy node with a noise direction vector z (of dimension d),
    compute the instantaneous noise injection magnitude (n_t) for each round t,
    using the generate_noise function. Returns a dictionary mapping each epsilon_max
    value to the list of noise magnitudes over time.
    """
    noise_curves = {}
    rounds_list = np.arange(rounds)
    for eps in epsilon_max_values:
        noise_vals = []
        dummy_node = {'z': np.ones(d)}  # simple vector of ones
        for t in rounds_list:
            noise = generate_noise(dummy_node, t, T_honest, T_ramp, eps)
            magnitude = np.linalg.norm(noise)
            noise_vals.append(magnitude)
        noise_curves[eps] = noise_vals
    return rounds_list, noise_curves

# Set parameters for the noise injection experiment
T_honest_test = 20
T_ramp_test = 10
epsilon_max_values = [0.05, 0.1, 0.2, 0.4, 0.8]
rounds_test = 200

rounds_list, noise_curves = test_noise_injection(T_honest_test, T_ramp_test, epsilon_max_values, rounds=rounds_test, d=1)

output_dir = "noise_plots"
os.makedirs(output_dir, exist_ok=True)
for eps, curve in noise_curves.items():
    plt.figure(figsize=(10,6))
    plt.plot(rounds_list, curve, label=f'epsilon_max = {eps}')
    plt.xlabel('Round (t)')
    plt.ylabel('Noise magnitude n_t')
    plt.title('Noise magnitude over time (T_honest = {}, T_ramp = {})'.format(T_honest_test, T_ramp_test))
    plt.legend()
    plt.grid()
    filename = os.path.join(output_dir, f"noise_plot_epsilon_{eps}.png")
    plt.savefig(filename)
    plt.close()
    print(f"Saved noise plot for epsilon_max = {eps} to {filename}")
