<a href="https://colab.research.google.com/github/dli-stats/distributed_cox_paper_repro/blob/main/distributed_cox_paper_repro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Packages and Dependencies




To start, we will install the packages and all the required dependencies. 

In [None]:
!pip install git+https://github.com/dli-stats/distributed_cox.git

Collecting git+https://github.com/dli-stats/distributed_cox.git
  Cloning https://github.com/dli-stats/distributed_cox.git to /tmp/pip-req-build-foj1q73t
  Running command git clone -q https://github.com/dli-stats/distributed_cox.git /tmp/pip-req-build-foj1q73t
Collecting oryx==0.1.4
[?25l  Downloading https://files.pythonhosted.org/packages/de/31/8a42ab5140418145a554b0d6a3286eb9b4178724c4e0f1f130f465b20c39/oryx-0.1.4.tar.gz (91kB)
[K     |████████████████████████████████| 92kB 4.0MB/s 
[?25hCollecting jax==0.2.9
[?25l  Downloading https://files.pythonhosted.org/packages/6d/4b/cd013403adac3a7b3b6a616bf40abc7ba767fba63facdb260d2f09ba4e18/jax-0.2.9.tar.gz (551kB)
[K     |████████████████████████████████| 552kB 10.0MB/s 
[?25hCollecting jaxlib==0.1.59
[?25l  Downloading https://files.pythonhosted.org/packages/a3/03/4bd29f0e81aaff4e51fa2bbcb6c5391ecec8d0fc9137f995b4c8827f18c7/jaxlib-0.1.59-cp37-none-manylinux2010_x86_64.whl (34.3MB)
[K     |████████████████████████████████| 34.3MB 

# Data Preperation

In this section, we will download the simulated dataset and prepare it into the desired format.

First, download the raw simulated real data from GitHub.

In [None]:
!git clone https://github.com/dli-stats/distributed_cox_paper_simudata simulated_data

Now load the data.

In [None]:
import pandas as pd

df = pd.read_csv("simulated_data/dat_std_simulated.csv", index_col=0)
df

Unnamed: 0,time,status,A,X1,X2,X3,X6,X8,X9,X11,X12,X13,X15,X16,X24,X25,X26,indDP
1,22,1,1,1,0.138245,1,0,0,1,0,0,1,0,0,-1.320056,0,0.147349,2
2,30,0,0,1,0.325395,0,0,0,0,0,1,1,0,0,0.407715,0,0.184558,3
3,30,0,1,0,1.120127,1,1,1,0,0,1,0,0,1,-0.427274,1,1.199993,1
4,30,0,1,0,0.771549,1,1,1,1,0,1,0,0,1,0.536396,1,0.961339,2
5,30,0,1,1,-1.787154,0,0,0,0,0,0,0,0,0,-0.403683,1,-1.468517,3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11575,30,0,1,1,1.844121,1,0,0,0,0,0,1,0,0,0.110043,0,1.717636,2
11576,30,0,1,1,-0.063699,1,0,0,0,1,0,0,0,0,-1.249988,0,0.000226,1
11577,30,0,1,1,0.956254,1,0,0,0,1,0,1,0,1,0.983444,0,0.799436,2
11578,30,0,0,0,-0.606473,1,0,1,0,1,1,1,0,0,-2.268432,0,-0.584255,1


As we can see, the csv data file contains 18 columns. 
- **time**: time of the event or censoring time
- **status**: event indicator
- **A**: treatment 
- **X1** - **X26**: covariates corresponding to the covariates used in the real data analysis
- **indDP**: site indicator

The first step of reproducing the results is to read in this raw simulated data into a format consumable by our code.

We convert the data into the following:
- `X`: an N by X_DIM matrix,  contains the predictors, where N is the sample size and X_DIM is the number of predictors 
- `delta`: a length N vector, contains status variable indicating observing an event or not (status)
- `group_labels`: a length N vector, contains the sites for each sample (indDP).
- `T`: length N vector, contains the observed times (time).

In [None]:
import re
import numpy as np
import jax.numpy as jnp

def convert_from_df(dataframe: pd.DataFrame):
  x_headers = ["A"] + list(
      sorted(
          (k for k in dataframe.keys() if re.match(r'X\d+', k) and len(k) > 1),
          key=lambda s: int(s[1:])))
  X = dataframe[x_headers]
  delta = dataframe["status"]
  T = dataframe["time"]
  group_labels = dataframe["indDP"] - 1
  X, delta, T, group_labels = map(lambda x: x.to_numpy(),
                                  (X, delta, T, group_labels))
  X = X.astype(np.float32)
  T = T.astype(np.float32)
  delta = delta.astype(bool)
  sorted_idx = np.argsort(-T)
  T = np.take(T, sorted_idx, axis=0)
  X = np.take(X, sorted_idx, axis=0)
  delta = np.take(delta, sorted_idx, axis=0)
  group_labels = np.take(group_labels, sorted_idx, axis=0)

  return X, delta, group_labels, T

X, delta, group_labels, T = convert_from_df(df)
N, X_DIM = X.shape
K = max(group_labels) + 1 # Number of sites
group_sizes = tuple(np.sum(group_labels == k) for k in range(K))

print("group sizes: ", group_sizes)

print("X shape: ", X.shape)
print("delta shape:", delta.shape)
print("group_labels shape:", group_labels.shape)
print("T shape:", T.shape)

# Analysis


We first import all the source functions needed to:
 
  1. perform the analysis to obtain the log HR estimate for each analysis method, and
  2. compute the corresponding estimated variance. 


In [None]:
import functools
from jax import jit
import distributed_cox.cox_solve as cox_solve

solver_config = dict(
    max_num_steps=40, 
    loglik_eps=1e-5, 
    score_norm_eps=1e-3
)

distributed_config = dict(
  pt2_use_average_guess=True,
  hessian_use_taylor=True,
  taylor_order=1,
)

def get_solve_and_cov_fn(eq:str, **kwargs):
   return jit(cox_solve.get_cox_solve_and_cov_fn(eq, 
                                                 group_sizes=group_sizes, 
                                                 solver=solver_config, 
                                                 **kwargs))

# functions for the 6 analysis methods

unstratified_pooled = get_solve_and_cov_fn("eq1")
unstratified_distributed = get_solve_and_cov_fn("eq2", 
                                                distributed=distributed_config)

stratified_pooled = get_solve_and_cov_fn("eq3")
stratified_distributed = get_solve_and_cov_fn("eq4", 
                                              distributed=distributed_config)

multivariate_meta_analysis = get_solve_and_cov_fn("meta_analysis", 
                                meta_analysis=dict(univariate=False))
univariate_variate_meta_analysis = get_solve_and_cov_fn("meta_analysis", 
                                meta_analysis=dict(univariate=True))

# An initial guess used for all the solvers
beta_guess = np.zeros(X_DIM)

We now run all the analysis methods (unstratified pooled, unstratified distributed, stratified pooled, startified distributed, multivariate meta-analysis, and univariate meta-analysis).

When you first run the cells below, the complication will take a few moments, but the computation will be fast one compiled. 

*That is, if you run the cell below twice, you will see that the second round will be much faster!

In [None]:
import time
import jax
import jax.tree_util as tu

def run_and_wait(name: str, analysis):
  """Executes `analysis` and wait for the result, also performs some logging."""
  print(f"Running {name}Analysis... ", end="", flush=True)
  start_time = time.time()
  result = analysis(X, delta, beta_guess, group_labels)
  compile_end = time.time()
  print(f"Compile finished in {compile_end - start_time:.2f}s.. ", 
        end="", flush=True)
  tu.tree_map(lambda x: x.block_until_ready(), result)
  execution_end = time.time()
  print(f"Execution finished in {execution_end - compile_end:.2f}s.. ", 
        end="", flush=True)
  print(f"Done.", flush=True)
  return result

unstratified_pooled_sol      = run_and_wait("Unstratified Pooled ", unstratified_pooled)
unstratified_distributed_sol = run_and_wait("Unstratified Distributed ", unstratified_distributed)
stratified_pooled_sol        = run_and_wait("Stratified Pooled ", stratified_pooled)
stratified_distributed_sol   = run_and_wait("Stratified Distributed ", stratified_distributed)
multivariate_meta_analysis_sol = run_and_wait("Multivariate Meta-", multivariate_meta_analysis)
univariate_meta_analysis_sol   = run_and_wait("Univariate Meta-", univariate_variate_meta_analysis)

# Extract Results

Finally, we pull the results from the 6 analysis methods and format them into the table we want. 

In [None]:
import scipy

# Some helpers for collecting the results of interest
get_beta_hat = lambda sol: sol.pt2.guess

def get_ese(sol, cov_key):
  return jnp.sqrt(jnp.diag(sol.covs[cov_key]))

def compute_confidence_interval_overlap(beta_eq,
                                        std_eq,
                                        beta_eq_true,
                                        std_eq_true,
                                        lb=0.025,
                                        ub=1 - 0.025):
  f_orig_cdf = functools.partial(jax.scipy.stats.norm.cdf,
                                 loc=beta_eq_true,
                                 scale=std_eq_true)
  f_rel_cdf = functools.partial(jax.scipy.stats.norm.cdf,
                                loc=beta_eq,
                                scale=std_eq)
  L_rel = beta_eq + scipy.stats.norm.ppf(lb) * std_eq
  U_rel = beta_eq + scipy.stats.norm.ppf(ub) * std_eq
  L_orig = beta_eq_true + scipy.stats.norm.ppf(lb) * std_eq_true
  U_orig = beta_eq_true + scipy.stats.norm.ppf(ub) * std_eq_true

  I = ((f_orig_cdf(U_rel) - f_orig_cdf(L_rel)) +
       (f_rel_cdf(U_orig) - f_rel_cdf(L_orig))) / 2

  return I

In [None]:
# Collect beta_hat
unstratified_pooled_beta_hat      = get_beta_hat(unstratified_pooled_sol)
unstratified_distributed_beta_hat = get_beta_hat(unstratified_distributed_sol)
stratified_pooled_beta_hat        = get_beta_hat(stratified_pooled_sol)
stratified_distributed_beta_hat   = get_beta_hat(stratified_distributed_sol)
multivariate_meta_analysis_beta_hat  = get_beta_hat(multivariate_meta_analysis_sol)
univariate_meta_analysis_beta_hat    = get_beta_hat(univariate_meta_analysis_sol)

# Collect the covariance of interest
unstratified_pooled_ese = get_ese(unstratified_pooled_sol, "cov:no_group_correction|no_sandwich|no_cox_correction|no_sum_first")
unstratified_distributed_ese = get_ese(unstratified_distributed_sol, "cov:group_correction|no_sandwich|no_cox_correction|no_sum_first")
stratified_pooled_ese = get_ese(stratified_pooled_sol, "cov:no_group_correction|no_sandwich|no_cox_correction|no_sum_first")
stratified_distributed_ese = get_ese(stratified_distributed_sol, "cov:group_correction|no_sandwich|no_cox_correction|no_sum_first")
multivariate_meta_analysis_ese = get_ese(multivariate_meta_analysis_sol, "cov:meta_analysis")
univariate_meta_analysis_ese = get_ese(univariate_meta_analysis_sol, "cov:meta_analysis")

# Compute cio
unstratified_pooled_cios = compute_confidence_interval_overlap(unstratified_pooled_beta_hat, 
                                                              unstratified_pooled_ese, 
                                                              unstratified_pooled_beta_hat, 
                                                              unstratified_pooled_ese)
unstratified_distributed_cios = compute_confidence_interval_overlap(unstratified_distributed_beta_hat, 
                                                                   unstratified_distributed_ese, 
                                                                   unstratified_pooled_beta_hat, 
                                                                   unstratified_pooled_ese)
stratified_pooled_cios = compute_confidence_interval_overlap(stratified_pooled_beta_hat, 
                                                            stratified_pooled_ese, 
                                                            stratified_pooled_beta_hat, 
                                                            stratified_pooled_ese)
stratified_distributed_cios = compute_confidence_interval_overlap(stratified_distributed_beta_hat, 
                                                                 stratified_distributed_ese, 
                                                                 stratified_pooled_beta_hat, 
                                                                 stratified_pooled_ese)
multivariate_meta_analysis_cios = compute_confidence_interval_overlap(multivariate_meta_analysis_beta_hat, 
                                                                     multivariate_meta_analysis_ese, 
                                                                     unstratified_pooled_beta_hat, 
                                                                     unstratified_pooled_ese)
univariate_meta_analysis_cios = compute_confidence_interval_overlap(univariate_meta_analysis_beta_hat, 
                                                                   univariate_meta_analysis_ese, 
                                                                   unstratified_pooled_beta_hat, 
                                                                   unstratified_pooled_ese)

You can inspect each these statistics: 

In [None]:
stratified_pooled_ese # change to any of the statistics above

Here's a table visualization of Table S6 in the supplementary material.

In [None]:
methods = ["unstratified_pooled", "unstratified_distributed", 
           "stratified_pooled", "stratified_distributed", 
           "multivariate_meta_analysis", "univariate_meta_analysis"]
stats = ["log(HR)", "ese", "cios"]

header = pd.MultiIndex.from_product([methods, stats])
df = pd.DataFrame(index=range(X_DIM), 
                  columns=header)
for method in methods:
  for stat in stats:
    stat1 = "beta_hat" if stat == "log(HR)" else stat
    val = locals()[f"{method}_{stat1}"]
    df.loc[:, (method, stat)] = val
df

Unnamed: 0_level_0,unstratified_pooled,unstratified_pooled,unstratified_pooled,unstratified_distributed,unstratified_distributed,unstratified_distributed,stratified_pooled,stratified_pooled,stratified_pooled,stratified_distributed,stratified_distributed,stratified_distributed,multivariate_meta_analysis,multivariate_meta_analysis,multivariate_meta_analysis,univariate_meta_analysis,univariate_meta_analysis,univariate_meta_analysis
Unnamed: 0_level_1,log(HR),ese,cios,log(HR),ese,cios,log(HR),ese,cios,log(HR),ese,cios,log(HR),ese,cios,log(HR),ese,cios
0,-0.332484,0.080267,0.95,-0.318657,0.083456,0.946241,-0.333016,0.080276,0.95,-0.318878,0.083761,0.946014,-0.332692,0.080329,0.949999,-0.329679,0.080508,0.949858
1,0.067154,0.093844,0.95,0.082734,0.095312,0.946809,0.068854,0.093859,0.95,0.083851,0.095148,0.947051,0.067213,0.093951,0.95,0.063231,0.094204,0.949796
2,-0.507255,0.221013,0.95,-0.552357,0.239132,0.943626,-0.507198,0.221029,0.95,-0.551326,0.23873,0.943892,-0.509879,0.2213,0.949983,-0.53939,0.221998,0.947579
3,0.102681,0.09677,0.95,0.106871,0.098904,0.949635,0.102361,0.096779,0.95,0.106602,0.098954,0.949624,0.101225,0.096843,0.949974,0.10094,0.097038,0.949961
4,-0.203842,0.150227,0.95,-0.165398,0.164648,0.940532,-0.203771,0.150207,0.95,-0.165354,0.164724,0.940511,-0.1861,0.150179,0.9484,-0.195609,0.150751,0.949653
5,-0.178424,0.086561,0.95,-0.183143,0.090514,0.949028,-0.17753,0.086564,0.95,-0.182357,0.090576,0.948994,-0.177069,0.086711,0.949971,-0.176778,0.086975,0.949951
6,-0.086646,0.083628,0.95,-0.090756,0.086367,0.949395,-0.087013,0.083654,0.95,-0.091277,0.086409,0.949371,-0.087344,0.083615,0.949992,-0.085424,0.083881,0.949973
7,0.020118,0.099569,0.95,0.022392,0.103748,0.949393,0.019626,0.099556,0.95,0.022121,0.103819,0.94936,0.0163,0.09961,0.949832,0.020899,0.099926,0.949989
8,0.246817,0.081014,0.95,0.211235,0.092181,0.925683,0.247245,0.081022,0.95,0.21156,0.092747,0.925249,0.243078,0.081311,0.949753,0.239387,0.081456,0.949032
9,0.064316,0.085833,0.95,0.076371,0.088899,0.947421,0.063576,0.085861,0.95,0.076223,0.089516,0.947057,0.06124,0.085906,0.949853,0.060631,0.086128,0.949786
