In [8]:
%pip install -U "bart-survival" "arviz" "xarray" "pymc" "pymc-bart" "numpy<2.0" "scipy<1.11"


Collecting pymc-bart
  Using cached pymc_bart-0.10.0-py3-none-any.whl.metadata (4.6 kB)
Collecting numpy<2.0
  Using cached numpy-1.26.4-cp39-cp39-macosx_11_0_arm64.whl.metadata (61 kB)
INFO: pip is looking at multiple versions of pymc-bart to determine which version is compatible with other requirements. This could take a while.
Collecting pymc-bart
  Using cached pymc_bart-0.9.2-py3-none-any.whl.metadata (4.6 kB)
  Using cached pymc_bart-0.9.1-py3-none-any.whl.metadata (4.6 kB)
  Using cached pymc_bart-0.9.0-py3-none-any.whl.metadata (4.6 kB)
  Using cached pymc_bart-0.8.2-py3-none-any.whl.metadata (4.4 kB)
  Using cached pymc_bart-0.8.1-py3-none-any.whl.metadata (4.4 kB)
  Using cached pymc_bart-0.8.0-py3-none-any.whl.metadata (4.4 kB)
  Using cached pymc_bart-0.7.1-py3-none-any.whl.metadata (4.4 kB)
INFO: pip is still looking at multiple versions of pymc-bart to determine which version is compatible with other requirements. This could take a while.
  Using cached pymc

In [9]:
from lifelines.datasets import load_rossi
from bart_survival import surv_bart as sb
import numpy as np

######################################
# Load rossi dataset from lifelines
rossi = load_rossi()
names = rossi.columns.to_numpy()
rossi = rossi.to_numpy()

######################################
# Transform data into 'augmented' dataset
# Requires creation of the training dataset and a predictive dataset for inference
trn = sb.get_surv_pre_train(
    y_time=rossi[:,0],
    y_status=rossi[:,1],
    x = rossi[:,2:],
    time_scale=7
)

post_test = sb.get_posterior_test(
    y_time=rossi[:,0],
    y_status=rossi[:,1],
    x = rossi[:,2:],
    time_scale=7
)

######################################
# Instantiate the BART models
# model_dict is defines specific model parameters
model_dict = {"trees": 50,
    "split_rules": [
        "pmb.ContinuousSplitRule()", # time
        "pmb.OneHotSplitRule()", # fin
        "pmb.ContinuousSplitRule()",  # age
        "pmb.OneHotSplitRule()", # race
        "pmb.OneHotSplitRule()", # wexp
        "pmb.OneHotSplitRule()", # mar
        "pmb.OneHotSplitRule()", # paro
        "pmb.ContinuousSplitRule()", # prio
    ]
}
# sampler_dict defines specific sampling parameters
sampler_dict = {
            "draws": 200,
            "tune": 200,
            "cores": 8,
            "chains": 8,
            "compute_convergence_checks": False
        }
BSM = sb.BartSurvModel(model_config=model_dict, sampler_config=sampler_dict)

#####################################
# Fit Model
BSM.fit(
    y =  trn["y"],
    X = trn["x"],
    weights=trn["w"],
    coords = trn["coord"],
    random_seed=5
)

# Get posterior predictive for evaluation.
post1 = BSM.sample_posterior_predictive(X_pred=post_test["post_x"], coords=post_test["coords"])

# Convert to SV probability.
sv_prob = sb.get_sv_prob(post1)


Multiprocess sampling (8 chains in 8 jobs)
PGBART: [f]


Sampling 8 chains for 200 tune and 200 draw iterations (1_600 + 1_600 draws total) took 9 seconds.
Sampling: [f]


In [10]:
sv_prob

{'prob': array([[[0.05596996, 0.06444161, 0.07741863, ..., 0.07610453,
          0.07146469, 0.0478056 ],
         [0.08796638, 0.10106372, 0.12226745, ..., 0.14347506,
          0.13609504, 0.09687294],
         [0.04690611, 0.05504917, 0.06864433, ..., 0.11266153,
          0.11803959, 0.0917231 ],
         ...,
         [0.01896996, 0.01868439, 0.0244235 , ..., 0.03305725,
          0.03514822, 0.01831914],
         [0.02767041, 0.03250586, 0.04011887, ..., 0.04589881,
          0.04862446, 0.03135156],
         [0.01043219, 0.01007625, 0.0129527 , ..., 0.01781143,
          0.01906031, 0.00930301]],
 
        [[0.04621877, 0.06679033, 0.06679033, ..., 0.06003639,
          0.05330596, 0.05330596],
         [0.05879051, 0.08460081, 0.08460081, ..., 0.08531744,
          0.08533358, 0.08533358],
         [0.05119797, 0.06679885, 0.06679885, ..., 0.07740295,
          0.08636034, 0.08636034],
         ...,
         [0.03477013, 0.04981479, 0.04981479, ..., 0.04850022,
          0.0547