# OS Status: advanced NN architectures (attention-based ones)

In [None]:
#Importing the necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
#Importing to preprocess the data
from sklearn.model_selection import train_test_split 
from sklearn.preprocessing import StandardScaler
#Importing to build the models
from tensorflow.keras import layers, regularizers, models
#Importing to evaluate the models
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
#Importing to explain the models
import shap

# for reproducibility, the value is set for conventional reasons
SEED = 42
tf.keras.utils.set_random_seed(SEED)

In [None]:
# load the data
data = pd.read_csv('dataset_b', encoding='latin-1', sep=',') # requesto the dataset to the author
#data.head()

In [3]:
# target column : "os_status", binary variable
# relevant columns
relevant_columns = ['age', 'dcr', 'dnlr', 'histology', 'immuno_line', 'iorr', 
                    'ldhpre', 'leucotpre', 'nb_meta_beforeimmuno', 'neuttpre', 
                     'ps_befimmuno', 'sex', 'smoking_history', 'os_status']

data = data[relevant_columns]
data = data.dropna(axis=0)
data['dcr'] = data['dcr'].astype(int)
data['age'] = data['age'].astype(int)
data['iorr'] = data['iorr'].astype(int)
data['ps_befimmuno'] = data['ps_befimmuno'].astype(int)

#data.head()

In [4]:
data['histology'] = data['histology'].str.lower()
data['sex'] = data['sex'].str.lower()
data['smoking_history'] = data['smoking_history'].str.lower()

In [None]:
data= data.dropna(axis=0)
data.shape

In [6]:
#to randomize the data
data = data.sample(frac=1, random_state=SEED)

# one-hot encoding
one_hot_data = pd.get_dummies(data, columns=['histology', 'sex', 'smoking_history'])

one_hot_data = one_hot_data.rename(columns={
    'histology_Adenocarcinoma': 'histology_adenocarcinoma',
    'histology_Squamous': 'histology_squamous',
    'histology_Nsclc_other': 'histology_nsclc_other',
    'histology_Large_cells': 'histology_large_cells',
    'sex_Male': 'sex_male',
    'sex_Female': 'sex_female',
    'smoking_history_Non_smoker': 'smoking_history_non_smoker',
    'smoking_history_Former': 'smoking_history_former',
    'smoking_history_Current': 'smoking_history_current',
    'smoking_history_Unk': 'smoking_history_unk'
})

#one_hot_data.head()

In [None]:
# replace boolean values with 0 and 1
for col in ['histology_adenocarcinoma','histology_squamous','histology_nsclc other',
    'histology_large cells','sex_male','sex_female','smoking_history_non smoker','smoking_history_former','smoking_history_current',
     'smoking_history_unk']:
    one_hot_data[col] = one_hot_data[col].replace({False: 0, True: 1})

In [None]:
# split the data into features and target
X = one_hot_data[one_hot_data.columns.difference(['os_status'])]
y = data['os_status']


# First split: training+validation vs test (80% vs 20%)
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y  # stratify to maintain the same proportion of classes in each set
)

# Second split: training vs validation (75% vs 25% of 80%)
# This results in 60% training, 20% validation, and 20% test
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.25, random_state=42, stratify=y_temp
)

In [None]:
#This ensures that all numerical features contribute equally
numerical_features = ['age', 'dcr', 'dnlr', 'ldhpre', 'leucotpre', 
                      'nb_meta_beforeimmuno', 'neuttpre', 'ps_befimmuno']
scaler = StandardScaler()

binary_features = [col for col in X.columns if col not in numerical_features]


scaler = StandardScaler()
X_train_scaled = X_train.copy()
X_val_scaled = X_val.copy() 
X_test_scaled = X_test.copy()
X_scaled = X.copy()
X_train_val_scaled = X_temp.copy()

X_scaled[numerical_features] = scaler.fit_transform(X_scaled[numerical_features])
X_train_scaled[numerical_features] = scaler.fit_transform(X_train_scaled[numerical_features])
X_val_scaled[numerical_features] = scaler.transform(X_val_scaled[numerical_features])
X_test_scaled[numerical_features] = scaler.transform(X_test_scaled[numerical_features])
X_train_val_scaled[numerical_features] = scaler.fit_transform(X_train_val_scaled[numerical_features])

In [None]:
X_scaled = X.copy()
X_scaled[numerical_features] = scaler.fit_transform(X_scaled[numerical_features])
X_train_scaled_df = pd.DataFrame(X_train_scaled, columns=X_train.columns)
X_test_scaled_df = pd.DataFrame(X_test_scaled, columns=X_test.columns)

In [None]:
# Function to create the attention layer
class FeatureAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(FeatureAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        # Net to compute attention weights
        self.attention_dense = layers.Dense(
            input_shape[-1], 
            activation='tanh',
            name='attention_dense'
        )
        self.attention_output = layers.Dense(
            input_shape[-1], 
            activation='softmax',
            name='attention_output'
        )
        super(FeatureAttention, self).build(input_shape)

    def call(self, inputs):
        # Evaluate attention weights for each sample
        attention_weights = self.attention_dense(inputs)
        attention_weights = self.attention_output(attention_weights)
        
        # Apply attention by element-wise multiplication
        attended_features = inputs * attention_weights
        
        return attended_features


# 1. Attention-block at input level
def model_attention_input_only(input_shape):
    return models.Sequential([
        FeatureAttention(input_shape=input_shape), # attention on raw input
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.LayerNormalization(),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(1, activation='sigmoid')
    ])

# 2. Attention-block after intermediate layer
def model_attention_intermediate(input_shape):
    return models.Sequential([
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01), input_shape=input_shape),
        layers.LayerNormalization(),
        layers.Dropout(0.4),
        FeatureAttention(),  # attention on transformed features
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(1, activation='sigmoid')
    ])

# 3. Multiple Attention-blocks (input + intermediate)
def model_multiple_attention(input_shape):
    return models.Sequential([
        FeatureAttention(input_shape=input_shape),  # attention on raw input
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.LayerNormalization(),
        layers.Dropout(0.4),
        FeatureAttention(),  # attention on transformed features
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(1, activation='sigmoid')
    ])

# 4. Attention-block just before output
def model_attention_pre_output(input_shape):
    return models.Sequential([
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01), input_shape=input_shape),
        layers.LayerNormalization(),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        FeatureAttention(),  
        layers.Dense(1, activation='sigmoid')
    ])

# 5. Residual Attention
class ResidualAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(ResidualAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.attention = FeatureAttention()
        super(ResidualAttention, self).build(input_shape)

    def call(self, inputs):
        attended = self.attention(inputs)
        return inputs + attended  # Residual connection

def model_residual_attention(input_shape):
    return models.Sequential([
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01), input_shape=input_shape),
        layers.LayerNormalization(),
        layers.Dropout(0.4),
        ResidualAttention(),  # attention with residual connection
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(1, activation='sigmoid')
    ])


def compare_models(X_train, y_train, X_val, y_val):
    input_shape = (X_train.shape[1],)
    
    models_to_test = {
        'attention_input': model_attention_input_only(input_shape),
        'attention_intermediate': model_attention_intermediate(input_shape),
        'attention_pre_output': model_attention_pre_output(input_shape),
        'multiple_attention': model_multiple_attention(input_shape),
        'residual_attention': model_residual_attention(input_shape)
    }
    
    results = {}
    
    for name, model in models_to_test.items():
        print(f"\nTraining {name}...")
        model.compile(optimizer='adam',
                  loss='binary_crossentropy',  
                  metrics=['accuracy'])
        
        history = model.fit(X_train_scaled, y_train,
                         validation_data=(X_val_scaled, y_val),
                         epochs=500,
                         batch_size=16,
                         verbose=0
        )

        epochs = range(1, len(history.history['loss']) + 1)

        # Plot of the loss
        plt.figure(figsize=(12, 5))

        plt.subplot(1, 2, 1)
        plt.plot(epochs, history.history['loss'], 'b', label='Training Loss')
        plt.plot(epochs, history.history['val_loss'], 'r', label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()

        # Plot of accuracy
        plt.subplot(1, 2, 2)
        plt.plot(epochs, history.history['accuracy'], 'b', label='Training Accuracy')
        plt.plot(epochs, history.history['val_accuracy'], 'r', label='Validation Accuracy')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.title('Training and Validation accuracy')
        plt.legend()

        plt.show()
        
        # Predict class probabilities
        prob_predictions = model.predict(X_test_scaled)

        # Get the predicted class index
        prob_predictions = np.squeeze(prob_predictions)

        # Convert the probabilities into binary predictions
        class_predictions = (prob_predictions >= 0.5).astype(int)
        
        accuracy = accuracy_score(y_test, class_predictions)
        precision = precision_score(y_test, class_predictions)
        recall = recall_score(y_test, class_predictions)
        f1 = f1_score(y_test, class_predictions)
        roc_auc = roc_auc_score(y_test, prob_predictions)

        # Print the results
        print(f"Accuracy: {accuracy:.4f}")
        print(f"Precision: {precision:.4f}")
        print(f"Recall: {recall:.4f}")
        print(f"F1-score: {f1:.4f}")
        print(f"AUC-ROC: {roc_auc:.4f}")

        print(confusion_matrix(y_test, class_predictions))

        prob_predictions = model.predict(X_scaled)
        prob_predictions = np.squeeze(prob_predictions)
        class_predictions = (prob_predictions >= 0.5).astype(int)
        data['Predicted'] = class_predictions

        print(data.tail(10))
        

        
        background = shap.kmeans(X_train_scaled, 10)
        explainer = shap.KernelExplainer(lambda x: model.predict(x), background)

        X_subset = X_test_scaled_df.sample(50, random_state=42)
        shap_values = explainer.shap_values(X_subset, nsamples=50, silent=True)

        if isinstance(shap_values, list):
            shap_values = shap_values[0]
        if shap_values.ndim == 3:
            shap_values = shap_values[:, :, 0]

        print(f"SHAP values shape: {shap_values.shape}, X_subset shape: {X_subset.shape}")

        plt.figure(figsize=(10, 8))
        shap.summary_plot(
            shap_values,
            features=X_subset,
            feature_names=X_subset.columns,
            plot_type='dot',
            max_display=len(binary_features + numerical_features),
            show=False
        )
        plt.title("SHAP Summary Plot – Class 1")
        plt.tight_layout()
        plt.savefig("shap_summary_binary.png")
        plt.show()

        
    
    return results

In [None]:
results = compare_models(X_train_scaled, y_train, X_val_scaled, y_val)