### Installing packages ###

In [1]:
import numpy as np
import os
import pandas as pd
import matplotlib.pyplot as plt
import sys

import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout, GlobalAveragePooling1D
from tensorflow.keras.models import Model
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
import plotly.graph_objects as go
import plotly.subplots as sp

import survivalnet2
from survivalnet2.data.labels import stack_labels, unstack_labels
from survivalnet2.losses import efron, cox
from survivalnet2.metrics.concordance import HarrellsC
from survivalnet2.visualization import km_plot

# Set random seeds for reproducibility
np.random.seed(51)
tf.random.set_seed(51)

### Data preprocessing ###

In [2]:
def binarize_columns(df):
    # binarizr region identifier as 0 or 1
    df[df.columns[0]] = df.iloc[:, 0].apply(lambda x: 1 if x == 'TUMOR' else 0)
    return df

def compute_median_values(data_files):
    # Read in first file to get the columns
    sample = pd.read_csv(data_files[0])
    num_cols = sample.shape[1] - 2  # Exclude the first two columns

    # Initialize array to store values for all files
    values = np.empty((0, num_cols))

    # Iterate over all files and extract values
    for data_file in data_files:
        # Read in data and skip first row (assumed to be header)
        df = pd.read_csv(data_file, skiprows=[0], usecols=range(2, num_cols+2))
        df = binarize_columns(df)
        values = np.concatenate((values, df.values), axis=0)
    
    # Compute median values for each column
    median_values = np.median(values, axis=0)
    median_dict = {sample.columns[i+2]: median_values[i] for i in range(num_cols)}
    return median_dict


def pad_missing_values(df, median_dict):
    # Replace missing values with median value for each column
    for col in df.columns:
        median_value = median_dict[col]
        df[col].fillna(median_value, inplace=True)
    return df


def min_max_normalize_features(df):

    for col in df.columns:
        min_value = df[col].min()
        max_value = df[col].max()
        
        # Check if values are not already in the range [0, 1]
        if (min_value < 0 or max_value > 1):
            df[col] = (df[col] - min_value) / (max_value - min_value)
        
    return df

def create_label_dict(label_dir):
    df = pd.read_csv(label_dir)
    column_names = df.columns  # Get the column names from the first row
    label_dict = {}
    for i in range(1, len(df)):
        name = df.iloc[i, 0]  # Get the sample name from the first column of the current row
        time = df.iloc[i, column_names.get_loc('ClinicalFeats.Survival.BCSS.YearsFromDx')]  # Get the time data from the 'ClinicalFeats.Survival.BCSS.YearsFromDx' column
        event = df.iloc[i, column_names.get_loc('ClinicalFeats.Survival.BCSS')]  # Get the event data from the 'ClinicalFeats.Survival.BCSS' column
        label_dict[name] = (time, event)
    return label_dict

### Parameters definition ###

In [3]:
# define dimensionality
D = 49
print(f"Dimensionality of each feature vector is: {D}")

# Define the batch size you want to use
batch_size = 64


data_dir = '/Users/shangke/Desktop/pathology/perSlideRegionFeatures/CPSII_40X'
label_dir = '/Users/shangke/Desktop/pathology/FusedData_CPSII_40X.csv'
csv_names = os.listdir(data_dir)
null_count = 0
label_dict = create_label_dict(label_dir)
valid_csv_names = []

for name in csv_names:
    if name.rstrip('.csv') in list(label_dict.keys()):
        valid_csv_names.append(name)

    else:
        null_count += 1

data_files = [os.path.join(data_dir, str(csv_name)) for csv_name in valid_csv_names]

print(f"Number of samples with missing label data: {null_count}")


Dimensionality of each feature vector is: 49
Number of samples with missing label data: 54


### Generate dataset ###

In [15]:
def read_data(data_files, label_dict):
    """
    Reads in the data files, binarizes the columns, pads missing values, and normalizes the features.

    Args:
        data_files (list): List of file paths to the data files.
        label_dict (dict): Dictionary mapping file names to (time, event) tuples.

    Returns:
        A tuple containing:
            - rows_tensor (tf.RaggedTensor): A ragged tensor containing the feature vectors for each sample.
            - labels_tensor (tuple): A tuple of two tensors, containing the time and event labels for each sample.
    """
    rows_list = []
    time_list = []
    event_list = []
    empty_count = 0

    # Calculate median values for each feature across all data files
    median_values = compute_median_values(data_files)

    for data_file in data_files:
        # Get the name of the file without the extension
        name = os.path.splitext(os.path.basename(data_file))[0]

        # Skip the file if it is not in the label dictionary
        if name not in label_dict:
            empty_count += 1
            continue

        # Read in the data file
        df = pd.read_csv(data_file)
        

        # Skip the file if it has no rows
        if df.shape[0] < 1:
            empty_count += 1
            print(name)
            continue

        # Binarize the columns and pad missing values
        df = df.iloc[:, 2:]  # Drop the first two columns
        df = binarize_columns(df)
        df = pad_missing_values(df, median_values)

        # Normalize the features
        df = min_max_normalize_features(df)
    
        # Add the feature vector and labels to the lists
        rows_list.append(df.values)
        time, event = label_dict[name]
        time_list.append(time)
        event_list.append(event)
    
    # Convert the lists to tensors
    rows_tensor = tf.ragged.constant(rows_list, ragged_rank=1, dtype=tf.float32)
    labels_tensor = stack_labels(tf.convert_to_tensor(time_list, dtype=tf.float32),
                                 tf.convert_to_tensor(event_list, dtype=tf.float32))

    print(f"Number of samples: {len(rows_list)}")
    print(f"Number of empty data files: {empty_count}")

    return rows_tensor, labels_tensor

data, labels = read_data(data_files, label_dict)


Number of samples: 1654
Number of empty data files: 0


### Remove subjects with persistent NaNs ###

Some NaNs may remain for subjects that have a single region that also contains NaN features (median imputation doesn't work in this case).

In [16]:
indices = []
for i, subject in enumerate(data):
    if not np.sum(np.isnan(subject)):
        indices.append(i)
data = tf.gather(data, np.array(indices), axis=0)
labels = tf.gather(labels, np.array(indices), axis=0)

### Attention model architecture ###

In [29]:
def build_model(D):
    
    # Input layer
    inputs = tf.keras.layers.Input(shape=(None, D), ragged=True)

    # Attention weights
    att = tf.keras.layers.Dense(units=1, activation="relu", name="att")(inputs)

    # Normalize weights to sum to 1
    totals = tf.reduce_sum(att, axis=1, name="att_total")
    normalized = tf.math.divide_no_nan(att, tf.expand_dims(totals, axis=1), name="normalized")

    # Use attention weights to calculate weighted sum of regions
    pooled = tf.linalg.matmul(normalized, inputs, transpose_a=True)

    # Remove the ragged dimension and reshape pooled tensor
    pooled = tf.squeeze(pooled.to_tensor(), axis=1)

    # Apply a linear layer to the pooled vector to generate the time and event risk values
    risk = tf.keras.layers.Dense(units=1, activation="linear", name="risk")(pooled)

    # Build the model
    model = tf.keras.models.Model(inputs=inputs, outputs=[risk, normalized])

    print(f"The input shape of model is: {model.input_shape}")
    print(f"The output shape of model is: {model.output_shape}")

    return model

# Create and compile the model
model = build_model(D)
model.compile(
    loss={"risk": efron},
    metrics={"risk": HarrellsC()},
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
)

The input shape of model is: (None, None, 49)
The output shape of model is: [(1, 1), (None, None, 1)]


###  Data batching ###

In [26]:
def create_datasets(data, labels, indices, shuffle=False):
    ds_data = tf.data.Dataset.from_tensor_slices(tf.gather(data, indices, axis=0))
    ds_labels = tf.data.Dataset.from_tensor_slices(tf.gather(labels, indices, axis=0))
    ds = tf.data.Dataset.zip((ds_data, ds_labels))
    
    if shuffle:
        buffer_size = len(indices)
        ds = ds.shuffle(buffer_size)
    
    ds = ds.batch(64)

    for i, batch in enumerate(ds):
        _, l = batch
        events = sum(l[:,1])
        if events < 1.:
            print(f"Warning, 0 events in batch {i}.")
            
    return ds


def perform_k_fold_cross_validation(data, labels, model, n_splits=5):
    kf = KFold(n_splits=n_splits)

    # Train and validate the model for each fold
    for fold, (train_index, val_index) in enumerate(kf.split(data), start=1):
        print(f"\nTraining fold {fold}")

        # Create training and validation datasets
        ds_train = create_datasets(data, labels, train_index, shuffle=True)
        ds_val = create_datasets(data, labels, val_index)

        # Train the model for the current fold
        history = model.fit(
            ds_train,
            epochs=100,
            verbose=1,
            shuffle=True,
            validation_data=ds_val,
            callbacks=[tf.keras.callbacks.ReduceLROnPlateau(patience=5, verbose=1)],
        )

        # Plot training metrics for the current fold
        plot_training_metrics(history, fold)

        # Print validation results for the current fold
        val_loss, val_harrells_c = history.history['val_risk_loss'][-1], history.history['val_risk_harrellsc'][-1]
        print(f"Fold {fold} validation results - Loss: {val_loss:.4f}, Harrell's C: {val_harrells_c:.4f}")


def plot_training_metrics(history, fold):
    epochs = list(range(1, len(history.history['risk_loss']) + 1))

    # Create subplots
    fig = sp.make_subplots(rows=1, cols=2, subplot_titles=("Loss", "Harrell's C"))

    # Create a trace for training loss
    trace_train_loss = go.Scatter(x=epochs,
                                  y=history.history['risk_loss'],
                                  mode='lines',
                                  name='Train Loss')

    # Create a trace for validation loss
    trace_val_loss = go.Scatter(x=epochs,
                                y=history.history['val_risk_loss'],
                                mode='lines',
                                name='Validation Loss')

    # Add loss traces to the subplots
    fig.add_trace(trace_train_loss, row=1, col=1)
    fig.add_trace(trace_val_loss, row=1, col=1)

    # Create a trace for training Harrell's C
    trace_train_harrells_c = go.Scatter(x=epochs,
                                        y=history.history['risk_harrellsc'],
                                        mode='lines',
                                        name="Train Harrell's C")

    # Create a trace for validation Harrell's C
    trace_val_harrells_c = go.Scatter(x=epochs,
                                      y=history.history['val_risk_harrellsc'],
                                      mode='lines',
                                      name="Validation Harrell's C")

    # Add Harrell's C traces to the subplots
    fig.add_trace(trace_train_harrells_c, row=1, col=2)
    fig.add_trace(trace_val_harrells_c, row=1, col=2)

    # Update xaxis and yaxis titles
    fig.update_xaxes(title_text='Epoch', row=1, col=1)
    fig.update_yaxes(title_text='Loss', row=1, col=1)
    fig.update_xaxes(title_text='Epoch', row=1, col=2)
    fig.update_yaxes(title_text="Harrell's C", row=1, col=2)

    # Update the layout for the plot
    fig.update_layout(title=f'Fold {fold} Training Metrics')

    # Show the plot
    fig.show()



### Visualization ###

In [30]:
# Split the data into train (80%) and test (20%) sets
data_size = data.shape[0]
train_size = int(data_size * 0.8)
indices = np.random.permutation(data_size)
train_indices, test_indices = indices[:train_size], indices[train_size:]

train_data = tf.gather(data, train_indices, axis=0)
test_data = tf.gather(data, test_indices, axis=0)
train_labels = tf.gather(labels, train_indices, axis=0)
test_labels = tf.gather(labels, test_indices, axis=0)

# Use KFold cross-validation on the train set
perform_k_fold_cross_validation(train_data, train_labels, model, n_splits=5)


Training fold 1
Epoch 1/100



Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_4/input.to_tensor_4/RaggedToTensor/boolean_mask_1/GatherV2:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/model_4/input.to_tensor_4/RaggedToTensor/boolean_mask/GatherV2:0", shape=(None, 49), dtype=float32), dense_shape=Tensor("gradient_tape/model_4/input.to_tensor_4/RaggedToTensor/Shape:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.


Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/model_4/tf.math.divide_no_nan_4/RaggedTile/Reshape_3:0", shape=(None,), dtype=int64), values=Tensor("gradient_tape/model_4/tf.math.divide_no_nan_4/RaggedTile/Reshape_2:0", shape=(None, 1), dtype=float32), dense_shape=Tensor("gradient_tape/model_4/tf.math.divide_no_nan_4/RaggedTile/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.



Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 00006: ReduceLROnPlateau reducing learning rate to 9.999999747378752e-06.
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 00011: ReduceLROnPlateau reducing learning rate to 9.999999747378752e-07.
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 00016: ReduceLROnPlateau reducing learning rate to 9.999999974752428e-08.
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 00021: ReduceLROnPlateau reducing learning rate to 1.0000000116860975e-08.
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 00026: ReduceLROnPlateau reducing learning rate to 9.999999939225292e-10.
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 00031: ReduceLROnPlateau reducing learning rate to 9.999999717180686e-11.
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 00036: ReduceLROnPlateau reducing learning rate to 9.9999994396249

Fold 1 validation results - Loss: 19.5122, Harrell's C: 0.5588

Training fold 2
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 00005: ReduceLROnPlateau reducing learning rate to 9.999999998199588e-25.
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 00010: ReduceLROnPlateau reducing learning rate to 1.0000000195414814e-25.
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 00015: ReduceLROnPlateau reducing learning rate to 1.0000000195414814e-26.
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 00020: ReduceLROnPlateau reducing learning rate to 9.999999887266024e-28.
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 00025: ReduceLROnPlateau reducing learning rate to 1.0000000272452012e-28.
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 00030: ReduceLROnPlateau reducing learning rate to 1.0000000031710769e-29.
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epo

Fold 2 validation results - Loss: nan, Harrell's C: 0.6854

Training fold 3
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 00005: ReduceLROnPlateau reducing learning rate to 9.949219096706202e-45.
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 00010: ReduceLROnPlateau reducing learning rate to 9.80908925027372e-46.
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 00015: ReduceLROnPlateau reducing learning rate to 1.4012984643248171e-46.
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoc

Fold 3 validation results - Loss: nan, Harrell's C: 0.6309

Training fold 4
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/1

Fold 4 validation results - Loss: 26.7820, Harrell's C: 0.6031

Training fold 5
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 

Fold 5 validation results - Loss: 26.2904, Harrell's C: 0.4627
