# OS Duration: advanced NN architectures (attention-based ones)

In [None]:
#Importing necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
#Importing libraries for data preprocessing 
from sklearn.model_selection import train_test_split 
from sklearn.preprocessing import StandardScaler
#Importing libraries for building the model
from tensorflow.keras import layers, regularizers, models
#Importing libraries for model evaluation
from sklearn.metrics import mean_absolute_error, mean_squared_error
#Importing library for explainability
import shap

# for reproducibility, the value is set for conventional reasons
SEED = 42
tf.keras.utils.set_random_seed(SEED)


In [None]:
# load the data
data = pd.read_csv('dataset_b.csv', encoding='latin-1', sep=',') # request the dataet to the author
#data.head()

In [None]:
# target column : "osduree", continuous variable
# relevant columns
relevant_columns = ['age', 'dcr', 'dnlr', 'histology', 'immuno_line', 'iorr', 
                    'ldhpre', 'leucotpre', 'nb_meta_beforeimmuno', 'neuttpre', 
                     'ps_befimmuno', 'sex', 'smoking_history', 'osduree']

data = data[relevant_columns]
data = data.dropna(axis=0)


In [None]:
data['histology'] = data['histology'].str.lower()
data['sex'] = data['sex'].str.lower()
data['smoking_history'] = data['smoking_history'].str.lower()

In [None]:
data= data.dropna(axis=0)
data.shape

In [None]:
#to randomize the data
data = data.sample(frac=1, random_state=SEED)

# one-hot encoding
one_hot_data = pd.get_dummies(data, columns=['histology', 'sex', 'smoking_history'])

one_hot_data = one_hot_data.rename(columns={
    'histology_Adenocarcinoma': 'histology_adenocarcinoma',
    'histology_Squamous': 'histology_squamous',
    'histology_Nsclc_other': 'histology_nsclc_other',
    'histology_Large_cells': 'histology_large_cells',
    'sex_Male': 'sex_male',
    'sex_Female': 'sex_female',
    'smoking_history_Non_smoker': 'smoking_history_non_smoker',
    'smoking_history_Former': 'smoking_history_former',
    'smoking_history_Current': 'smoking_history_current',
    'smoking_history_Unk': 'smoking_history_unk'
})

#one_hot_data.head()

In [None]:
# replace boolean values with 0 and 1
for col in ['histology_adenocarcinoma','histology_squamous','histology_nsclc other',
    'histology_large cells','sex_male','sex_female','smoking_history_non smoker','smoking_history_former','smoking_history_current',
     'smoking_history_unk']:
    one_hot_data[col] = one_hot_data[col].replace({False: 0, True: 1})

In [None]:
# split the data into features and target
X = one_hot_data[one_hot_data.columns.difference(['osduree'])]
y = data['osduree']


# First split: training vs test (80% vs 20%)
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42  # stratify per mantenere proporzioni delle classi
)

# Second split: training vs validation (75% vs 25% of the 80%)
# This results in 60% training, 20% validation, and 20% test overall
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=0.25, random_state=42
)

In [None]:
#This ensures that all numerical features contribute equally
numerical_features = ['age', 'dcr', 'dnlr', 'ldhpre', 'leucotpre', 
                      'nb_meta_beforeimmuno', 'neuttpre', 'ps_befimmuno']

binary_features = [col for col in X.columns if col not in numerical_features]


scaler = StandardScaler()
X_train_scaled = X_train.copy()
X_val_scaled = X_val.copy() 
X_test_scaled = X_test.copy()
X_scaled = X.copy()
X_train_val_scaled = X_temp.copy()

X_scaled[numerical_features] = scaler.fit_transform(X_scaled[numerical_features])
X_train_scaled[numerical_features] = scaler.fit_transform(X_train_scaled[numerical_features])
X_val_scaled[numerical_features] = scaler.transform(X_val_scaled[numerical_features])
X_test_scaled[numerical_features] = scaler.transform(X_test_scaled[numerical_features])
X_train_val_scaled[numerical_features] = scaler.fit_transform(X_train_val_scaled[numerical_features])

In [None]:
def smape_f(y_true, y_pred):
    epsilon = tf.keras.backend.epsilon()
    denominator = tf.maximum(
        (tf.abs(y_true) + tf.abs(y_pred) + epsilon) / 2.0,
        epsilon
    )
    diff = tf.abs(y_true - y_pred)
    return 100 * tf.reduce_mean(diff / denominator)

In [None]:
X_train_scaled_df = pd.DataFrame(X_train_scaled, columns=X_train.columns)
X_test_scaled_df = pd.DataFrame(X_test_scaled, columns=X_test.columns)

In [None]:
# Attention Layer Definition
class FeatureAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(FeatureAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        # Net to compute attention weights
        self.attention_dense = layers.Dense(
            input_shape[-1], 
            activation='tanh',
            name='attention_dense'
        )
        self.attention_output = layers.Dense(
            input_shape[-1], 
            activation='softmax',
            name='attention_output'
        )
        super(FeatureAttention, self).build(input_shape)

    def call(self, inputs):
        # Estimate attention weights for each sample
        attention_weights = self.attention_dense(inputs)
        attention_weights = self.attention_output(attention_weights)
        
        # Apply attention weights to the inputs
        attended_features = inputs * attention_weights
        
        return attended_features



# 1. attention-block at input level
def model_attention_input_only(input_shape):
    return models.Sequential([
        FeatureAttention(input_shape=input_shape),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.LayerNormalization(),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(1, activation='linear')
    ])

# 2. attention-block after intermediate layer
def model_attention_intermediate(input_shape):
    return models.Sequential([
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01), input_shape=input_shape),
        layers.LayerNormalization(),
        layers.Dropout(0.4),
        FeatureAttention(),  
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(1, activation='linear')
    ])

# 3. multiple attention-blocks at different levels
def model_multiple_attention(input_shape):
    return models.Sequential([
        FeatureAttention(input_shape=input_shape), 
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.LayerNormalization(),
        layers.Dropout(0.4),
        FeatureAttention(),  
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(1, activation='linear')
    ])

# 4. attention-block just before output
def model_attention_pre_output(input_shape):
    return models.Sequential([
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01), input_shape=input_shape),
        layers.LayerNormalization(),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        FeatureAttention(),  
        layers.Dense(1, activation='linear')
    ])

# 5. attention with residual connection
class ResidualAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(ResidualAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        self.attention = FeatureAttention()
        super(ResidualAttention, self).build(input_shape)

    def call(self, inputs):
        attended = self.attention(inputs)
        return inputs + attended  # Residual connection

def model_residual_attention(input_shape):
    return models.Sequential([
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01), input_shape=input_shape),
        layers.LayerNormalization(),
        layers.Dropout(0.4),
        ResidualAttention(),  
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(0.01)),
        layers.Dropout(0.4),
        layers.Dense(1, activation='linear')
    ])



def compare_models(X_train, y_train, X_val, y_val):
    input_shape = (X_train.shape[1],)
    
    models_to_test = {
        'attention_input': model_attention_input_only(input_shape),
        'attention_intermediate': model_attention_intermediate(input_shape),
        'attention_pre_output': model_attention_pre_output(input_shape),
        'multiple_attention': model_multiple_attention(input_shape),
        'residual_attention': model_residual_attention(input_shape)
    }
    
    results = {}
    
    for name, model in models_to_test.items():
        print(f"\nTraining {name}...")
        model.compile(optimizer='adam',
                  loss='mae',  
                  metrics=['mae', 'mse', smape_f])
        
        history = model.fit(X_train_scaled, y_train,
                         validation_data=(X_val_scaled, y_val),
                         epochs=50,
                         batch_size=16,
                         verbose=0
        )

        epochs = range(1, len(history.history['loss']) + 1)

        plt.figure(figsize=(12, 5))

        # Plot of the Loss 
        plt.subplot(1, 2, 1)
        plt.plot(epochs, history.history['loss'], 'b', label='Training Loss')
        plt.plot(epochs, history.history['val_loss'], 'r', label='Validation Loss')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()

        # Plot of MAE 
        plt.subplot(1, 2, 2)
        plt.plot(epochs, history.history['mae'], 'b', label='Training MAE')
        plt.plot(epochs, history.history['val_mae'], 'r', label='Validation MAE')
        plt.xlabel('Epochs')
        plt.ylabel('MAE')
        plt.title('Training and Validation MAE')
        plt.legend()

        plt.tight_layout()
        plt.show()

        
        y_pred = model.predict(X_test_scaled).flatten()  

        # Main metrics
        mae = mean_absolute_error(y_test, y_pred)
        mse = mean_squared_error(y_test, y_pred)
        mape = np.mean(np.abs((y_test - y_pred) / (y_test + 1e-8))) * 100  # 1e-8 to avoid division by zero
        smape = 100/len(y_test) * np.sum(2 * np.abs(y_pred - y_test) / (np.abs(y_test) + np.abs(y_pred) + 1e-8))

        # Print results
        print(f"MAE: {mae:.4f}")
        print(f"MSE: {mse:.4f}")
        print(f"MAPE: {mape:.2f}%")
        print(f"SMAPE: {smape:.2f}%")



        results = pd.DataFrame({
            'osduree_true': y_test,       # true values
            'osduree_pred': y_pred        # predicted values
        })

        print(results.tail(10))
        
        plt.figure(figsize=(6, 6))
        plt.scatter(y_test, y_pred, alpha=0.6)
        plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--')  
        plt.xlabel("True Values")
        plt.ylabel("Predicted Values")
        plt.title("Predicted vs True Values")
        plt.show()


    
        explainer = shap.Explainer(model, X_train_scaled_df)  
        shap_values = explainer(X_test_scaled_df)
        print("SHAP summary plot:")
        shap.plots.beeswarm(
            shap_values,
            max_display=len(binary_features + numerical_features)
        )

                
            
    return results

In [None]:
results = compare_models(X_train_scaled, y_train, X_val_scaled, y_val)