<a href="https://colab.research.google.com/github/dli-stats/distributed_cox_paper_repro/blob/main/distributed_cox_paper_repro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install Packages and Dependencies

To start, we will install our package [distributed_cox](https://github.com/dli-stats/distributed_cox/tree/master) and all its dependencies. 

Note that for reproducing the experiments, we install a [frozen version pointed to by the git tag `repro`](https://github.com/dli-stats/distributed_cox/releases/tag/repro).

In [1]:
!pip install --find-links https://storage.googleapis.com/jax-releases/jax_releases.html git+https://github.com/dli-stats/distributed_cox.git@repro

Looking in links: https://storage.googleapis.com/jax-releases/jax_releases.html
Collecting git+https://github.com/dli-stats/distributed_cox.git@repro
  Cloning https://github.com/dli-stats/distributed_cox.git (to revision repro) to /tmp/pip-req-build-sojuo97c
  Running command git clone -q https://github.com/dli-stats/distributed_cox.git /tmp/pip-req-build-sojuo97c
  Running command git checkout -q a2dab476da42fac7522ea2c4b17caf9e41483f88
Collecting oryx==0.1.4
  Downloading oryx-0.1.4.tar.gz (91 kB)
[K     |████████████████████████████████| 91 kB 10.5 MB/s 
[?25hCollecting jax==0.2.9
  Downloading jax-0.2.9.tar.gz (551 kB)
[K     |████████████████████████████████| 551 kB 47.2 MB/s 
[?25hCollecting jaxlib==0.1.59
  Downloading https://storage.googleapis.com/jax-releases/cuda111/jaxlib-0.1.59%2Bcuda111-cp37-none-manylinux2010_x86_64.whl (182.4 MB)
[K     |████████████████████████████████| 182.4 MB 79 kB/s 
Collecting dataclasses_json
  Downloading dataclasses_json-0.5.6-py3-none-an

# Data Preperation

In this section, we will download the simulated dataset and prepare it into the desired format.

First, download the raw simulated real data from GitHub.

In [2]:
!git clone https://github.com/dli-stats/distributed_cox_paper_simudata simulated_data

Cloning into 'simulated_data'...
remote: Enumerating objects: 12, done.[K
remote: Counting objects:   8% (1/12)[Kremote: Counting objects:  16% (2/12)[Kremote: Counting objects:  25% (3/12)[Kremote: Counting objects:  33% (4/12)[Kremote: Counting objects:  41% (5/12)[Kremote: Counting objects:  50% (6/12)[Kremote: Counting objects:  58% (7/12)[Kremote: Counting objects:  66% (8/12)[Kremote: Counting objects:  75% (9/12)[Kremote: Counting objects:  83% (10/12)[Kremote: Counting objects:  91% (11/12)[Kremote: Counting objects: 100% (12/12)[Kremote: Counting objects: 100% (12/12), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 12 (delta 1), reused 12 (delta 1), pack-reused 0[K
Unpacking objects: 100% (12/12), done.


Now load the data.

In [3]:
import pandas as pd

df = pd.read_csv("simulated_data/dat_std_simulated.csv", index_col=0)
df

Unnamed: 0,time,status,A,X2,X3,X6,X8,X9,X11,X12,X13,X16,X24,X26,indDP
1,22,1,1,0.138245,1,0,0,1,0,0,1,0,-1.320056,0.147349,2
2,30,0,0,0.325395,0,0,0,0,0,1,1,0,0.407715,0.184558,3
3,30,0,1,1.120127,1,1,1,0,0,1,0,1,-0.427274,1.199993,1
4,30,0,1,0.771549,1,1,1,1,0,1,0,1,0.536396,0.961339,2
5,30,0,1,-1.787154,0,0,0,0,0,0,0,0,-0.403683,-1.468517,3
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
11575,30,0,1,1.844121,1,0,0,0,0,0,1,0,0.110043,1.717636,2
11576,30,0,1,-0.063699,1,0,0,0,1,0,0,0,-1.249988,0.000226,1
11577,30,0,1,0.956254,1,0,0,0,1,0,1,1,0.983444,0.799436,2
11578,30,0,0,-0.606473,1,0,1,0,1,1,1,0,-2.268432,-0.584255,1


As we can see, the csv data file contains 18 columns. 
- **time**: time of the event or censoring time
- **status**: event indicator
- **A**: treatment 
- **X1** - **X26**: covariates corresponding to the covariates used in the real data analysis
- **indDP**: site indicator

The first step of reproducing the results is to read in this raw simulated data into a format consumable by our code.

We next convert the data into the following:
- `X`: an N by X_DIM matrix,  contains the predictors, where N is the sample size and X_DIM is the number of predictors 
- `delta`: a length N vector, contains status variable indicating observing an event or not (status)
- `group_labels`: a length N vector, contains the sites for each sample (indDP).
- `T`: length N vector, contains the observed times (time).

In [4]:
import re
import numpy as np
import jax.numpy as jnp

def convert_from_df(dataframe: pd.DataFrame):
  x_headers = ["A"] + list(
      sorted(
          (k for k in dataframe.keys() if re.match(r'X\d+', k) and len(k) > 1),
          key=lambda s: int(s[1:])))
  X = dataframe[x_headers]
  delta = dataframe["status"]
  T = dataframe["time"]
  group_labels = dataframe["indDP"] - 1
  X, delta, T, group_labels = map(lambda x: x.to_numpy(),
                                  (X, delta, T, group_labels))
  X = X.astype(np.float32)
  T = T.astype(np.float32)
  delta = delta.astype(bool)
  sorted_idx = np.argsort(-T)
  T = np.take(T, sorted_idx, axis=0)
  X = np.take(X, sorted_idx, axis=0)
  delta = np.take(delta, sorted_idx, axis=0)
  group_labels = np.take(group_labels, sorted_idx, axis=0)

  return X, delta, group_labels, T

X, delta, group_labels, T = convert_from_df(df)
N, X_DIM = X.shape
K = max(group_labels) + 1 # Number of sites
group_sizes = tuple(np.sum(group_labels == k) for k in range(K))

print("group sizes: ", group_sizes)

print("X shape: ", X.shape)
print("delta shape:", delta.shape)
print("group_labels shape:", group_labels.shape)
print("T shape:", T.shape)

group sizes:  (3825, 3921, 3833)
X shape:  (11579, 12)
delta shape: (11579,)
group_labels shape: (11579,)
T shape: (11579,)


# Analysis

We first import all the source functions needed to:
 
  1. perform the analysis to obtain the log HR estimate for each analysis method, and
  2. compute the corresponding estimated variance. 

In [5]:
import functools
from jax import jit
import distributed_cox.cox_solve as cox_solve

# We use newton-raphson as our solver
# Here are some configurations for that 
solver_config = dict(
    max_num_steps=40, # Sufficiently large number of steps
    loglik_eps=1e-5, 
    score_norm_eps=1e-3
)

# Some configurations for the distributed settings
distributed_config = dict(
  pt2_use_average_guess=True,
  hessian_use_taylor=True,
  taylor_order=1,
)

def get_solve_and_cov_fn(method:str, **kwargs):
   return jit(cox_solve.get_cox_solve_and_cov_fn(method, 
                                                 group_sizes=group_sizes, 
                                                 solver=solver_config, 
                                                 **kwargs))

# Functions for the 6 analysis methods.
# ========================================================
# These are just convenience functions used for reproducing the result.
# These functions all take in X, delta, beta_guess, group_labels and returns
# a solution and the analytical covariances.
# 
# Note that for a practical distributed algorithm, one will not have access
# to all the predictors X -- which should be sharded/distributed across sites.
# The subpackage `distributed_cox.distributed` contains necessary functions
# to perform such a distributed inference with a message passing protocal.
# 
# However, for the sake of reproducing numbers in Table S8, the convenience functions below 
# that computes the solution and covariances in one shot is sufficient.

unstratified_pooled = get_solve_and_cov_fn("unstratified_pooled")
unstratified_distributed = get_solve_and_cov_fn("unstratified_distributed", 
                                                distributed=distributed_config)

stratified_pooled = get_solve_and_cov_fn("stratified_pooled")
stratified_distributed = get_solve_and_cov_fn("stratified_distributed", 
                                              distributed=distributed_config)

multivariate_meta_analysis = get_solve_and_cov_fn("meta_analysis", 
                                meta_analysis=dict(univariate=False))
univariate_variate_meta_analysis = get_solve_and_cov_fn("meta_analysis", 
                                meta_analysis=dict(univariate=True))

# An initial guess used for all the solvers
beta_guess = np.zeros(X_DIM)



We now run all the analysis methods (unstratified pooled, unstratified distributed, stratified pooled, startified distributed, multivariate meta-analysis, and univariate meta-analysis).

When you first run the cells below, the compilation will take a few moments, but the computation will be fast one compiled. 

*That is, if you run the cell below twice, you will see that the second round will be much faster!

In [6]:
import time
import jax
import jax.tree_util as tu

def run_and_wait(name: str, analysis):
  """Executes `analysis` and wait for the result, also performs some logging."""
  print(f"Running {name}Analysis... ", end="", flush=True)
  start_time = time.time()
  result = analysis(X, delta, beta_guess, group_labels)
  compile_end = time.time()
  print(f"Compile finished in {compile_end - start_time:.2f}s.. ", 
        end="", flush=True)
  tu.tree_map(lambda x: x.block_until_ready(), result)
  execution_end = time.time()
  print(f"Execution finished in {execution_end - compile_end:.2f}s.. ", 
        end="", flush=True)
  print(f"Done.", flush=True)
  return result

unstratified_pooled_sol      = run_and_wait("Unstratified Pooled ", unstratified_pooled)
unstratified_distributed_sol = run_and_wait("Unstratified Distributed ", unstratified_distributed)
stratified_pooled_sol        = run_and_wait("Stratified Pooled ", stratified_pooled)
stratified_distributed_sol   = run_and_wait("Stratified Distributed ", stratified_distributed)
multivariate_meta_analysis_sol = run_and_wait("Multivariate Meta-", multivariate_meta_analysis)
univariate_meta_analysis_sol   = run_and_wait("Univariate Meta-", univariate_variate_meta_analysis)

Running Unstratified Pooled Analysis... Compile finished in 14.20s.. Execution finished in 0.05s.. Done.
Running Unstratified Distributed Analysis... Compile finished in 69.05s.. Execution finished in 0.10s.. Done.
Running Stratified Pooled Analysis... Compile finished in 37.68s.. Execution finished in 0.03s.. Done.
Running Stratified Distributed Analysis... Compile finished in 115.09s.. Execution finished in 0.12s.. Done.
Running Multivariate Meta-Analysis... Compile finished in 30.22s.. Execution finished in 0.03s.. Done.
Running Univariate Meta-Analysis... Compile finished in 32.39s.. Execution finished in 0.03s.. Done.


# Extract Results

Finally, we pull the results from the 6 analysis methods and format them into the table we want.

First let's define some helpers to extract the result.

In [7]:
import scipy

# Some helpers for collecting the results of interest

get_beta_hat = lambda sol: sol.pt2.guess

def get_ese(sol, cov_key):
  return jnp.sqrt(jnp.diag(sol.covs[cov_key]))

def compute_confidence_interval_overlap(beta_eq,
                                        std_eq,
                                        beta_eq_true,
                                        std_eq_true,
                                        lb=0.025,
                                        ub=1 - 0.025):
  f_orig_cdf = functools.partial(jax.scipy.stats.norm.cdf,
                                 loc=beta_eq_true,
                                 scale=std_eq_true)
  f_rel_cdf = functools.partial(jax.scipy.stats.norm.cdf,
                                loc=beta_eq,
                                scale=std_eq)
  L_rel = beta_eq + scipy.stats.norm.ppf(lb) * std_eq
  U_rel = beta_eq + scipy.stats.norm.ppf(ub) * std_eq
  L_orig = beta_eq_true + scipy.stats.norm.ppf(lb) * std_eq_true
  U_orig = beta_eq_true + scipy.stats.norm.ppf(ub) * std_eq_true

  I = ((f_orig_cdf(U_rel) - f_orig_cdf(L_rel)) +
       (f_rel_cdf(U_orig) - f_rel_cdf(L_orig))) / 2

  return I

Now we extract all the result values.

In [8]:
# Collect beta_hat
unstratified_pooled_beta_hat      = get_beta_hat(unstratified_pooled_sol)
unstratified_distributed_beta_hat = get_beta_hat(unstratified_distributed_sol)
stratified_pooled_beta_hat        = get_beta_hat(stratified_pooled_sol)
stratified_distributed_beta_hat   = get_beta_hat(stratified_distributed_sol)
multivariate_meta_analysis_beta_hat  = get_beta_hat(multivariate_meta_analysis_sol)
univariate_meta_analysis_beta_hat    = get_beta_hat(univariate_meta_analysis_sol)

# Collect the covariance of interest
unstratified_pooled_ese = get_ese(unstratified_pooled_sol, "cov:no_group_correction|no_sandwich|no_cox_correction|no_sum_first")
unstratified_distributed_ese = get_ese(unstratified_distributed_sol, "cov:group_correction|no_sandwich|no_cox_correction|no_sum_first")
stratified_pooled_ese = get_ese(stratified_pooled_sol, "cov:no_group_correction|no_sandwich|no_cox_correction|no_sum_first")
stratified_distributed_ese = get_ese(stratified_distributed_sol, "cov:group_correction|no_sandwich|no_cox_correction|no_sum_first")
multivariate_meta_analysis_ese = get_ese(multivariate_meta_analysis_sol, "cov:meta_analysis")
univariate_meta_analysis_ese = get_ese(univariate_meta_analysis_sol, "cov:meta_analysis")

# Compute cio (NOTE: not applicable to real data in practice because pooled data is not available)
unstratified_pooled_cios = compute_confidence_interval_overlap(unstratified_pooled_beta_hat, 
                                                              unstratified_pooled_ese, 
                                                              unstratified_pooled_beta_hat, 
                                                              unstratified_pooled_ese)
unstratified_distributed_cios = compute_confidence_interval_overlap(unstratified_distributed_beta_hat, 
                                                                   unstratified_distributed_ese, 
                                                                   unstratified_pooled_beta_hat, 
                                                                   unstratified_pooled_ese)
stratified_pooled_cios = compute_confidence_interval_overlap(stratified_pooled_beta_hat, 
                                                            stratified_pooled_ese, 
                                                            stratified_pooled_beta_hat, 
                                                            stratified_pooled_ese)
stratified_distributed_cios = compute_confidence_interval_overlap(stratified_distributed_beta_hat, 
                                                                 stratified_distributed_ese, 
                                                                 stratified_pooled_beta_hat, 
                                                                 stratified_pooled_ese)
multivariate_meta_analysis_cios = compute_confidence_interval_overlap(multivariate_meta_analysis_beta_hat, 
                                                                     multivariate_meta_analysis_ese, 
                                                                     unstratified_pooled_beta_hat, 
                                                                     unstratified_pooled_ese)
univariate_meta_analysis_cios = compute_confidence_interval_overlap(univariate_meta_analysis_beta_hat, 
                                                                   univariate_meta_analysis_ese, 
                                                                   unstratified_pooled_beta_hat, 
                                                                   unstratified_pooled_ese)

You can inspect each these statistics: 

In [9]:
stratified_pooled_ese # change to any of the statistics above

DeviceArray([0.08018788, 0.2206358 , 0.09646697, 0.14364463, 0.08645597,
             0.08289067, 0.09646788, 0.08069014, 0.08574339, 0.10069784,
             0.04091563, 0.22263509], dtype=float32)

Here's a table visualization of Table S8 in the supplementary material.

In [12]:
methods = ["unstratified_pooled", "unstratified_distributed", 
           "stratified_pooled", "stratified_distributed", 
           "multivariate_meta_analysis", "univariate_meta_analysis"]
stats = ["log(HR)", "ese",]# "cios"]

header = pd.MultiIndex.from_product([methods, stats])
df = pd.DataFrame(index=range(X_DIM), 
                  columns=header)
for method in methods:
  for stat in stats:
    stat1 = "beta_hat" if stat == "log(HR)" else stat
    val = locals()[f"{method}_{stat1}"]
    df.loc[:, (method, stat)] = val
df

Unnamed: 0_level_0,unstratified_pooled,unstratified_pooled,unstratified_distributed,unstratified_distributed,stratified_pooled,stratified_pooled,stratified_distributed,stratified_distributed,multivariate_meta_analysis,multivariate_meta_analysis,univariate_meta_analysis,univariate_meta_analysis
Unnamed: 0_level_1,log(HR),ese,log(HR),ese,log(HR),ese,log(HR),ese,log(HR),ese,log(HR),ese
0,-0.343323,0.080176,-0.340499,0.081007,-0.343969,0.080188,-0.341047,0.081075,-0.343627,0.080241,-0.341394,0.080341
1,-0.483166,0.220629,-0.490233,0.230544,-0.482762,0.220636,-0.489872,0.230379,-0.485822,0.220835,-0.512128,0.221369
2,0.091849,0.096452,0.088991,0.097825,0.091236,0.096467,0.088371,0.097917,0.088855,0.096572,0.086991,0.09669
3,-0.21686,0.143656,-0.196638,0.150693,-0.217094,0.143645,-0.197192,0.150659,-0.201172,0.143764,-0.199869,0.144232
4,-0.169933,0.086452,-0.168633,0.088631,-0.169386,0.086456,-0.168204,0.088632,-0.168323,0.086576,-0.165982,0.086707
5,-0.094372,0.082876,-0.095029,0.083361,-0.094515,0.082891,-0.095198,0.083378,-0.094854,0.082917,-0.092455,0.083104
6,0.086915,0.096469,0.085723,0.09758,0.086298,0.096468,0.085105,0.097586,0.084525,0.096533,0.085843,0.096717
7,0.260662,0.080684,0.25132,0.084263,0.260858,0.08069,0.251525,0.084258,0.256377,0.080981,0.252335,0.081093
8,0.076135,0.085722,0.07743,0.086302,0.075584,0.085743,0.076891,0.08641,0.074077,0.08577,0.07463,0.085936
9,-0.097749,0.100714,-0.089427,0.103719,-0.097325,0.100698,-0.089142,0.103668,-0.090232,0.100895,-0.106721,0.101159
