Skip to content

PyTorch implementation of the paper: SurvDiff: Diffusion-Based Generative Modeling for Survival Analysis

Notifications You must be signed in to change notification settings

mariebrockschmidt/SurvDiff

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SurvDiff

Diffusion Model for generating synthetic survival data. PyTorch implementation of the paper: SurvDiff: Diffusion-Based Generative Modeling for Survival Analysis

Alt text

Setup

Installations

First one needs to make the virtual environment and install all the requirements:

pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt

Prerequisites

Datasets must contain:

  • Duration column: event/censoring times
  • Event column: 1 = event, 0 = censored
  • Covariates: numerical or categorical features

Training Example

1. Prepare data

python runnables/prepare_data.py --dataset aids

2. Train

python runnables/train_survival.py --dataset aids --exp_name exp1

3. Generate synthetic data

python runnables/sample.py \
    --model_path ckpts/exp1/model_epoch_1000.pt \
    --num_samples 1000 \
    --dataset aids \
    --output_path synthetic_data.csv

Project Structure

survdiff/
├── config/           # Training configs
├── data/             # Datasets
├── models/           # Models
├── modules/          # Neural components
├── runnables/        # Scripts (train, sample, prepare)
├── outputs/          # Plots, reports, synthetic data
├── trainer_tabdiff.py
├── trainer_survival.py
└── utils_train.py

Acknowledgement

This repo is based on the implementation of TabDiff.

About

PyTorch implementation of the paper: SurvDiff: Diffusion-Based Generative Modeling for Survival Analysis

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published