In [1]:
# Simplified and Best Practice Version of Import Statements

import gc
import json
import os
import math
import multiprocessing
import numpy as np
import pandas as pd
import torch
import importlib
import logging
from pathlib import Path
from sklearn.model_selection import GroupKFold, GroupShuffleSplit

# Pycox and PyTorch tuples for survival analysis
import torchtuples as tt
import pycox
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from pycox.models import CoxPH, DeepHit
from pycox.evaluation import EvalSurv

# Ray for hyperparameter tuning and distributed processing
import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.search.bayesopt import BayesOptSearch
from ray.tune.search.optuna import OptunaSearch
from ray.tune.search import ConcurrencyLimiter
from ray.tune.schedulers import ASHAScheduler, PopulationBasedTraining
from ray.air import session
import ray.cloudpickle as pickle

# Custom modules for data handling, balancing, training, evaluation, and model architectures
import dataloader2
import databalancer2
import datatrainer2
import modeleval
import netweaver2

# Reload custom modules to ensure latest changes are available
importlib.reload(dataloader2)
importlib.reload(databalancer2)
importlib.reload(datatrainer2)
importlib.reload(modeleval)
importlib.reload(netweaver2)

# Import specific functions from custom modules to keep code clean and readable
from netweaver2 import (
    lstm_net_init, DHANNWrapper, LSTMWrapper, generalized_ann_net_init
)
from dataloader2 import (
    load_and_transform_data, preprocess_data #stack_sequences, dh_dataset_loader
)
from databalancer2 import (
    define_medoid_general, df_event_focus, rebalance_data, underbalance_data_general, medoid_cluster, 
    dh_rebalance_data
)
from datatrainer2 import (
    recursive_clustering, prepare_training_data, 
    prepare_validation_data, lstm_training
)
from modeleval import (
    dh_test_model, nam_dagostino_chi2, get_baseline_hazard_at_timepoints
)

import psutil
torch.cuda.empty_cache()
gc.collect()

# Define Constants and Load Datasets
RANDOM_SEED = 12345
N_SPLIT = 10
FEATURE_COLS = ['gender', 'dm', 'ht', 'sprint', 'a1c', 'po4', 'UACR_mg_g', 'Cr', 'age', 'alb', 'ca', 'hb', 'hco3']
DURATION_COL = 'date_from_sub_60'
EVENT_COL = 'endpoint'
CLUSTER_COL = 'key'
TIME_GRID = np.array([i * 365 for i in range(6)])

# Define Feature Groups
CAT_FEATURES = ['gender', 'dm', 'ht', 'sprint']
LOG_FEATURES = ['a1c', 'po4', 'UACR_mg_g', 'Cr']
STANDARD_FEATURES = ['age', 'alb', 'ca', 'hb', 'hco3']
PASSTHROUGH_FEATURES = ['key', 'date_from_sub_60', 'endpoint']

# Load and Transform Data
BASE_FILENAME = '/mnt/d/pydatascience/g3_regress/data/X/X_20240628'
X_train_transformed, X_test_transformed = load_and_transform_data(
    BASE_FILENAME, CAT_FEATURES, LOG_FEATURES, STANDARD_FEATURES, PASSTHROUGH_FEATURES
)

2024-11-03 23:15:30,703 - INFO - Transforming training data...
2024-11-03 23:15:46,784 - INFO - Transforming test data...


In [11]:
def create_neural_network(config):
    """
    Function to create a neural network based on the given configuration.

    Args:
        config (dict): Configuration dictionary containing model type, network type, and hyperparameters.

    Returns:
        torch.nn.Module: Created neural network model.
    """
    gc.collect()
    torch.cuda.empty_cache()

    # Create the Neural Network
    if config['net'] == 'ann':
        net = generalized_ann_net_init(
            input_size=len(config['features']),
            num_nodes=config["num_nodes"],
            batch_norm=config["batch_norm"],
            dropout=config["dropout"],
            output_size=1  # Default output size for DeepSurv
        )
    elif config['net'] == 'lstm':
        net = lstm_net_init(
            input_size=len(config['features']),
            num_nodes=config["num_nodes"],
            batch_norm=config["batch_norm"],
            dropout=config["dropout"]
        )
    else:
        raise ValueError("Unknown network type: {}".format(config['net']))

    optimizer = tt.optim.AdamWR(decoupled_weight_decay=1e-6, cycle_eta_multiplier=0.8)
    if config['model'] == 'deepsurv':
        model = CoxPH(net, optimizer)
    elif config['model'] == 'deephit':
        model = DeepHit(net, optimizer)
    model.optimizer.set_lr(config["lr"])
    
    return model

def train_neural_network(model, config, X_train, X_val, duration_col, event_col, cluster_col, callbacks, time_grid=None):
    """
    Function to train a given neural network using the provided datasets.

    Args:
        net (torch.nn.Module): Neural network to be trained.
        config (dict): Configuration dictionary containing model hyperparameters.
        X_train (pd.DataFrame): Training dataset with features.
        X_val (pd.DataFrame): Validation dataset with features.
        duration_col (str): Column representing event durations.
        event_col (str): Column representing event occurrences.
        cluster_col (str): Column for grouping during cross-validation.
        callbacks (list): List of callbacks for training.
        time_grid (np.array, optional): Time grid for evaluation if required. Defaults to None.

    Returns:
        model: Trained PyCox model.
        logs: Training logs.
    """
    gc.collect()
    torch.cuda.empty_cache()

    
    # Prepare validation data (features and target labels)
    X_val_processed, y_val = preprocess_data(X_val, config['features'], duration_col, event_col)
    val_data = (X_val_processed, y_val)

    # Train the model
    if config['model'] == 'deepsurv':
        if config['net'] == 'ann':
            if config['balance_method'] == 'clustering':
                model, logs = recursive_clustering(model, X_train, duration_col, event_col, config, val_data, callbacks, max_repeats=30)
        elif config['net'] == 'lstm':
            if config['balance_method'] == 'clustering':
                model, logs = lstm_training(model, X_train, X_val, duration_col, event_col, cluster_col, config, callbacks, time_grid)

    # Free memory after training
    gc.collect()
    torch.cuda.empty_cache()

    return model, logs

In [9]:
config = {
    'model': 'deepsurv',
    'net': 'lstm',
    'features': ['gender', 'dm', 'ht', 'sprint', 'a1c', 'po4', 'UACR_mg_g', 'Cr', 'age', 'alb', 'ca', 'hb', 'hco3'],
    'endpoint': 1,
    'num_nodes': [8, 4],
    'batch_norm': False,
    'dropout': 0.1144793446270997,
    'lr': 0.1,
    'max_epochs': 9,
    'batch_size': 512,
    'sampling_strategy': 0.05,
    'seq_length': 3,
}

if config['net'] == 'ann':
    net = generalized_ann_net_init(
        input_size=len(config['features']),
        num_nodes=config["num_nodes"],
        batch_norm=config["batch_norm"],
        dropout=config["dropout"],
        output_size=1  # Default output size for DeepSurv
    )
elif config['net'] == 'lstm':
    net = lstm_net_init(
        input_size=len(config['features']),
        num_nodes=config["num_nodes"],
        batch_norm=config["batch_norm"],
        dropout=config["dropout"]
    )

optimizer = tt.optim.AdamWR(decoupled_weight_decay=1e-6, cycle_eta_multiplier=0.8)
model = CoxPH(net, optimizer)
model.optimizer.set_lr(config["lr"])
callbacks = [tt.cb.EarlyStopping()]

In [9]:
model_type = config['model']
event_focus = config['endpoint']
feature_col = config['features']

gc.collect()
# prepare_training_data(df, feature_col, duration_col, event_col, params, cluster_col, clustering_method='define_medoid', time_grid=None)
X_train, y_train = prepare_training_data(X_train_transformed_3, feature_col, DURATION_COL, EVENT_COL, config, CLUSTER_COL, model_type, TIME_GRID)
X_val, y_val = prepare_validation_data(X_val_fold, feature_col, DURATION_COL, EVENT_COL, config, CLUSTER_COL, model_type, TIME_GRID)
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = (torch.tensor(y_train[0], dtype=torch.float32), torch.tensor(y_train[1], dtype=torch.float32))

X_val_tensor = torch.tensor(X_val, dtype=torch.float32)
y_val_tensor = (torch.tensor(y_val[0], dtype=torch.float32), torch.tensor(y_val[1], dtype=torch.float32))
val_data = (X_val_tensor, y_val_tensor)

dataset_size = X_train_tensor.size()[0]
batch_size = min(config['batch_size'], dataset_size)
# if dataset_size % batch_size == 1:
#     batch_size = math.ceil(dataset_size / (math.floor(dataset_size / batch_size) + 1))

2024-11-03 23:45:58,875 - INFO - Event column 'endpoint' updated with focus on event value 1.
2024-11-03 23:45:58,882 - INFO - Performing clustering iteration 1 / 20
2024-11-03 23:45:58,883 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 23:45:58,887 - INFO - Event column 'endpoint' updated with focus on event value 1.
2024-11-03 23:45:59,630 - INFO - Defined medoid for deepsurv model with 1207 clusters.
2024-11-03 23:45:59,631 - INFO - Performing clustering iteration 2 / 20
2024-11-03 23:45:59,632 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 23:45:59,635 - INFO - Event column 'endpoint' updated with focus on event value 1.
2024-11-03 23:46:00,119 - INFO - Defined medoid for deepsurv model with 1207 clusters.
2024-11-03 23:46:00,120 - INFO - Performing clustering iteration 3 / 20
2024-11-03 23:46:00,120 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 23:46:00,124 - INFO - Event column 'endpoint' updated with focus on event 

KeyboardInterrupt: 

In [13]:
gc.collect()
torch.cuda.empty_cache()

config = {
    'model': 'deepsurv',
    'net': 'lstm',
    'balance_method': 'clustering',
    'features': ['gender', 'dm', 'ht', 'sprint', 'a1c', 'po4', 'UACR_mg_g', 'Cr', 'age', 'alb', 'ca', 'hb', 'hco3'],
    'endpoint': 1,
    'num_nodes': [8, 4],
    'batch_norm': False,
    'dropout': 0.1144793446270997,
    'lr': 0.1,
    'max_epochs': 9,
    'batch_size': 512,
    'sampling_strategy': 0.05,
    'seq_length': 3,
}
# Split final validation (fin_val) data for meta-learner
gss1 = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=RANDOM_SEED)
for train_idx_1, fin_val_idx in gss1.split(X=X_train_transformed[FEATURE_COLS], y=X_train_transformed[EVENT_COL], groups=X_train_transformed[CLUSTER_COL]):
    X_train_transformed_2, X_fin_val = X_train_transformed.iloc[train_idx_1, :], X_train_transformed.iloc[fin_val_idx, :]
    gc.collect()
    
    gss2 = GroupShuffleSplit(n_splits=1, test_size=0.2, random_state=RANDOM_SEED)
    test_results = []
    brier_df = pd.DataFrame()
    for train_idx, val_idx in gss2.split(X=X_train_transformed_2[FEATURE_COLS], y=X_train_transformed_2[EVENT_COL], groups=X_train_transformed_2[CLUSTER_COL]):
        gc.collect()
        torch.cuda.empty_cache()
        callbacks = [tt.cb.EarlyStopping()]
        
        X_train_transformed_3 = X_train_transformed_2.iloc[train_idx]
        X_val_fold = X_train_transformed_2.iloc[val_idx]
        X_val, y_val = preprocess_data(X_val_fold, config['features'], DURATION_COL, EVENT_COL)
        val = (X_val, y_val)

        model = create_neural_network(config)
    
        # model, logs = recursive_clustering(model, X_train_transformed_3, DURATION_COL, EVENT_COL, config, val, callbacks, max_repeats=30)
        # model, logs = lstm_training(model, X_train_transformed_3, X_val_fold, DURATION_COL, EVENT_COL, CLUSTER_COL, config, callbacks, TIME_GRID)
        model, logs = train_neural_network(model, config, X_train=X_train_transformed_3, X_val=X_val_fold, duration_col=DURATION_COL,
                                                    event_col=EVENT_COL, cluster_col=CLUSTER_COL, callbacks=callbacks, time_grid=TIME_GRID)

del model, logs

gc.collect()
torch.cuda.empty_cache()
    

2024-11-04 00:04:27,580 - INFO - Event column 'endpoint' updated with focus on event value 1.
2024-11-04 00:04:27,583 - INFO - Performing clustering iteration 1 / 20
2024-11-04 00:04:27,584 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-04 00:04:27,586 - INFO - Event column 'endpoint' updated with focus on event value 1.
2024-11-04 00:04:28,231 - INFO - Defined medoid for deepsurv model with 1207 clusters.
2024-11-04 00:04:28,232 - INFO - Performing clustering iteration 2 / 20
2024-11-04 00:04:28,233 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-04 00:04:28,237 - INFO - Event column 'endpoint' updated with focus on event value 1.
2024-11-04 00:04:28,685 - INFO - Defined medoid for deepsurv model with 1207 clusters.
2024-11-04 00:04:28,686 - INFO - Performing clustering iteration 3 / 20
2024-11-04 00:04:28,687 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-04 00:04:28,691 - INFO - Event column 'endpoint' updated with focus on event 

0:	[0s / 0s],		train_loss: 3.1756,	val_loss: 5.9737
1:	[0s / 1s],		train_loss: 2.4935,	val_loss: 5.4709
2:	[0s / 2s],		train_loss: 2.3906,	val_loss: 5.6250
3:	[0s / 3s],		train_loss: 2.3369,	val_loss: 5.4939
4:	[0s / 3s],		train_loss: 2.2701,	val_loss: 5.5683
5:	[0s / 4s],		train_loss: 2.2454,	val_loss: 5.4326
6:	[0s / 5s],		train_loss: 2.2264,	val_loss: 5.3574
7:	[0s / 5s],		train_loss: 2.3140,	val_loss: 5.2434
8:	[0s / 6s],		train_loss: 2.2340,	val_loss: 5.4267


  self.net.load_state_dict(torch.load(path, **kwargs))


In [None]:
recursive_clustering(model, X_train_transformed, config['endpoint'], DURATION_COL, EVENT_COL, config['features'], config, val, callbacks, max_repeats=30)

In [38]:
gc.collect()
torch.cuda.empty_cache()

repeat_count = 0
logs = []
model_type = config['model']
event_focus = config['endpoint']
feature_col = config['features']
max_repeats = 30

remaining_data = df_event_focus(df=X_train_transformed, event_col=EVENT_COL, event_focus=config['endpoint']) if model_type == 'deepsurv' else df.copy()
df_minor = remaining_data[remaining_data[EVENT_COL] == event_focus].copy() if model_type == 'deepsurv' else remaining_data[remaining_data[EVENT_COL] != 0].copy()
df_major = remaining_data[remaining_data[EVENT_COL] != event_focus].copy() if model_type == 'deepsurv' else remaining_data[remaining_data[EVENT_COL] == 0].copy()

goal = round(len(df_major) / len(df_minor)) - 1 if max_repeats == -1 else round(1 / config['sampling_strategy'])

while len(remaining_data) > 0 and repeat_count < goal:
    logging.info(f"Performing clustering iteration {repeat_count + 1} / {goal}")
    if model_type == 'deepsurv':
        X_cluster, remaining_data = define_medoid_general(df=remaining_data, feature_col=feature_col, event_col=EVENT_COL)
        X_train_cluster, y_train_cluster = preprocess_data(df=X_cluster, feature_col=feature_col, duration_col=DURATION_COL, event_col=EVENT_COL)
    else:
        X_cluster, remaining_data = define_medoid_general(df=remaining_data, feature_col=feature_col, event_col=EVENT_COL)
        X_train_cluster, y_train_cluster = preprocess_data(X_cluster, feature_col, DURATION_COL, EVENT_COL, TIME_GRID, discretize=True)

    log = model.fit(X_train_cluster, y_train_cluster, config['batch_size'], config['max_epochs'], callbacks, verbose=True)
    logs.append(log)
    gc.collect()

    # Early stopping check
    if callbacks and hasattr(callbacks[0], 'stopped_epoch') and callbacks[0].stopped_epoch > 0:
        logging.info(f"Early stopping at epoch {callbacks[0].stopped_epoch}")
        break

    repeat_count += 1
    
gc.collect()
torch.cuda.empty_cache()

2024-11-03 15:00:55,359 - INFO - Event column 'endpoint' updated with focus on event value 1.
2024-11-03 15:00:56,217 - INFO - Performing clustering iteration 1 / 20
2024-11-03 15:00:56,217 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:00:56,225 - INFO - Event column 'endpoint' updated with focus on event value 1.
2024-11-03 15:00:57,441 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:00:57,622 - INFO - Performing clustering iteration 2 / 20
2024-11-03 15:00:57,623 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:00:57,628 - INFO - Event column 'endpoint' updated with focus on event value 1.


28:	[0s / 0s],		train_loss: 5.0468


2024-11-03 15:00:58,552 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:00:58,774 - INFO - Performing clustering iteration 3 / 20
2024-11-03 15:00:58,775 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:00:58,779 - INFO - Event column 'endpoint' updated with focus on event value 1.


29:	[0s / 0s],		train_loss: 4.9419


2024-11-03 15:00:59,712 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:00:59,896 - INFO - Performing clustering iteration 4 / 20
2024-11-03 15:00:59,897 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:00:59,902 - INFO - Event column 'endpoint' updated with focus on event value 1.


30:	[0s / 0s],		train_loss: 4.9586


2024-11-03 15:01:00,804 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:00,993 - INFO - Performing clustering iteration 5 / 20
2024-11-03 15:01:00,994 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:00,998 - INFO - Event column 'endpoint' updated with focus on event value 1.


31:	[0s / 0s],		train_loss: 4.9585


2024-11-03 15:01:01,954 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:02,176 - INFO - Performing clustering iteration 6 / 20
2024-11-03 15:01:02,176 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:02,181 - INFO - Event column 'endpoint' updated with focus on event value 1.


32:	[0s / 0s],		train_loss: 4.9698


2024-11-03 15:01:03,060 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:03,248 - INFO - Performing clustering iteration 7 / 20
2024-11-03 15:01:03,248 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:03,253 - INFO - Event column 'endpoint' updated with focus on event value 1.


33:	[0s / 0s],		train_loss: 4.9719


2024-11-03 15:01:04,152 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:04,349 - INFO - Performing clustering iteration 8 / 20
2024-11-03 15:01:04,350 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:04,355 - INFO - Event column 'endpoint' updated with focus on event value 1.


34:	[0s / 0s],		train_loss: 4.9622


2024-11-03 15:01:05,217 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:05,432 - INFO - Performing clustering iteration 9 / 20
2024-11-03 15:01:05,433 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:05,437 - INFO - Event column 'endpoint' updated with focus on event value 1.


35:	[0s / 0s],		train_loss: 4.9839


2024-11-03 15:01:06,318 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:06,503 - INFO - Performing clustering iteration 10 / 20
2024-11-03 15:01:06,503 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:06,508 - INFO - Event column 'endpoint' updated with focus on event value 1.


36:	[0s / 0s],		train_loss: 4.9794


2024-11-03 15:01:07,362 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:07,596 - INFO - Performing clustering iteration 11 / 20
2024-11-03 15:01:07,597 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:07,601 - INFO - Event column 'endpoint' updated with focus on event value 1.


37:	[0s / 0s],		train_loss: 5.0072


2024-11-03 15:01:08,525 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:08,711 - INFO - Performing clustering iteration 12 / 20
2024-11-03 15:01:08,712 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:08,716 - INFO - Event column 'endpoint' updated with focus on event value 1.


38:	[0s / 0s],		train_loss: 4.9794


2024-11-03 15:01:09,560 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:09,755 - INFO - Performing clustering iteration 13 / 20
2024-11-03 15:01:09,756 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:09,760 - INFO - Event column 'endpoint' updated with focus on event value 1.


39:	[0s / 0s],		train_loss: 4.9976


2024-11-03 15:01:10,621 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:10,827 - INFO - Performing clustering iteration 14 / 20
2024-11-03 15:01:10,828 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:10,833 - INFO - Event column 'endpoint' updated with focus on event value 1.


40:	[0s / 0s],		train_loss: 5.0132


2024-11-03 15:01:11,696 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:11,907 - INFO - Performing clustering iteration 15 / 20
2024-11-03 15:01:11,908 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:11,913 - INFO - Event column 'endpoint' updated with focus on event value 1.


41:	[0s / 0s],		train_loss: 5.0160


2024-11-03 15:01:12,760 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:12,940 - INFO - Performing clustering iteration 16 / 20
2024-11-03 15:01:12,941 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:12,945 - INFO - Event column 'endpoint' updated with focus on event value 1.


42:	[0s / 0s],		train_loss: 5.0499


2024-11-03 15:01:13,761 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:13,987 - INFO - Performing clustering iteration 17 / 20
2024-11-03 15:01:13,988 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:13,995 - INFO - Event column 'endpoint' updated with focus on event value 1.


43:	[0s / 0s],		train_loss: 5.0269


2024-11-03 15:01:14,836 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:15,048 - INFO - Performing clustering iteration 18 / 20
2024-11-03 15:01:15,049 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:15,055 - INFO - Event column 'endpoint' updated with focus on event value 1.


44:	[0s / 0s],		train_loss: 5.0653


2024-11-03 15:01:15,887 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:16,120 - INFO - Performing clustering iteration 19 / 20
2024-11-03 15:01:16,121 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:16,125 - INFO - Event column 'endpoint' updated with focus on event value 1.


45:	[0s / 0s],		train_loss: 5.0541


2024-11-03 15:01:16,958 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))
2024-11-03 15:01:17,141 - INFO - Performing clustering iteration 20 / 20
2024-11-03 15:01:17,142 - INFO - CUDA environment set up and GPU memory cleared.
2024-11-03 15:01:17,147 - INFO - Event column 'endpoint' updated with focus on event value 1.


46:	[0s / 0s],		train_loss: 5.0698


2024-11-03 15:01:17,945 - INFO - Defined medoid for deepsurv model with 1925 clusters.
  self.net.load_state_dict(torch.load(path, **kwargs))


47:	[0s / 0s],		train_loss: 5.0389


In [31]:
remaining_data[remaining_data[EVENT_COL] == event_focus].copy()

Unnamed: 0,gender,dm,ht,sprint,a1c,po4,UACR_mg_g,Cr,age,alb,ca,hb,hco3,key,date_from_sub_60,endpoint
501,1.0,1.0,1.0,0.0,0.126422,0.718803,0.917315,0.766480,0.688889,0.350878,0.303458,0.309179,0.363289,2695029,740.0,1
724,1.0,1.0,1.0,0.0,0.151458,0.781061,0.832484,0.726740,0.411112,0.578948,0.401560,0.502416,0.267687,4565409,1764.0,1
1302,1.0,1.0,1.0,1.0,0.169871,0.668293,0.874092,0.651472,0.800000,0.456141,0.339003,0.338165,0.191205,1655813,300.0,1
1306,0.0,1.0,0.0,0.0,0.239518,0.657125,0.876405,0.634679,0.677778,0.456141,0.412330,0.429952,0.305928,6052274,1790.0,1
1656,1.0,0.0,1.0,0.0,0.151458,0.746326,0.840789,0.815412,0.577778,0.421053,0.412742,0.454107,0.344169,5608947,1764.0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
395678,1.0,1.0,1.0,1.0,0.246211,0.735924,0.891154,0.714247,0.566667,0.473685,0.304223,0.400967,0.458892,1655764,608.0,1
395860,0.0,0.0,1.0,0.0,0.187555,0.631314,0.886437,0.622796,0.311112,0.385966,0.328057,0.483092,0.363289,415799,482.0,1
395884,0.0,0.0,1.0,0.0,0.161575,0.726289,0.866884,0.713979,0.677778,0.561404,0.298044,0.338165,0.327184,4827892,586.0,1
396200,0.0,1.0,0.0,0.0,0.150980,0.687764,0.915196,0.626973,0.466667,0.578948,0.374166,0.415460,0.325048,1677979,1523.0,1


In [None]:
recursive_clustering(model, df, event_focus, duration_col, event_col, feature_col, params, val, callbacks, max_repeats, model_type='deepsurv')