In [1]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn_pandas import DataFrameMapper
import torch
import torchtuples as tt
from pycox.models import CoxPH
from pycox.evaluation import EvalSurv
import os

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Fixed columns
cols_standardize = ['Age','BMI', 'poverty_level', 'PC1','PC2','PC3','PC4','PC5','PC6','PC7','PC8','PC9']
cols_leave = ['Race', 'sex', 'Mobility', 'diabetes.y', 'Asthma', 'Arthritis', 'heart_failure', 'coronary_heart_disease', 
              'angina', 'stroke', 'thyroid', 'bronchitis', 'cancer']

standardize = [([col], StandardScaler()) for col in cols_standardize]
leave = [(col, None) for col in cols_leave]
x_mapper = DataFrameMapper(standardize + leave)

# Model architecture
def get_model(input_dim):
    num_nodes = [256, 256]
    net = tt.practical.MLPVanilla(input_dim, num_nodes, 1, batch_norm=True, dropout=0.7, output_bias=False).to(device)
    return CoxPH(net, tt.optim.Adam)

# For C-index storage
c_index_list = []

# Directory containing the splits
split_dir = "D:/Final_Year_Project/splits_final/"  # Modify this if needed

for i in range(1, 101):
    print(f"Processing split {i}...")

    # Load train/test split
    df_train = pd.read_csv(os.path.join(split_dir, f"train_split_{i}.csv"))
    df_test = pd.read_csv(os.path.join(split_dir, f"test_split_{i}.csv"))

    # Fit transformer only on training data
    x_train = torch.tensor(x_mapper.fit_transform(df_train).astype('float32')).to(device)
    x_val = torch.tensor(x_mapper.transform(df_test).astype('float32')).to(device)
    x_test = torch.tensor(x_mapper.transform(df_test).astype('float32')).to(device)

    # Get target
    get_target = lambda df: (
        torch.tensor(df['time_mort'].values, dtype=torch.float32).to(device),
        torch.tensor(df['mortstat'].values, dtype=torch.float32).to(device)
    )
    y_train = get_target(df_train)
    y_val = get_target(df_test)
    durations_test, events_test = get_target(df_test)
    val = x_val, y_val

    durations_np = durations_test.cpu().numpy()
    events_np = events_test.cpu().numpy()

    # Model
    model = get_model(x_train.shape[1])

    # Find LR
    batch_size = 64
    lrfinder = model.lr_finder(x_train, y_train, batch_size, tolerance=10)
    best_lr = lrfinder.get_best_lr()
    model.optimizer.set_lr(best_lr)

    # Train
    callbacks = [tt.callbacks.EarlyStopping()]
    model.fit(x_train, y_train, batch_size, 512, callbacks, verbose=False, 
              val_data=val, val_batch_size=batch_size)

    # Predict survival
    model.compute_baseline_hazards()
    surv = model.predict_surv_df(x_test)
    ev = EvalSurv(surv, durations_np, events_np, censor_surv='km')
    c_index = ev.concordance_td()
    c_index_list.append(c_index)

# Save C-indices
c_index_df = pd.DataFrame({'split': list(range(1, 101)), 'c_index': c_index_list})
# c_index_df.to_csv("D:/Final_Year_Project/c_index_results.csv", index=False)

print("All splits processed. Results saved to c_index_results.csv.")

Processing split 1...
Processing split 2...
Processing split 3...
Processing split 4...
Processing split 5...
Processing split 6...
Processing split 7...
Processing split 8...
Processing split 9...
Processing split 10...
Processing split 11...
Processing split 12...
Processing split 13...
Processing split 14...
Processing split 15...
Processing split 16...
Processing split 17...
Processing split 18...
Processing split 19...
Processing split 20...
Processing split 21...
Processing split 22...
Processing split 23...
Processing split 24...
Processing split 25...
Processing split 26...
Processing split 27...
Processing split 28...
Processing split 29...
Processing split 30...
Processing split 31...
Processing split 32...
Processing split 33...
Processing split 34...
Processing split 35...
Processing split 36...
Processing split 37...
Processing split 38...
Processing split 39...
Processing split 40...
Processing split 41...
Processing split 42...
Processing split 43...
Processing split 44.

In [2]:
np.mean(c_index_list)

0.7817317342027168

In [4]:
c_index_df.to_csv("D:/Final_Year_Project/c_index_PCdeepsurv.csv", index=False)