Skip to content

CoxKAN: Kolmogorov-Arnold Networks for Interpretable, High-Performance Survival Analysis

License

Notifications You must be signed in to change notification settings

knottwill/CoxKAN

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CoxKAN

CoxKAN: Kolmogorov-Arnold Networks for Survival Analysis

InstallationUsageDatasetsReproducibilityCredits

This repository contains the codes accompanying the paper "CoxKAN: Kolmogorov-Arnold Networks for Interpretable, High-Performance Survival Analysis".

  • Paper: ArXiv.
  • Installation: pip install coxkan
  • Documentation: Read-the-Docs.
  • Quick-start: tutorials/intro.ipynb

Repo Structure:

├── checkpoints/        # Results / checkpoints from paper
├── configs/            # Model configuration files
├── coxkan/             # CoxKAN package 
├── data/               # Data 
├── docs/               # Documentation
├── media/              # Figures used in paper
├── reprod/             # Reproducability instructions/code
├── tutorials/          # Tutorials for CoxKAN
|
# standard stuff:
├── .gitignore         
├── LICENSE          
├── README.md          
└── setup.py            

Installation

Pip

CoxKAN can be installed via:

pip install coxkan

Git

Alternatively, may desire the full codebase and environment that was used to produce all results in the associated paper:

git clone https://github.com/knottwill/CoxKAN.git 
cd CoxKAN
pip install -r reprod/requirements.txt

Please refer to reproducibility instructions in reprod/README.md.

Usage

Find tutorials in tutorials/ or Read-the-Docs

Example

from coxkan import CoxKAN
from coxkan.datasets import metabric 

df_train, df_test = metabric.load(split=True)

ckan = CoxKAN(width=[len(metabric.covariates), 1])

_ = ckan.train(
    df_train, 
    df_test, 
    duration_col='duration', 
    event_col='event',
    steps=100)

# evaluate model
ckan.cindex(df_test)
>>> 0.6441975461899737

CoxKAN Package

The coxkan/ package has 4 main components:

coxkan/
    ├── datasets/             # datasets subpackage
    ├── CoxKAN.py             # CoxKAN model
    ├── utils.py              # utility functions
    └── hyperparam_search.py  # hyperparameter searching

Datasets

Synthetic Datasets

coxkan.datasets.create_dataset makes it easy to generate synthetic survival data assuming a proportional-hazards, time-independant hazard function: $$h = h_0 e^{\theta(\mathbf{x})} \rightarrow T_s \sim \text{Exp}(h)$$ and uniform censoring distribution $T_c \sim \text{Uniform}(0, T_{max})$.

In the example below, we use a log-partial hazard of $\theta(\mathbf{x}) = 5 e^{-2(x_1^2 + x_2^2)}$ and a baseline hazard of $h_0 = 0.01$.

from coxkan.datasets import create_dataset

log_partial_hazard = lambda x1, x2: 5*np.exp(-2*(x1**2 + x2**2))
df = create_dataset(log_partial_hazard, baseline_hazard=0.01, n_samples=10000)

Clinical Datasets

5 clinical datasets are available with the coxkan.datasets subpackage (inspired by pycox). For example:

from coxkan.datasets import gbsg
df_train, df_test = gbsg.load(split=True)

You can decide where to store them using the COXKAN_DATA_DIR environment variable.

Dataset Description Source
GBSG The Rotterdam & German Breast Cancer Study Group. DeepSurv
METABRIC The Molecular Taxonomy of Breast Cancer International Consortium. DeepSurv
SUPPORT Study to Understand Prognoses Preferences Outcomes and Risks of Treatment. DeepSurv
NWTCO National Wilm's Tumor Study. Rdatasets
FLCHAIN Assay of Serum Free Light Chain. Rdatasets

Unfortunately, DeepSurv did not retain the column names. We manually restored the names by obtaining the datasets elsewhere and comparing the columns (then we can use the same train/test split):

Genomics Datasets

We curated 4 genomics datasets from The Cancer Genome Atlas Program (TCGA). The raw or pre-processed data is available by request - please email me at knottenbeltwill@gmail.com.

Two of the datasets (GBMLGG, KIRC) were the unaltered datasets used in Pathomic Fusion

Dataset Description Source
STAD Stomach Adenocarcinoma. TCGA
BRCA Breast Invasive Carcinoma. TCGA
GBM/LGG Merged dataset from two types of brain cancer: Glioblastoma Multiforme and Lower Grade Glioma. Chen et al.
KIRC Kidney Renal Clear Cell Carcinoma. Chen et al.

Reproducibility

All results in the associated paper can be reproduced using the codes in reprod/. Please refer to the instructions in reprod/README.md.

Credits

Special thanks to:

  • All authors of Kolmogorov-Arnold Networks and the incredible pykan package.
  • Håvard Kvamme for pycox and torchtuples.

About

CoxKAN: Kolmogorov-Arnold Networks for Interpretable, High-Performance Survival Analysis

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published