In [1]:
# ============================
# 0. Import Libraries
# ============================
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler
from pycox.models import CoxPH
from pycox.evaluation import EvalSurv
import torchtuples as tt

# ============================
# 1. Configuration
# ============================
config = {
    'lstm_hidden_size': 5,
    'lstm_layers': 3,
    'lstm_dropout': 0.3,
    'batch_size': 64,
    'epochs_lstm_cox': 512,
}

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def to_device(x):
    if isinstance(x, (tuple, list)):
        return tuple(to_device(xx) for xx in x)
    return x.to(device)

# ============================
# 2. Load and Prepare Dataset
# ============================
data = pd.read_csv("D:\\Final_Year_Project\\NHANES_11_14_survPA_Python.csv")
data_act = pd.read_csv("D:\\Final_Year_Project\\activity.csv")
df = pd.concat([data, data_act], axis=1)
df['row_id'] = df.index

# Shuffle and split
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
df_val = df.sample(frac=0.1, random_state=42).reset_index(drop=True)
df_train = df.drop(df_val.index).reset_index(drop=True)

# ============================
# 3. Feature Columns
# ============================
static_cols = ['Age', 'Race', 'BMI', 'sex', 'Mobility', 'diabetes.y', 'poverty_level',
               'Asthma', 'Arthritis', 'heart_failure', 'coronary_heart_disease',
               'angina', 'stroke', 'thyroid', 'bronchitis', 'cancer']
time_cols = [str(i) for i in range(1, 1441)]

def prepare_data(df):
    X_static = df[static_cols].copy()
    X_time = df[time_cols].values.reshape(len(df), 1440, 1)
    y_event = df['mortstat'].values.astype(np.float32)
    y_time = df['time_mort'].values.astype(np.float32)
    row_ids = df['row_id'].values
    return X_static, X_time.astype(np.float32), y_time, y_event, row_ids

X_train_static, X_train_time, time_train, y_train_event, row_ids_train = prepare_data(df_train)
X_val_static, X_val_time, time_val, y_val_event, row_ids_val = prepare_data(df_val)

# ============================
# 4. BiLSTM Feature Extractor
# ============================
class BiLSTMFeatureExtractor(nn.Module):
    def __init__(self, input_size=1, hidden_size=9, num_layers=2, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            dropout=dropout,
                            batch_first=True,
                            bidirectional=True)

    def forward(self, x):
        _, (hn, _) = self.lstm(x)
        hn_forward = hn[-2]
        hn_backward = hn[-1]
        return torch.cat((hn_forward, hn_backward), dim=1)

# ============================
# 5. Wrapper for CoxPH Model
# ============================
class FeatureWrapperForCox(nn.Module):
    def __init__(self, lstm_model, static_input_dim, hidden_dim):
        super().__init__()
        self.lstm_model = lstm_model
        self.linear = nn.Linear(2 * hidden_dim + static_input_dim, 1)

    def forward(self, X_time, X_static):
        lstm_feat = self.lstm_model(X_time)
        return self.linear(torch.cat([lstm_feat, X_static], dim=1))

# ============================
# 6. Train BiLSTM with CoxPH Loss
# ============================
def train_bilstm_cox(X_train_time, X_train_static, durations_train, events_train,
                     X_val_time, X_val_static, durations_val, events_val):

    scaler = StandardScaler()
    X_train_static = scaler.fit_transform(X_train_static).astype(np.float32)
    X_val_static = scaler.transform(X_val_static).astype(np.float32)

    X_train_time, X_train_static, durations_train, events_train = to_device((
        torch.tensor(X_train_time), torch.tensor(X_train_static),
        torch.tensor(durations_train), torch.tensor(events_train)))
    X_val_time, X_val_static, durations_val, events_val = to_device((
        torch.tensor(X_val_time), torch.tensor(X_val_static),
        torch.tensor(durations_val), torch.tensor(events_val)))

    y_train = (durations_train, events_train)
    y_val = (durations_val, events_val)

    lstm_model = BiLSTMFeatureExtractor(
        hidden_size=config['lstm_hidden_size'],
        num_layers=config['lstm_layers'],
        dropout=config['lstm_dropout']
    ).to(device)

    feature_net = FeatureWrapperForCox(
        lstm_model, X_train_static.shape[1], config['lstm_hidden_size']
    ).to(device)

    model = CoxPH(feature_net, tt.optim.Adam)
    batch_size = config['batch_size']
    model.optimizer.set_lr(model.lr_finder((X_train_time, X_train_static), y_train, batch_size=batch_size).get_best_lr())

    model.fit((X_train_time, X_train_static), y_train, batch_size, epochs=config['epochs_lstm_cox'],
              callbacks=[tt.callbacks.EarlyStopping()],
              val_data=((X_val_time, X_val_static), y_val), val_batch_size=batch_size)

    model.compute_baseline_hazards()
    surv = model.predict_surv_df((X_val_time, X_val_static))
    ev = EvalSurv(surv, durations_val.cpu().numpy(), events_val.cpu().numpy(), censor_surv='km')
    print("BiLSTM-CoxPH Concordance Index:", ev.concordance_td())

    return lstm_model, scaler

# ============================
# 7. Extract Only Dynamic (LSTM) Features
# ============================
def extract_lstm_only_features(lstm_model, X_time):
    X_time = torch.tensor(X_time).to(device)
    with torch.no_grad():
        lstm_feat = lstm_model(X_time)
    return lstm_feat.cpu().numpy()

# ============================
# 8. Train and Extract Features
# ============================
lstm_model, scaler = train_bilstm_cox(X_train_time, X_train_static, time_train, y_train_event,
                                      X_val_time, X_val_static, time_val, y_val_event)

train_lstm_features = extract_lstm_only_features(lstm_model, X_train_time)
val_lstm_features = extract_lstm_only_features(lstm_model, X_val_time)

0:	[1s / 1s],		train_loss: 3.6418,	val_loss: 3.1728
1:	[1s / 2s],		train_loss: 3.5117,	val_loss: 3.2201
2:	[1s / 3s],		train_loss: 3.5215,	val_loss: 3.2181
3:	[1s / 4s],		train_loss: 3.5306,	val_loss: 3.2120
4:	[1s / 5s],		train_loss: 3.5230,	val_loss: 3.1649
5:	[1s / 6s],		train_loss: 3.5299,	val_loss: 3.1980
6:	[1s / 7s],		train_loss: 3.5166,	val_loss: 3.2092
7:	[1s / 8s],		train_loss: 3.5074,	val_loss: 3.2968
8:	[1s / 9s],		train_loss: 3.4923,	val_loss: 3.2343
9:	[1s / 10s],		train_loss: 3.4947,	val_loss: 3.2118
10:	[1s / 11s],		train_loss: 3.5061,	val_loss: 3.2786
11:	[1s / 12s],		train_loss: 3.5194,	val_loss: 3.2123
12:	[1s / 14s],		train_loss: 3.5264,	val_loss: 3.2045
13:	[1s / 15s],		train_loss: 3.4778,	val_loss: 3.2661
14:	[1s / 16s],		train_loss: 3.5083,	val_loss: 3.2898
BiLSTM-CoxPH Concordance Index: 0.8022700532908852


In [2]:
# ============================
# 9. Save to DataFrames and Merge
# ============================
train_feat_df = pd.DataFrame(train_lstm_features, columns=[f"dyn_feat_{i}" for i in range(train_lstm_features.shape[1])])
train_feat_df['row_id'] = row_ids_train

val_feat_df = pd.DataFrame(val_lstm_features, columns=[f"dyn_feat_{i}" for i in range(val_lstm_features.shape[1])])
val_feat_df['row_id'] = row_ids_val

train_merged = pd.merge(df_train[static_cols + ['time_mort', 'mortstat', 'row_id']], train_feat_df, on='row_id')
val_merged = pd.merge(df_val[static_cols + ['time_mort', 'mortstat', 'row_id']], val_feat_df, on='row_id')

In [3]:
# Concatenate row-wise
df_LSTM = pd.concat([train_merged, val_merged], axis=0).reset_index(drop=True)

In [4]:
df_LSTM.columns

Index(['Age', 'Race', 'BMI', 'sex', 'Mobility', 'diabetes.y', 'poverty_level',
       'Asthma', 'Arthritis', 'heart_failure', 'coronary_heart_disease',
       'angina', 'stroke', 'thyroid', 'bronchitis', 'cancer', 'time_mort',
       'mortstat', 'row_id', 'dyn_feat_0', 'dyn_feat_1', 'dyn_feat_2',
       'dyn_feat_3', 'dyn_feat_4', 'dyn_feat_5', 'dyn_feat_6', 'dyn_feat_7',
       'dyn_feat_8', 'dyn_feat_9'],
      dtype='object')

In [8]:
selected_df = df_LSTM[['row_id', 'dyn_feat_0', 'dyn_feat_1', 'dyn_feat_2','dyn_feat_3', 'dyn_feat_4', 'dyn_feat_5', 
                  'dyn_feat_6', 'dyn_feat_7','dyn_feat_8', 'dyn_feat_9']]

In [9]:
merged_df = pd.merge(df, selected_df, on='row_id', how='inner')

In [10]:
final_df = merged_df[['Age', 'Race', 'BMI', 'sex', 'Mobility', 'diabetes.y', 'poverty_level','Asthma', 'Arthritis', 
                      'heart_failure', 'coronary_heart_disease','angina', 'stroke', 'thyroid', 'bronchitis', 'cancer', 
                      'time_mort','mortstat','PC1','PC2','PC3','PC4','PC5','PC6','PC7','PC8','PC9','dyn_feat_0', 'dyn_feat_1', 
                      'dyn_feat_2','dyn_feat_3', 'dyn_feat_4', 'dyn_feat_5','dyn_feat_6', 'dyn_feat_7','dyn_feat_8', 'dyn_feat_9']]

In [11]:
import pandas as pd
from sklearn.model_selection import train_test_split

# Ensure 'mortstat' exists and it's the event indicator column (True/False)
if 'mortstat' not in df.columns:
    raise ValueError("'mortstat' column not found in the data.")

# Define your target columns
y = final_df['mortstat']  # Event indicator column
X = final_df.drop(columns=['mortstat', 'time_mort'])  # Drop target and time column for features
time_col = 'time_mort'  # Assuming 'time_mort' is your time column

# Perform 100 stratified splits and save
for i in range(1, 101):
    print(f"Generating split {i}...")

    # Perform stratified split (80/20)
    train, test = train_test_split(final_df, test_size=0.2, stratify=y, random_state=i)

    # Save to CSV
    train_file = f"train_split_{i}.csv"
    test_file = f"test_split_{i}.csv"

    train.to_csv(train_file, index=False)
    test.to_csv(test_file, index=False)

    print(f"Split {i} saved as {train_file} and {test_file}")

print("\nAll splits generated and saved successfully!")

Generating split 1...
Split 1 saved as train_split_1.csv and test_split_1.csv
Generating split 2...
Split 2 saved as train_split_2.csv and test_split_2.csv
Generating split 3...
Split 3 saved as train_split_3.csv and test_split_3.csv
Generating split 4...
Split 4 saved as train_split_4.csv and test_split_4.csv
Generating split 5...
Split 5 saved as train_split_5.csv and test_split_5.csv
Generating split 6...
Split 6 saved as train_split_6.csv and test_split_6.csv
Generating split 7...
Split 7 saved as train_split_7.csv and test_split_7.csv
Generating split 8...
Split 8 saved as train_split_8.csv and test_split_8.csv
Generating split 9...
Split 9 saved as train_split_9.csv and test_split_9.csv
Generating split 10...
Split 10 saved as train_split_10.csv and test_split_10.csv
Generating split 11...
Split 11 saved as train_split_11.csv and test_split_11.csv
Generating split 12...
Split 12 saved as train_split_12.csv and test_split_12.csv
Generating split 13...
Split 13 saved as train_split