Diffusion Model for generating synthetic survival data. PyTorch implementation of the paper: SurvDiff: Diffusion-Based Generative Modeling for Survival Analysis
First one needs to make the virtual environment and install all the requirements:
pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txtDatasets must contain:
- Duration column: event/censoring times
- Event column: 1 = event, 0 = censored
- Covariates: numerical or categorical features
1. Prepare data
python runnables/prepare_data.py --dataset aids2. Train
python runnables/train_survival.py --dataset aids --exp_name exp13. Generate synthetic data
python runnables/sample.py \
--model_path ckpts/exp1/model_epoch_1000.pt \
--num_samples 1000 \
--dataset aids \
--output_path synthetic_data.csvsurvdiff/
├── config/ # Training configs
├── data/ # Datasets
├── models/ # Models
├── modules/ # Neural components
├── runnables/ # Scripts (train, sample, prepare)
├── outputs/ # Plots, reports, synthetic data
├── trainer_tabdiff.py
├── trainer_survival.py
└── utils_train.py
This repo is based on the implementation of TabDiff.
