## 1. Introduction

Estimation of treatment efficacy of real-world clinical interventions involves working with continuous outcomes such as time-to-death, re-hospitalization, or a composite event that may be subject to censoring. Causal reasoning in such scenarios requires decoupling the effects of confounding physiological characteristics that affect baseline survival rates from the effects of the interventions being assessed. In this paper, we present a latent variable approach to model heterogeneous treatment effects by proposing that an individual can belong to one of latent clusters with distinct response characteristics. We show that this latent structure can mediate the base survival rates and helps determine the effects of an intervention. We demonstrate the ability of our approach to discover actionable phenotypes of individuals based on their treatment response on multiple large randomized clinical trials originally conducted to assess appropriate treatment strategies to reduce cardiovascular risk.

## 2. Synthetic Data Example

In [1]:
import sys
sys.path.append('../auton_survival/')

from datasets import load_dataset

# Load the synthetic dataset
outcomes, features, interventions = load_dataset(dataset='SYNTHETIC')

In [2]:
# Let's take a look at take the dataset
features.head(5)

Unnamed: 0,X1,X2,X3,X4,X5,X6,X7,X8
0,0.148745,1.892484,0.195254,0.860757,0.696523,0.483697,0.339551,0.374794
1,1.139439,-0.94333,0.411054,0.179533,0.428686,0.683057,0.600948,0.070483
2,-0.961237,0.782706,-0.305381,0.583576,0.157478,0.070556,0.03459,0.776005
3,0.466508,0.694348,-0.249651,1.567092,0.850959,0.416178,0.968841,0.863598
4,-0.249002,-0.552091,1.854651,-0.466234,0.860385,0.367184,0.954347,0.74893


In [3]:
# Visualize the dataset

### Hyper-parameters

In [4]:
# Hyper-parameters
random_seed = 0
test_size = 0.25

# Split the synthetic data into training and testing data
import numpy as np

np.random.seed(random_seed)
n = features.shape[0] 

test_idx = np.zeros(n).astype('bool')
test_idx[np.random.randint(n, size=int(n*test_size))] = True 

features_tr = features.iloc[~test_idx] 
outcomes_tr = outcomes.iloc[~test_idx]
interventions_tr = interventions[~test_idx]
print(f'Number of training data points: {len(features_tr)}')

features_te = features.iloc[test_idx] 
outcomes_te = outcomes.iloc[test_idx]
interventions_te = interventions[test_idx]
print(f'Number of test data points: {len(features_te)}')

Number of training data points: 3899
Number of test data points: 1101


In [5]:
# Hyper-parameters to train model

k = 2 # Number of underlying base survival phenotypes
g = 2 # Number of underlying treatment effect phenotypes
layers = [50] # Number of neurons in each hidden layer.

model_random_seed = 0
epochs = 50
lr = 1e-3
bs = 128
vsize = 0.15

In [7]:
import torch

# Set torch and numpy random seeds
torch.manual_seed(model_random_seed)
np.random.seed(model_random_seed)

# Convert training data into torch tensors
x = features_tr.values.astype('float32')
t = outcomes_tr['time'].values.astype('float32')
e = outcomes_tr['event'].values.astype('float32')
a = interventions_tr.values.astype('float32')
print(f'Shape of covariates: {x.shape} | times: {t.shape} | events: {e.shape} | interventions: {a.shape}')

from models.cmhe import DeepCoxMixturesHeterogenousEffects

# Instantiate the CMHE model
model = DeepCoxMixturesHeterogenousEffects(k=k, g=g, layers=layers)

model = model.fit(x, t, e, a, vsize=0.15, val_data=None, iters=50, learning_rate=1e-3, 
                  batch_size=128, optimizer="Adam", random_state=0)

Shape of covariates: (3899, 8) | times: (3899,) | events: (3899,) | interventions: (3899,)


UnboundLocalError: local variable 'act' referenced before assignment

In [None]:
##### Training !!!!




if len(layers): model = DeepCoxSubgroupMixture(k=k, g=g, inputdim=x.shape[1], hidden=layers[0]).float()
else: model = CoxSubgroupMixture(k=k, g=g, inputdim=x.shape[1]).float()

(model, breslow_splines), losses = train(model, train_data, val_data, 
                                     epochs=epochs, lr=lr, use_posteriors=True, 
                                     patience=patience, return_losses=True, bs=bs,
                                     smoothing_factor=smoothing_factor)

if return_model: return (model, breslow_splines) 

print("Treatment Effects:", model.treatment_effect)

zeta_probs_train = torch.exp(model(x, a)[0]).sum(dim=1).detach().numpy()
zeta_train =  np.argmax(zeta_probs_train, axis=1)

if use_cf_evaluation:
treated_outcomes_train, control_outcomes_train, _, _ = _load_estimated_counterfactuals(cf_path) 
max_treat_idx = _find_subgroup(zeta_probs_train, counterfactual_outcomes=(treated_outcomes_train, control_outcomes_train), MAX=True)
min_treat_idx = _find_subgroup(zeta_probs_train, counterfactual_outcomes=(treated_outcomes_train, control_outcomes_train), MAX=False)
else:
max_treat_idx = _find_subgroup(zeta_probs_train, factual_outcomes=(outcomes_train, interventions_train), MAX=True)
min_treat_idx = _find_subgroup(zeta_probs_train, factual_outcomes=(outcomes_train, interventions_train), MAX=False)

x, t, e, a = features_test.values, outcomes_test['time'].values, outcomes_test['event'].values, interventions_test.values
x, t, e, a = _convert_to_torch(x, t, e, a) 

zeta_probs_test = torch.exp(model(x, a)[0]).sum(dim=1).detach().numpy()

max_subgroup_probs = (zeta_probs_train[:, max_treat_idx], zeta_probs_test[:, max_treat_idx])
min_subgroup_probs = (zeta_probs_train[:, min_treat_idx], zeta_probs_test[:, min_treat_idx])

return max_subgroup_probs, min_subgroup_probs 

In [None]:
# Train the clustering phenotyper