In [1]:
# Install coxkan
! pip install coxkan

# Hyperparameter searching

How to perform hyperparameter searches (sweeps) using the `coxkan.hyperparameter_search` module

In [2]:
from coxkan import CoxKAN
from coxkan.datasets import create_dataset
from coxkan.hyperparam_search import Sweep
from sklearn.model_selection import train_test_split
import numpy as np

# Generate a synthetic dataset
log_partial_hazard = lambda x1, x2: x1**2 + x2**2
df = create_dataset(log_partial_hazard, baseline_hazard=0.01, n_samples=1000)

df_train_full, df_test = train_test_split(df, test_size=0.2, random_state=42)
df_train, df_val = train_test_split(df_train_full, test_size=0.2, random_state=42)

Concordance index of true expression: 0.6139


  from .autonotebook import tqdm as notebook_tqdm


### Sweep using validation set

In [3]:
sweep = Sweep(search_config='./search_config.yml')

sweep.run_val(df_train, df_val, duration_col='duration', event_col='event', 
              n_trials=10, # number of trials
              n_runs_per_trial=1, # number of runs for each set of hyperparameters
              save_params='./result.yml',
              seed=42) # save the best hyperparameters to this file

[I 2024-08-15 14:00:59,645] A new study created in memory with name: no-name-969f6d6f-0b8e-4942-bd81-3cbead7c9fc8
Best trial: 0. Best value: 0.573135:  10%|█         | 1/10 [00:01<00:14,  1.64s/it]

Best model loaded (early stopping).
Run 0 c-index: 0.5731345885634589
[I 2024-08-15 14:01:01,299] Trial 0 finished with value: 0.5731345885634589 and parameters: {'num_hidden': 0, 'hidden_dim': 5, 'base_fun': 'silu', 'grid': 3, 'k': 3, 'noise_scale': 0.031198904067240532, 'noise_scale_base': 0.011616722433639893, 'early_stopping': True, 'lr': 0.013311216080736894, 'steps': 52, 'lamb': 0.014548647782429914, 'lamb_entropy': 13, 'lamb_coef': 1, 'prune_threshold': 0.009091248360355032}. Best is trial 0 with value: 0.5731345885634589.


Best trial: 1. Best value: 0.614888:  20%|██        | 2/10 [00:02<00:07,  1.00it/s]

Run 0 c-index: 0.6148884239888424
[I 2024-08-15 14:01:01,841] Trial 1 finished with value: 0.6148884239888424 and parameters: {'num_hidden': 0, 'hidden_dim': 2, 'base_fun': 'silu', 'grid': 3, 'k': 3, 'noise_scale': 0.1223705789444759, 'noise_scale_base': 0.027898772130408367, 'early_stopping': False, 'lr': 0.0023345864076016252, 'steps': 129, 'lamb': 0.002995106732375396, 'lamb_entropy': 8, 'lamb_coef': 3, 'prune_threshold': 0.0023225206359998862}. Best is trial 1 with value: 0.6148884239888424.


Best trial: 1. Best value: 0.614888:  30%|███       | 3/10 [00:02<00:04,  1.41it/s]

Run 0 c-index: 0.5007845188284519
[I 2024-08-15 14:01:02,215] Trial 2 finished with value: 0.5007845188284519 and parameters: {'num_hidden': 1, 'hidden_dim': 1, 'base_fun': 'linear', 'grid': 5, 'k': 3, 'noise_scale': 0.16167946962329224, 'noise_scale_base': 0.06092275383467414, 'early_stopping': False, 'lr': 0.0020914981329035616, 'steps': 62, 'lamb': 0.007427653651669052, 'lamb_entropy': 0, 'lamb_coef': 5, 'prune_threshold': 0.012938999080000846}. Best is trial 1 with value: 0.6148884239888424.


Best trial: 3. Best value: 0.61576:  40%|████      | 4/10 [00:04<00:08,  1.37s/it] 

Best model loaded (early stopping).
Run 0 c-index: 0.6157601115760112
[I 2024-08-15 14:01:04,591] Trial 3 finished with value: 0.6157601115760112 and parameters: {'num_hidden': 1, 'hidden_dim': 2, 'base_fun': 'linear', 'grid': 3, 'k': 3, 'noise_scale': 0.19391692555291173, 'noise_scale_base': 0.15502656467222292, 'early_stopping': True, 'lr': 0.00621870472776908, 'steps': 143, 'lamb': 0.0013273875307787924, 'lamb_entropy': 3, 'lamb_coef': 0, 'prune_threshold': 0.016266516538163217}. Best is trial 3 with value: 0.6157601115760112.


Best trial: 4. Best value: 0.623082:  50%|█████     | 5/10 [00:06<00:07,  1.41s/it]

Best model loaded (early stopping).
Run 0 c-index: 0.6230822873082287
[I 2024-08-15 14:01:06,072] Trial 4 finished with value: 0.6230822873082287 and parameters: {'num_hidden': 0, 'hidden_dim': 2, 'base_fun': 'silu', 'grid': 3, 'k': 3, 'noise_scale': 0.1085392166316497, 'noise_scale_base': 0.02818484499495253, 'early_stopping': True, 'lr': 0.09133995846860977, 'steps': 127, 'lamb': 0.002980735223012586, 'lamb_entropy': 0, 'lamb_coef': 4, 'prune_threshold': 0.035342867192380854}. Best is trial 4 with value: 0.6230822873082287.


Best trial: 4. Best value: 0.623082:  60%|██████    | 6/10 [00:09<00:07,  1.96s/it]

Best model loaded (early stopping).
Run 0 c-index: 0.571652719665272
[I 2024-08-15 14:01:09,105] Trial 5 finished with value: 0.571652719665272 and parameters: {'num_hidden': 1, 'hidden_dim': 4, 'base_fun': 'linear', 'grid': 3, 'k': 3, 'noise_scale': 0.17262068517511872, 'noise_scale_base': 0.12465962536551159, 'early_stopping': True, 'lr': 0.0008569331925053991, 'steps': 82, 'lamb': 0.01094409267507096, 'lamb_entropy': 10, 'lamb_coef': 5, 'prune_threshold': 0.023610746258097465}. Best is trial 4 with value: 0.6230822873082287.


Best trial: 4. Best value: 0.623082:  70%|███████   | 7/10 [00:11<00:05,  1.86s/it]

Best model loaded (early stopping).
Run 0 c-index: 0.5295502092050209
[I 2024-08-15 14:01:10,742] Trial 6 finished with value: 0.5295502092050209 and parameters: {'num_hidden': 0, 'hidden_dim': 4, 'base_fun': 'silu', 'grid': 5, 'k': 3, 'noise_scale': 0.09875911927287816, 'noise_scale_base': 0.10454656587639882, 'early_stopping': True, 'lr': 0.00021070472806578247, 'steps': 53, 'lamb': 0.009546156168956706, 'lamb_entropy': 5, 'lamb_coef': 3, 'prune_threshold': 0.04537832369630465}. Best is trial 4 with value: 0.6230822873082287.


Best trial: 4. Best value: 0.623082:  80%|████████  | 8/10 [00:12<00:03,  1.73s/it]

Best model loaded (early stopping).
Run 0 c-index: 0.6188981868898187
[I 2024-08-15 14:01:12,208] Trial 7 finished with value: 0.6188981868898187 and parameters: {'num_hidden': 0, 'hidden_dim': 3, 'base_fun': 'silu', 'grid': 3, 'k': 3, 'noise_scale': 0.05795029058275361, 'noise_scale_base': 0.03224425745080089, 'early_stopping': True, 'lr': 0.007947147424653748, 'steps': 138, 'lamb': 0.012055081153486717, 'lamb_entropy': 2, 'lamb_coef': 5, 'prune_threshold': 0.026967112095782536}. Best is trial 4 with value: 0.6230822873082287.


Best trial: 4. Best value: 0.623082:  90%|█████████ | 9/10 [00:16<00:02,  2.32s/it]

Best model loaded (early stopping).
Run 0 c-index: 0.553347280334728
[I 2024-08-15 14:01:15,814] Trial 8 finished with value: 0.553347280334728 and parameters: {'num_hidden': 1, 'hidden_dim': 5, 'base_fun': 'silu', 'grid': 3, 'k': 3, 'noise_scale': 0.08542155772525127, 'noise_scale_base': 0.16360295318449863, 'early_stopping': True, 'lr': 0.0034059785435329977, 'steps': 92, 'lamb': 0.0033316171570609535, 'lamb_entropy': 1, 'lamb_coef': 2, 'prune_threshold': 0.04714548519562596}. Best is trial 4 with value: 0.6230822873082287.


Best trial: 4. Best value: 0.623082: 100%|██████████| 10/10 [00:17<00:00,  1.77s/it]

Best model loaded (early stopping).
Run 0 c-index: 0.610965829846583
[I 2024-08-15 14:01:17,310] Trial 9 finished with value: 0.610965829846583 and parameters: {'num_hidden': 0, 'hidden_dim': 3, 'base_fun': 'silu', 'grid': 5, 'k': 3, 'noise_scale': 0.19248945898842226, 'noise_scale_base': 0.05035645916507284, 'early_stopping': True, 'lr': 0.000715354779469316, 'steps': 53, 'lamb': 0.009143465009698452, 'lamb_entropy': 8, 'lamb_coef': 0, 'prune_threshold': 0.013932323211830572}. Best is trial 4 with value: 0.6230822873082287.





<optuna.study.study.Study at 0x36058bef0>

In [4]:
import yaml 

with open('./result.yml', 'r') as f:
    config = yaml.safe_load(f)

ckan = CoxKAN(**config['init_params'])

if config['train_params']['early_stopping']:
    ckan.train(df_train, df_val, duration_col='duration', event_col='event', **config['train_params'])
else:
    ckan.train(df_train_full, None, duration_col='duration', event_col='event', **config['train_params'])

ckan.prune_edges(config['prune_threshold'])

ckan.cindex(df_test)

train loss: 2.33e+00 | val loss: 2.02e+00: 100%|█████████████████| 300/300 [00:01<00:00, 152.86it/s]

Best model loaded (early stopping).





0.6236619090098127

### Cross-Validation sweep

In [5]:
study = sweep.run_cv(df_train, duration_col='duration', event_col='event', 
                     n_trials=10,
                     n_folds=4,
                     save_params='./result_cv.yml',
                     verbose=0)


[I 2024-08-15 14:01:19,293] A new study created in memory with name: no-name-47a77c2f-36e0-4421-b506-f7795eeae521
  0%|          | 0/10 [00:00<?, ?it/s]

Fold 0 c-index: 0.5921134249003102
Fold 1 c-index: 0.5487295825771324
Fold 2 c-index: 0.5871625245843018


Best trial: 0. Best value: 0.569219:  10%|█         | 1/10 [00:03<00:28,  3.17s/it]

Fold 3 c-index: 0.5488721804511278
[I 2024-08-15 14:01:22,462] Trial 0 finished with value: 0.5692194281282181 and parameters: {'num_hidden': 1, 'hidden_dim': 2, 'base_fun': 'linear', 'grid': 4, 'k': 3, 'noise_scale': 0.005761103131819723, 'noise_scale_base': 0.13885886579345655, 'early_stopping': False, 'lr': 0.09160714155043215, 'steps': 118, 'lamb': 0.013312884060624764, 'lamb_entropy': 13, 'lamb_coef': 1, 'prune_threshold': 0.02517798341672471}. Best is trial 0 with value: 0.5692194281282181.
Best model loaded (early stopping).
Fold 0 c-index: 0.5254762959680992
Best model loaded (early stopping).
Fold 1 c-index: 0.5005444646098004
Best model loaded (early stopping).
Fold 2 c-index: 0.4983908456999821


Best trial: 0. Best value: 0.569219:  20%|██        | 2/10 [00:17<01:17,  9.65s/it]

Best model loaded (early stopping).
Fold 3 c-index: 0.5061477222467935
[I 2024-08-15 14:01:36,649] Trial 1 finished with value: 0.5076398321311688 and parameters: {'num_hidden': 1, 'hidden_dim': 5, 'base_fun': 'linear', 'grid': 5, 'k': 3, 'noise_scale': 0.03604483372105134, 'noise_scale_base': 0.15177672314541957, 'early_stopping': True, 'lr': 0.00010212400656174308, 'steps': 146, 'lamb': 0.0127156683050421, 'lamb_entropy': 12, 'lamb_coef': 2, 'prune_threshold': 0.0007681710338059034}. Best is trial 0 with value: 0.5692194281282181.
Fold 0 c-index: 0.5054497120070891
Fold 1 c-index: 0.5
Fold 2 c-index: 0.4808242445914536


Best trial: 0. Best value: 0.569219:  30%|███       | 3/10 [00:20<00:47,  6.85s/it]

Fold 3 c-index: 0.5150818222025653
[I 2024-08-15 14:01:40,159] Trial 2 finished with value: 0.5003389447002771 and parameters: {'num_hidden': 1, 'hidden_dim': 3, 'base_fun': 'silu', 'grid': 4, 'k': 3, 'noise_scale': 0.15092634667098273, 'noise_scale_base': 0.16915885847649653, 'early_stopping': False, 'lr': 0.013855793600165379, 'steps': 108, 'lamb': 0.011767120092815676, 'lamb_entropy': 13, 'lamb_coef': 0, 'prune_threshold': 0.033334356170348536}. Best is trial 0 with value: 0.5692194281282181.
Fold 0 c-index: 0.6419140451927338
Fold 1 c-index: 0.5866606170598911
Fold 2 c-index: 0.5886822814232076


Best trial: 3. Best value: 0.599715:  40%|████      | 4/10 [00:24<00:33,  5.65s/it]

Fold 3 c-index: 0.5816010614772225
[I 2024-08-15 14:01:43,966] Trial 3 finished with value: 0.5997145012882636 and parameters: {'num_hidden': 1, 'hidden_dim': 3, 'base_fun': 'linear', 'grid': 5, 'k': 3, 'noise_scale': 0.08365213158503833, 'noise_scale_base': 0.1936601698253311, 'early_stopping': False, 'lr': 0.005006356684698013, 'steps': 137, 'lamb': 0.0054486908189243875, 'lamb_entropy': 6, 'lamb_coef': 0, 'prune_threshold': 0.0020421609831000223}. Best is trial 3 with value: 0.5997145012882636.
Fold 0 c-index: 0.6404962339388569
Fold 1 c-index: 0.5934664246823956
Fold 2 c-index: 0.5473806543894153


Best trial: 3. Best value: 0.599715:  50%|█████     | 5/10 [00:26<00:21,  4.21s/it]

Fold 3 c-index: 0.5749668288367978
[I 2024-08-15 14:01:45,627] Trial 4 finished with value: 0.5890775354618665 and parameters: {'num_hidden': 0, 'hidden_dim': 3, 'base_fun': 'silu', 'grid': 5, 'k': 3, 'noise_scale': 0.03389261833046593, 'noise_scale_base': 0.06292264914726667, 'early_stopping': False, 'lr': 0.07261973176751384, 'steps': 111, 'lamb': 0.005915624001054734, 'lamb_entropy': 13, 'lamb_coef': 1, 'prune_threshold': 0.041795388147832443}. Best is trial 3 with value: 0.5997145012882636.
Best model loaded (early stopping).
Fold 0 c-index: 0.5945059813912273
Best model loaded (early stopping).
Fold 1 c-index: 0.5582577132486388
Best model loaded (early stopping).
Fold 2 c-index: 0.5862685499731808


Best trial: 3. Best value: 0.599715:  60%|██████    | 6/10 [00:34<00:22,  5.72s/it]

Best model loaded (early stopping).
Fold 3 c-index: 0.545422379478107
[I 2024-08-15 14:01:54,268] Trial 5 finished with value: 0.5711136560227885 and parameters: {'num_hidden': 1, 'hidden_dim': 2, 'base_fun': 'linear', 'grid': 3, 'k': 3, 'noise_scale': 0.19678864651095512, 'noise_scale_base': 0.04629897529137772, 'early_stopping': True, 'lr': 0.004340819641538362, 'steps': 132, 'lamb': 0.013312357931531689, 'lamb_entropy': 13, 'lamb_coef': 2, 'prune_threshold': 0.0018430695330362346}. Best is trial 3 with value: 0.5997145012882636.
Best model loaded (early stopping).
Fold 0 c-index: 0.6445724412937528
Best model loaded (early stopping).
Fold 1 c-index: 0.5779491833030853
Best model loaded (early stopping).
Fold 2 c-index: 0.5539066690505989


Best trial: 3. Best value: 0.599715:  70%|███████   | 7/10 [00:48<00:24,  8.31s/it]

Best model loaded (early stopping).
Fold 3 c-index: 0.5951348960636886
[I 2024-08-15 14:02:07,912] Trial 6 finished with value: 0.5928907974277815 and parameters: {'num_hidden': 1, 'hidden_dim': 5, 'base_fun': 'linear', 'grid': 3, 'k': 3, 'noise_scale': 0.1107043559274616, 'noise_scale_base': 0.1611410345119494, 'early_stopping': True, 'lr': 0.007458682185923105, 'steps': 135, 'lamb': 0.0011069010646453624, 'lamb_entropy': 11, 'lamb_coef': 5, 'prune_threshold': 0.03098769194034892}. Best is trial 3 with value: 0.5997145012882636.
Fold 0 c-index: 0.6227735932653965
Fold 1 c-index: 0.5574410163339383
Fold 2 c-index: 0.5652601466118362


Best trial: 3. Best value: 0.599715:  80%|████████  | 8/10 [00:49<00:12,  6.09s/it]

Fold 3 c-index: 0.567359575409111
[I 2024-08-15 14:02:09,264] Trial 7 finished with value: 0.5782085829050705 and parameters: {'num_hidden': 0, 'hidden_dim': 3, 'base_fun': 'silu', 'grid': 3, 'k': 3, 'noise_scale': 0.15087783173212, 'noise_scale_base': 0.010510014311042416, 'early_stopping': False, 'lr': 0.001874659985300332, 'steps': 97, 'lamb': 0.010625574038749704, 'lamb_entropy': 3, 'lamb_coef': 0, 'prune_threshold': 0.011714877789290357}. Best is trial 3 with value: 0.5997145012882636.
Fold 0 c-index: 0.6376606114311032
Fold 1 c-index: 0.5601633393829402
Fold 2 c-index: 0.5764348292508493


Best trial: 3. Best value: 0.599715:  90%|█████████ | 9/10 [00:52<00:04,  4.97s/it]

Fold 3 c-index: 0.6018575851393189
[I 2024-08-15 14:02:11,779] Trial 8 finished with value: 0.5940290913010529 and parameters: {'num_hidden': 1, 'hidden_dim': 3, 'base_fun': 'linear', 'grid': 3, 'k': 3, 'noise_scale': 0.04087160724549735, 'noise_scale_base': 0.14106211885762818, 'early_stopping': False, 'lr': 0.014822313350534654, 'steps': 78, 'lamb': 0.011667028202088986, 'lamb_entropy': 8, 'lamb_coef': 3, 'prune_threshold': 0.005129942095985285}. Best is trial 3 with value: 0.5997145012882636.


Best trial: 3. Best value: 0.599715: 100%|██████████| 10/10 [00:52<00:00,  5.29s/it]

Fold 0 c-index: 0.5293752769162605
[I 2024-08-15 14:02:12,156] Trial 9 pruned. 





In [7]:
import yaml 

with open('./result_cv.yml', 'r') as f:
    config = yaml.safe_load(f)

ckan = CoxKAN(**config['init_params'])

if config['train_params']['early_stopping']:
    ckan.train(df_train, df_val, duration_col='duration', event_col='event', **config['train_params'])
else:
    ckan.train(df_train_full, None, duration_col='duration', event_col='event', **config['train_params'])

ckan.prune_edges(config['prune_threshold'])

ckan.cindex(df_test)

train loss: 2.38e+00: 100%|███████████████████████████████████████| 137/137 [00:01<00:00, 68.76it/s]


Pruned activation (0,0,1)
Pruned activation (0,1,0)
Pruned activation (0,1,1)
Pruned activation (1,1,0)


0.6173059768064229