# ðŸš€ Machine Learning Analysis on Processed Data  

## ðŸŽ¯ Objective  
This notebook performs machine learning analysis on the processed dataset.  
It includes data preprocessing specific to machine learning (feature/target separation, train/test splitting, and data shuffling), feature selection, model training, and evaluation.  

The analysis can focus on one of two conditions:  
- **Odor Recognition**  
- **Associative Memory**  

ðŸ“Œ **Important:** The condition selection is done in the step **[Condition Selection](#toc3_2_)**, where the user must specify `"odor_recognition"` or `"associative_memory"` before proceeding with the analysis.  

## ðŸ“‘ Table of Contents  

- [Libraries](#toc1_)  
  *Importing necessary libraries for machine learning and data analysis.*  
- [Functions](#toc2_)  
  *Defining functions used throughout the notebook.*  
- [Data Loading & Preprocessing](#toc3_)  
  - [Loading](#toc3_1_) â€“ *Load the processed dataset into a DataFrame.*  
  - [Condition Selection](#toc3_2_) â€“ *Select the analysis condition: `"odor_recognition"` or `"associative_memory"`.*  
  - [Info on Data](#toc3_3_)  
    - [Check for Unique Values](#toc3_3_1_) â€“ *Inspect key dataset attributes (e.g., number of participants, number of odors, etc.).*  
    - [Target Distribution](#toc3_3_2_) â€“ *Visualize the distribution of the target variable.*  
  - [Features/Target Split](#toc3_4_) â€“ *Separate predictor variables and target variable.*  
  - [Train/Test Split](#toc3_5_) â€“ *Split data into training and testing sets.*  
  - [Data Shuffling](#toc3_6_) â€“ *Randomize data order to prevent biases.*  
- [Feature Selection](#toc4_)  
  *Select the most relevant features for model training.*  
- [Base Model](#toc5_)  
  *Train and evaluate a baseline model without hyperparameter tuning.*  
- [All Features - Models with Tuning](#toc6_)  
  *Train machine learning models using all features with hyperparameter tuning.*  
- [Final Models](#toc7_)  
  *Train the final models using selected features and optimized hyperparameters.*  

# <a id='toc1_'></a>[Libraries](#toc0_)

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
import seaborn as sns
import shap
import xgboost as xgb

import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.patches import Polygon

from imblearn.over_sampling import SMOTE, SMOTENC
from imblearn.pipeline import Pipeline as imbpipeline

from sklearn.ensemble import RandomForestClassifier
from sklearn.manifold import TSNE
from sklearn.metrics import (
    classification_report,
    roc_auc_score,
)
from sklearn.model_selection import (
    StratifiedGroupKFold,
    cross_validate,
    permutation_test_score,
)
from sklearn.utils import shuffle

from skopt import BayesSearchCV
from skopt.space import Categorical, Integer, Real

# <a id='toc2_'></a>[Functions](#toc0_)

In [None]:
feature_name_mapping = {
    "gender_encoded": "Gender",
    "pleasantness": "Pleasantness",
    "emotional_strength": "Emotional strength",
    "intensity": "Intensity",
    "familiarity": "Familiarity",
    "percept_dist": "Perceptive distance",
    "nb_words": "Number of words",
    "jaccard_dist": "Semantic distance"
}

In [None]:
def perform_random_forest(X_train, y_train, groups_train, model_type="default_rf", tuning=False, categorical_features="No"):
    # Define categorical feature
    if categorical_features == "Yes":
        categorical_columns = ["gender_encoded"]
        cat_col_index = [X_train.columns.get_loc(col) for col in categorical_columns if col in X_train.columns]
    else:
        cat_col_index = None

    # Define classifier
    if model_type == "default_rf":
        model = RandomForestClassifier(random_state=42)
    elif model_type == "balanced_rf":
        model = RandomForestClassifier(class_weight="balanced", random_state=42)
    elif model_type == "smote_rf":
        if categorical_features == "Yes":
            smote = SMOTENC(categorical_features=cat_col_index, random_state=42)
        else:
            smote = SMOTE(random_state=42)
    
        model = imbpipeline([
            ('smote', smote),
            ('rfc', RandomForestClassifier(random_state=42))
        ])
    else:
        raise ValueError("Invalid model_type. Choose from 'default_rf', 'balanced_rf', 'smote_rf'.")
    
    # Define cross-validation
    cv = StratifiedGroupKFold(n_splits=10)
    scoring = ["accuracy", "roc_auc", "balanced_accuracy", "recall", "precision", "f1", "f1_weighted"]
    
    # If no tuning, perform cross-validation
    if not tuning:
        cv_results = cross_validate(
            model,
            X_train,
            y_train,
            groups=groups_train,
            cv=cv,
            n_jobs=-1,
            scoring=scoring,
            return_train_score=True
            )
        
        print("Average Performance Scores:")
        for metric in scoring:
            train_score = np.mean(cv_results[f"train_{metric}"])
            test_score = np.mean(cv_results[f"test_{metric}"])
            print(f"{metric}: Train - {train_score}, Validation - {test_score}")
        
        model.fit(X_train, y_train)
        return model, cv_results
    
    # Define the parameters to be tested with BayesSearchCV
    param_distributions = {
        'n_estimators': Integer(500, 2000),
        'max_depth': Integer(3, 50, prior='uniform'),
        'max_leaf_nodes': Integer(10, 300, prior='uniform'),
        'min_samples_leaf': Integer(1, 20),
        'min_samples_split': Integer(2, 20),
        'max_features': Categorical(['sqrt', 'log2', None])
    }

    # If SMOTE is used, prepend 'rfc__' to parameters
    if model_type == "smote_rf":
        param_distributions = {f'rfc__{key}': value for key, value in param_distributions.items()}

    # If tuning, search for the best hyperparameters with BayesSearchCV
    search_cv = BayesSearchCV(
        model,
        param_distributions,
        n_iter=100,
        scoring=scoring,
        cv=cv,
        refit="roc_auc",
        random_state=42
        )
    search_cv.fit(X_train, y_train, groups=groups_train)
    
    # Best model
    best_model = search_cv.best_estimator_
    print(f"Best model: {best_model}")
    
    # Results
    columns = [f"param_{name}" for name in param_distributions.keys()]
    columns += [f"mean_test_{metric}" for metric in scoring]
    cv_results_df = pd.DataFrame(search_cv.cv_results_)

    return best_model, cv_results_df[columns].sort_values(by="mean_test_roc_auc", ascending=False)

In [None]:
def perform_xgboost(X_train, y_train, groups_train):
    # Calculation of scale_pos_weight based on class balance
    scale_pos_weight = y.value_counts(normalize=True)[0] / y.value_counts(normalize=True)[1]  # class_0 / class_1

    # XGBoost classifier initialization
    model = xgb.XGBClassifier(
        objective="binary:logistic",
        random_state=42,
        eval_metric='auc',
        scale_pos_weight=scale_pos_weight
    )

    # Define the parameters to be tested with BayesSearchCV
    param_distributions = {
        'n_estimators': Integer(500, 2000),
        'max_depth': Integer(3, 50, prior='uniform'),
        'max_leaves': Integer(10, 300, prior='uniform'),
        'min_child_weight': Integer(1, 20),
        'gamma': Real(0, 10),
        'subsample': Real(0.5, 1.0),
        'colsample_bytree': Categorical([0.25, 0.5, 1.0]),
        'num_parallel_tree': Integer(1, 10)
    }

    # Define cross-validation
    cv = StratifiedGroupKFold(n_splits=10)
    scoring = ["accuracy", "roc_auc", "balanced_accuracy", "recall", "precision", "f1", "f1_weighted"]

   # Searching for the best hyperparameters with BayesSearchCV
    search_cv = BayesSearchCV(
        model,
        param_distributions,
        n_iter=100,
        scoring=scoring,
        cv=cv,
        refit="roc_auc",
        random_state=42
    )
    
    # Training
    search_cv.fit(X_train, y_train, groups=groups_train)
    
    # Best model
    best_model = search_cv.best_estimator_
    print(f"Best model: {best_model}")
    
    # Results
    columns = [f"param_{name}" for name in param_distributions.keys()]
    columns += [f"mean_test_{metric}" for metric in scoring]
    cv_results_df = pd.DataFrame(search_cv.cv_results_)

    return best_model, cv_results_df[columns].sort_values(by="mean_test_roc_auc", ascending=False)

In [None]:
def permutation_test_with_plot(model, X, y, groups, n_permutations=100, scoring="roc_auc"):
    # Define cv
    cv = StratifiedGroupKFold(n_splits=10)
    
    # Perform permutation test
    score, perm_scores, p_value = permutation_test_score(
        model,
        X, y, groups=groups,
        scoring=scoring,
        cv=cv,
        n_permutations=n_permutations,
        n_jobs=-1,
        random_state=42 
    )

    # Print results
    print(f"\nPermutation test results:\nTraining {scoring}: {score:.3f}, P-value: {p_value:.3f}")

    # Plot the distribution of permutation scores
    plt.figure(figsize=(8, 5))
    plt.hist(perm_scores, bins=50, density=True, alpha=0.5, color='blue', label='Permutation Scores')
    plt.axvline(score, color='red', linestyle='dashed', linewidth=2, label='Actual Score')

    plt.xlabel("ROC AUC", fontsize=20)
    plt.ylabel("Density", fontsize=20)
    plt.legend(fontsize=16)

    # plt.xlim(0.4, 0.61)
    # plt.xticks([0.4, 0.45, 0.5, 0.55, 0.6], fontsize=20)
    # plt.ylim(0, 33)
    plt.grid(False)
    sns.despine()

    plt.gca().spines['bottom'].set_linewidth(2)
    plt.gca().spines['left'].set_linewidth(2)
    plt.tick_params(axis='both', length=10, labelsize=20, width=2)

    plt.tight_layout()

    return score, perm_scores, p_value, plt

In [None]:
def evaluate_model(best_model, X_test, y_test):
    # Recovering the best model from the pipeline (if SMOTE/SMOTE-NC was used)
    if hasattr(best_model, 'steps'):
        best_model = best_model.steps[-1][1]

    # Predictions
    y_pred = best_model.predict(X_test)
    
    # Confusion matrix
    print("\nConfusion matrix:",
    pd.crosstab(
    y_test,
    y_pred,
    rownames=['Actual class'],
    colnames=['\nPredicted class']
    )
    )

    # Classification report
    print("\nClassification Report:\n", classification_report(y_test, y_pred))

    # Probabilities (for AUC calculation)
    y_prob = best_model.predict_proba(X_test)[:, 1]  # Probabilities for class 1

    # AUC score
    auc_score = roc_auc_score(y_test, y_prob)
    print(f"Test AUC: {auc_score}")

In [None]:
def shap_feature_selection(best_model, X_train, y_train, model="rf"):
    # Recovering the best model from the pipeline (if SMOTE/SMOTE-NC was used)
    if hasattr(best_model, 'steps'):
        best_model = best_model.steps[-1][1]

    # Create a SHAP explainer
    explainer = shap.TreeExplainer(best_model)

    # Calculate SHAP values for train set
    shap_values = explainer.shap_values(X_train)

    # Map feature names
    mapped_feature_names = [feature_name_mapping.get(name, name) for name in X_train.columns]
    
    # X_train with new names
    X_train_mapped = X_train.copy()
    X_train_mapped.columns = mapped_feature_names

    # Bar plot of SHAP values
    plt.figure(figsize=(8, 5))
    bar_summary_plot = shap.summary_plot(shap_values, X_train_mapped, plot_type="bar")

    # Dot plot of SHAP values
    plt.figure(figsize=(8, 5))
    # If the model is a RF, dot plot of SHAP values for class 1
    if model == "rf":
        dot_summary_plot = shap.summary_plot(shap_values[1], X_train_mapped)
    # If the model is a XGB, dot plot of SHAP values
    elif model == "xgb":
        dot_summary_plot = shap.summary_plot(shap_values, X_train_mapped)

    return bar_summary_plot, dot_summary_plot

In [None]:
def error_analysis(best_model, X_test, y_test):
    # Recovering the best model from the pipeline (if SMOTE/SMOTE-NC was used)
    if hasattr(best_model, 'steps'):
        best_model = best_model.steps[-1][1]

    # Predictions
    y_pred = best_model.predict(X_test)

    # Classification results
    correct_class_0 = (y_test == 0) & (y_pred == 0)
    correct_class_1 = (y_test == 1) & (y_pred == 1)
    misclassified_class_0_as_1 = (y_test == 0) & (y_pred == 1)
    misclassified_class_1_as_0 = (y_test == 1) & (y_pred == 0)

    # Create DataFrame with classifications
    df_errors = pd.concat([
        X_test[correct_class_0].assign(classification='Correct Class 0'),
        X_test[correct_class_1].assign(classification='Correct Class 1'),
        X_test[misclassified_class_0_as_1].assign(classification='Misclassified 0->1'),
        X_test[misclassified_class_1_as_0].assign(classification='Misclassified 1->0')
    ], axis=0).reset_index(drop=True)

    # Classification order
    classification_order = ['Correct Class 0', 'Misclassified 1->0', 'Correct Class 1', 'Misclassified 0->1']

    # Axes configuration
    feature_config = {
        "Gender": {"xticks": [0, 1], "xticklabels": ['W', 'M'], "xlim": (-0.5, 1.5)},
        "Pleasantness": {"xlim": (-5, 5), "xticks": np.arange(-5, 6, 2.5)},
        "Emotional strength": {"xlim": (0, 5), "xticks": np.arange(0, 6, 1)},
        "Intensity": {"xlim": (0, 10), "xticks": np.arange(0, 11, 2.5)},
        "Familiarity": {"xlim": (0, 10), "xticks": np.arange(0, 11, 2.5)},
        "Perceptive distance": {"xlim": (0, 12), "xticks": np.arange(0, 13, 2)},
        "Number of words": {"xlim": (0, 12), "xticks": np.arange(0, 13, 1), "xticklabels": np.arange(0, 13, 1)},
        "Semantic distance": {"xlim": (0, 1), "xticks": np.arange(0, 1.25, 0.25)},
    }

    # Figure setup
    num_features = len(X_test.columns)
    rows = (num_features // 3) + 1
    fig, axes = plt.subplots(rows, 3, figsize=(15, rows * 3.5))
    axes = axes.flatten()
    fontsize=16
    labelsize=16

    for i, feature_name in enumerate(X_test.columns):
        ax = axes[i]
        x_label = feature_name_mapping.get(feature_name, feature_name)

        if feature_name == 'gender_encoded':
            gender_counts = df_errors.groupby(['classification', 'gender_encoded']).size().unstack(fill_value=0)
            gender_proportions = gender_counts.div(gender_counts.sum(axis=1), axis=0).reindex(classification_order)
            gender_proportions = gender_proportions.rename(columns={0: 'W', 1: 'M'})

            ax.barh(range(len(classification_order)), gender_proportions['W'], 
                     color=[palette[c] for c in classification_order], height=0.8, label='W')
            ax.barh(range(len(classification_order)), gender_proportions['M'], 
                     left=gender_proportions['W'],
                     color=[palette[c] for c in classification_order], height=0.8, label='M', hatch='//')

            ax.set_xlim(-0.02, 1.02)
            ax.set_xticks(np.arange(0, 1.25, 0.25))
            ax.set_xlabel('Proportion', fontsize=fontsize)
            ax.set_yticks(range(len(classification_order)))
            ax.set_yticklabels(classification_order, fontsize=fontsize)
            ax.set_ylabel('')
            ax.yaxis.set_ticks_position('none')
            ax.legend(title='Gender', loc='upper right', bbox_to_anchor=(1.5, 0.9), fontsize=fontsize, title_fontsize=fontsize)
            ax.invert_yaxis()

        else:
            sns.stripplot(
                data=df_errors,
                x=feature_name,
                y='classification',
                order=classification_order,
                hue='classification',
                hue_order=classification_order,
                dodge=False,
                jitter=0.2,
                palette=palette,
                marker='o',
                size=6,
                ax=ax
            )

            # Apply feature-specific X-axis configuration
            config = feature_config.get(x_label, {})
            if "xlim" in config:
                # Check if the feature is "Semantic distance"
                if feature_name == "jaccard_dist":
                    ax.set_xlim(config["xlim"][0] - 0.02, config["xlim"][1] + 0.02)
                else:
                    ax.set_xlim(config["xlim"][0] - 0.2, config["xlim"][1] + 0.2)

            if "xticks" in config:
                ax.set_xticks(config["xticks"])

            if "xticklabels" in config:
                ax.set_xticklabels(config["xticklabels"], fontsize=fontsize)

            # Special case for "Number of words"
            if feature_name == "nb_words":
                ax.set_xlim(-0.2, 12.2)
                ax.set_xticks(np.arange(0, 13, 1))
                ax.set_xticklabels([i if i % 2 == 0 else '' for i in np.arange(0, 13, 1)], fontsize=fontsize)

            ax.set_xlabel(x_label, fontsize=fontsize)
            ax.set_ylabel('')
            ax.yaxis.set_ticks_position('none')

        # Display y labels only in first column
        if i % 3 == 0:
            ax.set_yticklabels(classification_order, fontsize=fontsize)
        else:
            ax.set_yticklabels([])

        ax.tick_params(axis='both', length=10, width=2, labelsize=labelsize)
        sns.despine(ax=ax)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_linewidth(2)
        ax.spines['top'].set_visible(False)

    # t-SNE or scatterplot depending on the number of features
    if num_features > 2:
        tsne = TSNE(n_components=2, perplexity=30, n_iter=1000, random_state=42)
        X_tsne = tsne.fit_transform(df_errors[X_test.columns])
        tsne_df = pd.DataFrame(X_tsne, columns=['t-SNE 1', 't-SNE 2'])
        tsne_df['classification'] = df_errors['classification'].reset_index(drop=True)

        ax_tsne = axes[num_features] if num_features < len(axes) else fig.add_subplot(rows, 3, num_features + 1)
        sns.scatterplot(
            data=tsne_df,
            x='t-SNE 1',
            y='t-SNE 2',
            hue='classification',
            palette=palette,
            s=60,
            ax=ax_tsne
        )
        ax_tsne.set_xlabel('t-SNE 1', fontsize=fontsize)
        ax_tsne.set_ylabel('t-SNE 2', fontsize=fontsize)

        # Remove legend
        ax_tsne.legend_.remove()

        # Formatting
        sns.despine(ax=ax_tsne)
        ax_tsne.spines['bottom'].set_linewidth(2)
        ax_tsne.spines['left'].set_linewidth(2)
        ax_tsne.spines['top'].set_visible(False)
        ax_tsne.spines['right'].set_visible(False)
        ax_tsne.xaxis.set_tick_params(length=10, width=2, labelsize=labelsize)
        ax_tsne.yaxis.set_tick_params(length=10, width=2, labelsize=labelsize)

    elif num_features == 2:
        ax_scatter = axes[num_features] if num_features < len(axes) else fig.add_subplot(rows, 3, num_features + 1)

        feature_x, feature_y = X_test.columns[0], X_test.columns[1]
        x_label = feature_name_mapping.get(feature_x, feature_x)
        y_label = feature_name_mapping.get(feature_y, feature_y)

        sns.scatterplot(
            data=df_errors,
            x=feature_x,
            y=feature_y,
            hue='classification',
            palette=palette,
            s=60,
            ax=ax_scatter
        )

        ax_scatter.set_xlabel(x_label, fontsize=fontsize)
        ax_scatter.set_ylabel(y_label, fontsize=fontsize)

        # Remove legend
        ax_scatter.legend_.remove()

        # Formatting
        sns.despine(ax=ax_scatter)
        ax_scatter.spines['bottom'].set_linewidth(2)
        ax_scatter.spines['left'].set_linewidth(2)
        ax_scatter.spines['top'].set_visible(False)
        ax_scatter.spines['right'].set_visible(False)
        ax_scatter.xaxis.set_tick_params(length=10, width=2, labelsize=labelsize)
        ax_scatter.yaxis.set_tick_params(length=10, width=2, labelsize=labelsize)

    # Remove unused subplots
    for i in range(num_features + 1, len(axes)):
        fig.delaxes(axes[i])

    plt.tight_layout()
    plt.show()

In [None]:
def plot_shap_summary(best_model, X_test, model="rf", figsize=(9, 4)):
    # Recovering the best model from the pipeline (if SMOTE/SMOTE-NC was used)
    if hasattr(best_model, 'steps'):
        best_model = best_model.steps[-1][1]

    # Create a SHAP explainer
    explainer = shap.TreeExplainer(best_model)

    # Retrieve SHAP values and feature names
    shap_values = explainer.shap_values(X_test)

    if model == "rf":
        shap_values = shap_values[1]
    elif model == "xgb":
        shap_values = shap_values
        
    feature_names = X_test.columns
    feature_values = X_test.values

    # Map the feature names using the global feature_name_mapping
    mapped_feature_names = [feature_name_mapping.get(name, name) for name in feature_names]

    # Calculate the average magnitude of the SHAP values for each feature
    feature_magnitudes = np.abs(shap_values).mean(axis=0)

    # Create a sorted order based on the magnitude of SHAP values
    sorted_indices = np.argsort(feature_magnitudes)[::-1]
    sorted_feature_names = [mapped_feature_names[i] for i in sorted_indices]

    # Sort SHAP values and feature values accordingly
    sorted_shap_values = shap_values[:, sorted_indices]
    sorted_feature_values = feature_values[:, sorted_indices]

    # Custom colormap for feature values
    cmap = LinearSegmentedColormap.from_list("custom_cmap", ["#3C89F3", "#EA3355"])

    # Compute adaptive x-axis limits separately for positive and negative values
    min_shap = np.min(shap_values)
    max_shap = np.max(shap_values)

    xlim_min = np.floor(min_shap / 0.05) * 0.05
    xlim_max = np.ceil(max_shap / 0.05) * 0.05

    # Create the plot
    plt.figure(figsize=figsize)

    # For each feature, create a scatter plot
    for i, feature_name in enumerate(sorted_feature_names):
        # Retrieve the SHAP values and feature values for each feature
        shap_vals = sorted_shap_values[:, i]
        feature_vals = sorted_feature_values[:, i]

        # Add jitter for spacing between points and reduce overlap
        jitter = np.random.normal(0, 0.1, size=len(shap_vals))
        y_pos = np.full(len(shap_vals), len(sorted_feature_names) - i + 0.5 * jitter)

        # Scatter plot with color-coded feature values
        plt.scatter(
            shap_vals, 
            y_pos, 
            c=feature_vals, 
            cmap=cmap,
            linewidth=0.5
        )

    # Configure labels and title
    plt.yticks(range(1, len(sorted_feature_names) + 1), reversed(sorted_feature_names), fontsize=16)
    plt.xlabel("SHAP value (impact on model output)", fontsize=16)
    plt.ylabel("", fontsize=16)

    # Apply custom x-axis limits with a small extension of Â± 0.02
    plt.xlim(xlim_min - 0.02, xlim_max + 0.02)

    # Apply ticks separately for negative and positive ranges
    neg_ticks = np.arange(xlim_min, 0, 0.05)
    pos_ticks = np.arange(0, xlim_max + 0.05, 0.05)
    plt.xticks(np.concatenate([neg_ticks, pos_ticks]))

    # Adjust the colorbar
    cbar = plt.colorbar()
    cbar.set_label("Feature Value", fontsize=16)
    cbar.set_ticks([0, 1])
    cbar.set_ticklabels(["Low", "High"], fontsize=14)

    # Remove ticks on the y-axis but keep the feature names
    plt.tick_params(axis='y', which='both', left=False)

    # Increase x-axis line width and tick label size
    plt.tick_params(axis='x', which='both', labelsize=16, width=2)
    plt.gca().tick_params(axis='x', width=2, length=10)

    # Remove top, left, and right borders
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['left'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().spines['bottom'].set_linewidth(2)

    # Add a vertical line at x = 0
    plt.axvline(x=0, color="black", linestyle="-", linewidth=2)

    # Display the plot with tight layout
    plt.tight_layout()
    plt.show()

In [None]:
def plot_shap_scatter(best_model, X_test, model="rf", figsize=(14, 12)):
    # Recovering the best model from the pipeline (if SMOTE/SMOTE-NC was used)
    if hasattr(best_model, 'steps'):
        best_model = best_model.steps[-1][1]

    # SHAP explainer creation
    explainer = shap.TreeExplainer(best_model)
    shap_values = explainer.shap_values(X_test)
    if model == "rf":
        shap_values = shap_values[1]
    elif model == "xgb":
        shap_values = shap_values
    feature_names = X_test.columns
    feature_values = X_test.values

    # Apply mapping, otherwise keep original name
    mapped_feature_names = [feature_name_mapping.get(name, name) for name in feature_names]

    # Sort features by absolute average importance SHAP
    feature_magnitudes = np.abs(shap_values).mean(axis=0)
    sorted_indices = np.argsort(feature_magnitudes)[::-1]
    sorted_feature_names = [mapped_feature_names[i] for i in sorted_indices]
    sorted_shap_values = shap_values[:, sorted_indices]
    sorted_feature_values = feature_values[:, sorted_indices]

    # Compute global SHAP value limits for all features
    min_shap_all = np.min(sorted_shap_values)
    max_shap_all = np.max(sorted_shap_values)
    
    # Creation of subplots (3x3)
    fig, axes = plt.subplots(3, 3, figsize=figsize)
    axes = axes.flatten()

    # Configurations des axes par feature
    feature_config = {
        "Gender": {"xticks": [0, 1], "xticklabels": ['W', 'M'], "xlim": (-0.5, 1.5)},
        "Pleasantness": {"xlim": (-5, 5), "xticks": np.arange(-5, 6, 2.5)},
        "Emotional strength": {"xlim": (0, 5), "xticks": np.arange(0, 6, 1)},
        "Intensity": {"xlim": (0, 10), "xticks": np.arange(0, 11, 2.5)},
        "Familiarity": {"xlim": (0, 10), "xticks": np.arange(0, 11, 2.5)},
        "Perceptive distance": {"xlim": (0, 12), "xticks": np.arange(0, 13, 2)},
        "Number of words": {"xlim": (0, 12), "xticks": np.arange(0, 13, 1), "xticklabels": np.arange(0, 13, 1)},
        "Semantic distance": {"xlim": (0, 1), "xticks": np.arange(0, 1.25, 0.25)},
    }

    # Plot graphs
    for i, feature_name in enumerate(sorted_feature_names):
        if i >= len(axes):
            break

        ax = axes[i]
        shap_vals = sorted_shap_values[:, i]
        feature_vals = sorted_feature_values[:, i]

        # Scatter plot
        ax.scatter(feature_vals, shap_vals, color='black', linewidth=0.5)
        ax.axhline(y=0, color='grey', linestyle='--', linewidth=2)

        # Labels
        ax.set_xlabel(feature_name, fontsize=16)
        ax.set_ylabel("SHAP value", fontsize=16)

        # Use global SHAP value limits for Y axis
        y_lim_neg = np.floor(min_shap_all / 0.05) * 0.05 - 0.02
        y_lim_pos = np.ceil(max_shap_all / 0.05) * 0.05 + 0.02
        ax.set_ylim(y_lim_neg, y_lim_pos)

        # X axis configuration according to feature
        config = feature_config.get(feature_name, {})
        if "xlim" in config:
            # Check if the feature is "Semantic distance"
            if feature_name == "Semantic distance":
                ax.set_xlim(config["xlim"][0] - 0.02, config["xlim"][1] + 0.02)
            else:
                ax.set_xlim(config["xlim"][0] - 0.2, config["xlim"][1] + 0.2)
        if "xticks" in config:
            ax.set_xticks(config["xticks"])
        if "xticklabels" in config:
            ax.set_xticklabels(config["xticklabels"], fontsize=16)
        
        # For "Number of words", labels 1/2 tick
        if feature_name == "Number of words":
            labels = ax.get_xticklabels()
            for j in range(1, len(labels), 2):  # Hide every other label
                labels[j] = ''
            ax.set_xticklabels(labels)

        # Customize axes
        ax.tick_params(axis='x', length=10, width=2, labelsize=16)
        ax.tick_params(axis='y', length=10, width=2, labelsize=16)

        # Remove unnecessary borders
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['left'].set_linewidth(2)
        ax.spines['bottom'].set_linewidth(2)

    # Remove empty subplots if fewer than 9 features
    for j in range(i + 1, len(axes)):
        fig.delaxes(axes[j])

    # Display adjustment
    plt.tight_layout(pad=3)
    plt.show()

In [None]:
def plot_shap_waterfall(best_model, X_test, trial_index, predicted_class, model="rf"):
    # Recovering the best model from the pipeline (if SMOTE/SMOTE-NC was used)
    if hasattr(best_model, 'steps'):
        best_model = best_model.steps[-1][1]
    
    # Compute SHAP values
    explainer = shap.TreeExplainer(best_model)
    shap_values = explainer.shap_values(X_test)
    if model == "rf":
        shap_vals_trial = shap_values[predicted_class][trial_index]
        base_value = explainer.expected_value[1]
    elif model == "xgb":
        shap_vals_trial = shap_values[trial_index]
        base_value = explainer.expected_value
    predicted_value = base_value + shap_vals_trial.sum()

    # Map feature names
    original_feature_names = list(X_test.columns)
    mapped_feature_names = [feature_name_mapping.get(name, name) for name in original_feature_names]

    # Sort SHAP values by absolute magnitude
    sorted_indices = sorted(range(len(shap_vals_trial)), key=lambda i: abs(shap_vals_trial[i]))
    shap_vals_trial_sorted = [shap_vals_trial[i] for i in sorted_indices]
    custom_feature_names_sorted = [mapped_feature_names[i] for i in sorted_indices]

    # Calculation of cumulative values of f(x)
    cumulative_values = [base_value]
    current_value = base_value

    for shap_value in shap_vals_trial_sorted:
        current_value += shap_value
        cumulative_values.append(current_value)

    # Adjustment of X-axis terminals according to values reached
    x_min = min(cumulative_values) - 0.01
    x_max = max(cumulative_values) + 0.01

    # Plot settings
    fig, ax = plt.subplots(figsize=(9, 4))
    position = base_value
    bar_height = 0.8
    arrow_head_fraction = 0.1

    arrow_end_positions = []
    arrow_centers = []

    for i, shap_value in enumerate(shap_vals_trial_sorted):
        color = "#EA3355" if shap_value > 0 else "#3C89F3"
        y_center = i
        y_bottom = y_center - bar_height / 2
        y_top = y_center + bar_height / 2
        
        x_start = position
        x_end = position + shap_value
        arrow_width = arrow_head_fraction * abs(shap_value)
        rectangle_end = x_end - arrow_width if shap_value > 0 else x_end + arrow_width

        arrow_vertices = [
            (x_start, y_bottom), (rectangle_end, y_bottom),
            (x_end, y_center), (rectangle_end, y_top), (x_start, y_top)
        ]

        arrow = Polygon(arrow_vertices, closed=True, facecolor=color, edgecolor='none')
        ax.add_patch(arrow)

        text_x = (x_start + rectangle_end) / 2
        text_ha = "center"

        if abs(shap_value) < 0.016:
            text_x = x_end + (0.001 if shap_value > 0 else -0.001)
            text_ha = "left" if shap_value > 0 else "right"

        shap_text = f"$\\bf{{{shap_value:+.2f}}}$"

        ax.text(
            text_x, y_center, shap_text, ha=text_ha, va='center',
            color="white" if abs(shap_value) >= 0.016 else color, fontsize=12
        )

        position += shap_value
        arrow_end_positions.append(x_end)
        arrow_centers.append(y_center)

    for i in range(len(arrow_centers) - 1):
        ax.plot([arrow_end_positions[i], arrow_end_positions[i]], 
                [arrow_centers[i], arrow_centers[i + 1]], color='grey', linestyle='--', linewidth=1)

    ax.axvline(predicted_value, color='grey', linestyle='--', linewidth=2)
    ax.text(predicted_value, len(shap_vals_trial_sorted) + 0.2, 
            f"$f(x) = {predicted_value:.3f}$", ha='left', va='center', fontsize=16, color='black')

    # Base_value limited to the height of features
    ax.plot([base_value, base_value], [-0.5, 0.5],
        color='grey', linestyle='--', linewidth=1)

    ax.text(base_value, -2, f"$E[f(X)] = {base_value:.3f}$", ha='center', va='center', fontsize=16, color='black')

    ax.set_yticks(range(len(shap_vals_trial_sorted)))
    ax.set_yticklabels([
        f"{custom_feature_names_sorted[i]} = {X_test.iloc[trial_index, sorted_indices[i]]:.2f}"
        if custom_feature_names_sorted[i] in ["Semantic distance", "Perceptive distance"] 
        else f"{custom_feature_names_sorted[i]} = {X_test.iloc[trial_index, sorted_indices[i]]}"
        for i in range(len(shap_vals_trial_sorted))
    ], fontsize=16, color='black')

    ax.set_xlim([x_min, x_max])
    ax.tick_params(axis='x', which='both', bottom=True, width=2, length=10, labelsize=16)
    ax.tick_params(axis='y', which='both', length=0)
    ax.set_ylim(-0.5, len(shap_vals_trial_sorted))
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)

    plt.tight_layout()
    plt.show()

    print(f"Actual test value {trial_index}: {y_test.iloc[trial_index]}")

# <a id='toc3_'></a>[Data loading & Preprocessing](#toc0_)

## <a id='toc3_1_'></a>[Loading](#toc0_)

In [None]:
# Get the parent directory of the current working directory
project_dir = Path.cwd().parent

# Define the path to the data folder
data_folder = project_dir / "data"

# Define the path to the CSV file
file_path = data_folder / "dataset.csv"

# Read the CSV file
df = pd.read_csv(file_path)

# Display first lines
df.head()

## <a id='toc3_2_'></a>[Condition selection](#toc0_)

In [None]:
# Select a condition
condition = "odor_recognition"  # "odor_recognition" or "associative_memory"

# Initialize variables according to condition
if condition == "odor_recognition":
    data = df[df["is_target"] == 1].copy()
    data["outcome"] = data["hit"]
    data["percept_dist"] = data["avg_distance_target"]
    data["jaccard_dist"] = data["mean_lemma_jaccard_target"]
    fillcolor = "#5187EC"
    final_features = ["gender_encoded", "pleasantness", "emotional_strength", "intensity", "familiarity", "nb_words", "jaccard_dist"]
    palette = {
        'Correct Class 0': '#F39C12', 
        'Misclassified 1->0': '#C75F1D',
        'Correct Class 1': '#3C89F3',
        'Misclassified 0->1': '#1F4E79'
    }

elif condition == "associative_memory":
    data = df[df["hit"] == 1].copy()
    data["outcome"] = data["mem"]
    data["percept_dist"] = data["avg_distance_hit"]
    data["jaccard_dist"] = data["mean_lemma_jaccard_hit"]
    fillcolor = "#5EACA3"
    first_feature_sel = ["emotional_strength", "familiarity", "percept_dist", "nb_words", "jaccard_dist"]
    second_feature_sel = ["familiarity", "percept_dist", "nb_words", "jaccard_dist"]
    final_features = ["familiarity", "jaccard_dist"]
    palette = {
        'Correct Class 0': '#F2645A', 
        'Misclassified 1->0': '#A73C35',
        'Correct Class 1': '#3CAEA3',
        'Misclassified 0->1': '#0F524B'
    }

## <a id='toc3_3_'></a>[Info on data](#toc0_)

### <a id='toc3_3_1_'></a>[Check for unique values (n of sub, n of odor, ...)](#toc0_)

In [None]:
data.nunique()

### <a id='toc3_3_2_'></a>[Target distribution](#toc0_)

In [None]:
# Target distribution
print(data["outcome"].value_counts())

## <a id='toc3_4_'></a>[Features/Target split](#toc0_)

In [None]:
# Indicate numerical and categorical features
features = [
    'gender_encoded',
    'pleasantness',
    'emotional_strength',
    'intensity',
    'familiarity', 
    'percept_dist',
    'nb_words',
    'jaccard_dist',
    ]

X = data.loc[:, features]
y = data["outcome"]

groups = data["participant"]

In [None]:
X

In [None]:
y

## <a id='toc3_5_'></a>[Train/Test split](#toc0_)

Our data includes multiple rows for each participant. As our aim is to make predictions for new participants, then testing on rows from participants who also have rows in the training set may be optimistically biased. After all, the responses of the same participant are bound to be more similar than those of two different participants. Using StratifiedGroupKFold ensures that participants remain grouped and that the outcome distribution is maintained in each fold. 

In [None]:
# Outcome repartition in y
y.value_counts(normalize=True)

In [None]:
# StratifiedGroupKFold initialization
splitter = StratifiedGroupKFold(n_splits = 5)

# Obtaining training and test set indices
train_indx, test_indx = next(splitter.split(X, y, groups))

# Length of train and test
print(f"Train length: n={len(train_indx)}")
print(f"Test length: n={len(test_indx)}")

In [None]:
# Split X into train & test
X_train = X.iloc[train_indx]
X_test = X.iloc[test_indx]

# Split y into train & test
y_train = y.iloc[train_indx]
y_test = y.iloc[test_indx]

# Split groups into train & test
groups_train = groups.iloc[train_indx]
groups_test = groups.iloc[test_indx]

# Outcome repartition in y_train
print(f"Outcome repartition in y_train : {y_train.value_counts(normalize=True)}")
print()

# Outcome repartition in y_test
print(f"Outcome repartition in y_test : {y_test.value_counts(normalize=True)}")

In [None]:
# Participants ID in the training set
unique_groups_train = groups_train.unique()
print(f"Participants in groups_train: {unique_groups_train}")
print()

# Participants ID in the test set
unique_groups_test = groups_test.unique()
print(f"Participants in groups_test: {unique_groups_test}")
print()

# Verif
common_groups = set(unique_groups_train).intersection(unique_groups_test)

if common_groups:
    print(f"Participants in both groups: {common_groups}")
else:
    print("There is no common participants in groups_train and groups_test.")

## <a id='toc3_6_'></a>[Data Shuffling](#toc0_)

Use shuffle to randomly shuffle the data to ensure that the order of the samples does not introduce unwanted bias into the learning process.

In [None]:
X_train_sh, y_train_sh, groups_train_sh = shuffle(X_train, y_train, groups_train, random_state=42)

In [None]:
# Reset index of training and test sets
X_train_ri, X_test_ri = X_train_sh.reset_index(drop=True), X_test.reset_index(drop=True)
y_train_ri, y_test_ri = y_train_sh.reset_index(drop=True), y_test.reset_index(drop=True)
groups_train_ri = groups_train_sh.reset_index(drop=True)

# <a id='toc4_'></a>[Features selection](#toc0_)

In [None]:
# All the features
X_train_ri_all = X_train_ri.copy()
X_test_ri_all = X_test_ri.copy()

# Final features
X_train_ri_final = X_train_ri[final_features]
X_test_ri_final = X_test_ri[final_features]

# <a id='toc5_'></a>[Base model](#toc0_)

In [None]:
# Training
base_model, results_bm = perform_random_forest(
    X_train_ri_all,
    y_train_ri,
    groups_train_ri,
    model_type="default_rf",
    tuning=None,
    categorical_features="Yes"
    )

# Permutation test
permutation_test_with_plot(
    model=base_model,
    X=X_train_ri_all,
    y=y_train_ri,
    groups=groups_train_ri
)

# Model evaluation
evaluate_model(base_model, X_test_ri_all, y_test_ri)

# <a id='toc6_'></a>[All features - models with tuning](#toc0_)

In [None]:
# Training
first_brf, results_first_brf = perform_random_forest(
    X_train_ri_all,
    y_train_ri,
    groups_train_ri,
    model_type="balanced_rf",
    tuning=True,
    categorical_features="Yes"
    )

# Permutation test
permutation_test_with_plot(
    model=first_brf,
    X=X_train_ri_all,
    y=y_train_ri,
    groups=groups_train_ri
)

# Model evaluation
evaluate_model(first_brf, X_test_ri_all, y_test_ri)

# Feature selection
shap_feature_selection(first_brf, X_train_ri_all, y_train_ri, model="rf")

In [None]:
if condition == "odor_recognition":
    model_type = "smote_rf"
elif condition == "associative_memory":
    model_type = "default_rf"

# Training
first_smrf, results_first_smrf = perform_random_forest(
    X_train_ri_all,
    y_train_ri,
    groups_train_ri,
    model_type=model_type,
    tuning=True,
    categorical_features="Yes"
    )

# Permutation test
permutation_test_with_plot(
    model=first_smrf,
    X=X_train_ri_all,
    y=y_train_ri,
    groups=groups_train_ri
)

# Model evaluation
evaluate_model(first_smrf, X_test_ri_all, y_test_ri)

# Feature selection
shap_feature_selection(first_smrf, X_train_ri_all, y_train_ri, model="rf")

In [None]:
# Training
first_xgb, results_first_xgb = perform_xgboost(
    X_train_ri_all,
    y_train_ri,
    groups_train_ri
    )

# Permutation test
permutation_test_with_plot(
    model=first_xgb,
    X=X_train_ri_all,
    y=y_train_ri,
    groups=groups_train_ri
)

# Model evaluation
evaluate_model(first_xgb, X_test_ri_all, y_test_ri)

# Feature selection
shap_feature_selection(first_xgb, X_train_ri_all, y_train_ri, model="xgb")

# <a id='toc7_'></a>[Final models](#toc0_)

In [None]:
# Training
final_brf, results_final_brf = perform_random_forest(
    X_train_ri_final,
    y_train_ri,
    groups_train_ri,
    model_type="balanced_rf",
    tuning=True,
    categorical_features="Yes"
    )

# Permutation test
permutation_test_with_plot(
    model=final_brf,
    X=X_train_ri_final,
    y=y_train_ri,
    groups=groups_train_ri
)

# Model evaluation
evaluate_model(final_brf, X_test_ri_final, y_test_ri)

# Error analysis
error_analysis(final_brf, X_test_ri_final, y_test_ri)

# Interpretability
plot_shap_summary(final_brf, X_test_ri_final, model="rf" , figsize=(9, 4))

plot_shap_scatter(final_brf, X_test_ri_final, model="rf")

plot_shap_waterfall(final_brf, X_test_ri_final, trial_index=70, predicted_class=1, model="rf")

In [None]:
# Training
final_smrf, results_final_smrf = perform_random_forest(
    X_train_ri_final,
    y_train_ri,
    groups_train_ri,
    model_type=model_type,
    tuning=True,
    categorical_features="Yes"
    )

# Permutation test
permutation_test_with_plot(
    model=final_smrf,
    X=X_train_ri_final,
    y=y_train_ri,
    groups=groups_train_ri
)

# Model evaluation
evaluate_model(final_smrf, X_test_ri_final, y_test_ri)

# Error analysis
error_analysis(final_smrf, X_test_ri_final, y_test_ri)

# Interpretability
plot_shap_summary(final_smrf, X_test_ri_final, model="rf", figsize=(9, 4))

plot_shap_scatter(final_smrf, X_test_ri_final, model="rf")

plot_shap_waterfall(final_smrf, X_test_ri_final, trial_index=70, predicted_class=1, model="rf")

In [None]:
# Training
final_xgb, results_final_xgb = perform_xgboost(
    X_train_ri_final,
    y_train_ri,
    groups_train_ri
    )

# Permutation test
permutation_test_with_plot(
    model=final_xgb,
    X=X_train_ri_final,
    y=y_train_ri,
    groups=groups_train_ri
)

# Model evaluation
evaluate_model(final_xgb, X_test_ri_final, y_test_ri)

# Error analysis
error_analysis(final_xgb, X_test_ri_final, y_test_ri)

# Interpretability
plot_shap_summary(final_xgb, X_test_ri_final, model="xgb", figsize=(9, 4))

plot_shap_scatter(final_xgb, X_test_ri_final, model="xgb")

plot_shap_waterfall(final_xgb, X_test_ri_final, trial_index=70, predicted_class=1, model="xgb")

# nb: XGBoost gives a log-odds score, not direct probabilities.