# Intro on this code:
Guidelines to install python and jupyter notebook HERE: https://github.com/PaulCarrascosa/LMT_Widget_Tool-LWT?tab=readme-ov-file
## packages to install in python (or anaconda) console ('pip install xxx') (or 'conda install xxx' if pip is not working):
- pandas
- scikit-learn
- matplotlib
- statsmodels

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import ttest_ind, norm
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, roc_curve
import ipywidgets as widgets
from IPython.display import display, clear_output
import os
import io

# Global variables
data = None
columns = []
selected_columns = []
column_widgets = []
group_column = "ADHD"
group_values = {"Group 0": 0, "Group 1": 1}
results_path = "results"

# To store results from the last histogram run
histogram_figures = []
histogram_ttest_results = []

# To store results from the last ROC run
roc_figures = []
roc_feature_importance = None

def create_folders(base_path):
    os.makedirs(base_path, exist_ok=True)
    os.makedirs(f"{base_path}/figures", exist_ok=True)
    os.makedirs(f"{base_path}/stats", exist_ok=True)

file_upload = widgets.FileUpload(accept='.csv,.xlsx', multiple=False)
output = widgets.Output()

group_col_text = widgets.Text(value="ADHD", description="Group Column:")
group_val0_text = widgets.IntText(value=0, description="Group 0 Value:")
group_val1_text = widgets.IntText(value=1, description="Group 1 Value:")

plot_button = widgets.Button(description="Plot Histograms")
run_analysis_button = widgets.Button(description="Run ROC-AUC")
reset_button = widgets.Button(description="Reset")

# Buttons for saving results after analyses
save_histogram_button = widgets.Button(description="Save Histogram Results")
save_roc_button = widgets.Button(description="Save ROC Results")

def update_group_settings(*args):
    global group_column, group_values
    group_column = group_col_text.value.strip()
    group_values = {"Group 0": group_val0_text.value, "Group 1": group_val1_text.value}

group_col_text.observe(update_group_settings, names='value')
group_val0_text.observe(update_group_settings, names='value')
group_val1_text.observe(update_group_settings, names='value')

def upload_file(change):
    global data, columns, column_widgets

    with output:
        clear_output()
        if file_upload.value:
            uploaded_file = file_upload.value[0]
            content = uploaded_file['content']
            file_name = uploaded_file['name']

            try:
                if file_name.endswith('.csv'):
                    data_ = pd.read_csv(io.BytesIO(content), sep=';')
                elif file_name.endswith('.xlsx'):
                    data_ = pd.read_excel(io.BytesIO(content))
                else:
                    print("Unsupported file format!")
                    return

                set_data(data_)
                print(f"File '{file_name}' uploaded successfully.")
            except Exception as e:
                print(f"Error processing file: {e}")
        else:
            print("No file uploaded.")

def set_data(df):
    global data, columns, selected_columns, column_widgets
    data = df
    columns = [col for col in data.columns if not col.startswith("Unnamed")]
    selected_columns = []
    column_widgets = []
    display_columns_widget()

def display_columns_widget():
    global column_widgets, selected_columns
    column_widgets = [widgets.Checkbox(value=False, description=col) for col in columns]
    half = len(column_widgets) // 2 + len(column_widgets) % 2
    col1 = widgets.VBox(column_widgets[:half])
    col2 = widgets.VBox(column_widgets[half:])

    grid = widgets.HBox([col1, col2])

    select_all_button = widgets.Button(description="Select All")
    validate_button = widgets.Button(description="Validate Selection")

    def select_all(change=None):
        for checkbox in column_widgets:
            checkbox.value = True
        update_selected_columns()

    def validate_selection(change=None):
        update_selected_columns()
        with output:
            clear_output()
            print(f"Validated columns: {selected_columns}")

    select_all_button.on_click(select_all)
    validate_button.on_click(validate_selection)

    with output:
        clear_output()
        display(widgets.VBox([grid, widgets.HBox([select_all_button, validate_button])]))

def update_selected_columns():
    global selected_columns
    selected_columns = [checkbox.description for checkbox in column_widgets if checkbox.value]

def plot_histograms(change=None):
    global data, selected_columns, group_column, group_values, histogram_figures, histogram_ttest_results
    with output:
        clear_output()
        histogram_figures = []
        histogram_ttest_results = []
        if data is None or not selected_columns:
            print("No data or no columns selected.")
            return
        if group_column not in data.columns:
            print("Group column not found in dataset.")
            return

        group_0 = data[data[group_column] == group_values["Group 0"]]
        group_1 = data[data[group_column] == group_values["Group 1"]]

        if group_0.empty or group_1.empty:
            print("One or both groups are empty.")
            return

        ttest_records = []
        for column in selected_columns:
            if column not in data.columns:
                print(f"Column {column} not found.")
                continue
            if not pd.api.types.is_numeric_dtype(data[column]):
                print(f"Column {column} is not numeric.")
                continue

            col_data_0 = group_0[column].dropna()
            col_data_1 = group_1[column].dropna()

            fig = plt.figure(figsize=(6,4))
            plt.hist(col_data_0, bins=15, alpha=0.6, color='blue', label="Group 0", density=True)
            plt.hist(col_data_1, bins=15, alpha=0.6, color='red', label="Group 1", density=True)

            # Gaussian fit
            if len(col_data_0) > 1:
                mu0, std0 = col_data_0.mean(), col_data_0.std()
                x0 = np.linspace(col_data_0.min(), col_data_0.max(), 100)
                plt.plot(x0, norm.pdf(x0, mu0, std0), 'b--')

            if len(col_data_1) > 1:
                mu1, std1 = col_data_1.mean(), col_data_1.std()
                x1 = np.linspace(col_data_1.min(), col_data_1.max(), 100)
                plt.plot(x1, norm.pdf(x1, mu1, std1), 'r--')

            plt.title(f"Histogram: {column}")
            plt.xlabel(column)
            plt.ylabel("Density")
            plt.legend()
            plt.grid()
            plt.show()

            # Store fig in memory
            histogram_figures.append((fig, f"histogram_{column}.png"))

            # t-test
            if len(col_data_0) > 1 and len(col_data_1) > 1:
                t_stat, p_val = ttest_ind(col_data_0, col_data_1, nan_policy='omit')
                print(f"T-test for {column}: t={t_stat:.3f}, p={p_val:.3e}")
                ttest_records.append({"Variable": column, "t_stat": t_stat, "p_val": p_val})
            else:
                print(f"Not enough data for t-test on {column}")

        if ttest_records:
            histogram_ttest_results = ttest_records

        # After plotting histograms and printing t-tests, show save button
        display(save_histogram_button)

def save_histogram_results(change=None):
    with output:
        clear_output()
        create_folders(results_path)
        # Save histogram figures
        for fig, filename in histogram_figures:
            fig.savefig(f"{results_path}/figures/{filename}")
        # Save t-test results
        if histogram_ttest_results:
            df = pd.DataFrame(histogram_ttest_results)
            df.to_csv(f"{results_path}/stats/histogram_ttests.csv", index=False)
            print("Histogram figures and t-test results saved.")
        else:
            print("Histogram figures saved. No t-test results to save.")

def run_roc_auc(change=None):
    target = widgets.Text(value="ADHD", description="Target Column:")
    initial_execute_button = widgets.Button(description="Execute Initial ROC-AUC")
    refine_button = widgets.Button(description="Refine Features")
    reexecute_button = widgets.Button(description="Re-execute ROC-AUC with Refined Features")

    refined_feature_checkboxes = []
    current_predictors = []

    def execute_roc_auc_func(predictors, title_suffix="", show_refine=True):
        global roc_figures, roc_feature_importance
        with output:
            clear_output()
            roc_figures = []
            roc_feature_importance = None
            target_col = target.value

            if data is None or not predictors or target_col not in data.columns:
                print("Data, predictors, or target column are invalid.")
                return

            if group_column not in data.columns:
                print("Group column not found in the dataset.")
                return

            group_0 = data[data[group_column] == group_values["Group 0"]]
            group_1 = data[data[group_column] == group_values["Group 1"]]

            combined_data = pd.concat([group_0, group_1])
            X = combined_data[predictors]
            y = (combined_data[group_column] == group_values["Group 1"]).astype(int)

            if X.empty or y.empty:
                print("No valid data for ROC.")
                return

            if X.isnull().values.any():
                combined_data = combined_data.dropna(subset=predictors)
                X = combined_data[predictors]
                y = (combined_data[group_column] == group_values["Group 1"]).astype(int)

            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

            log_model = LogisticRegression(max_iter=1000)
            log_model.fit(X_train, y_train)
            y_prob_log = log_model.predict_proba(X_test)[:, 1]
            log_auc = roc_auc_score(y_test, y_prob_log)

            rf_model = RandomForestClassifier(n_estimators=100, random_state=42)
            rf_model.fit(X_train, y_train)
            y_prob_rf = rf_model.predict_proba(X_test)[:, 1]
            rf_auc = roc_auc_score(y_test, y_prob_rf)

            print(f"Logistic Regression AUC-ROC{title_suffix}: {log_auc:.3f}")
            print(f"Random Forest AUC-ROC{title_suffix}: {rf_auc:.3f}")

            fpr_log, tpr_log, _ = roc_curve(y_test, y_prob_log)
            fpr_rf, tpr_rf, _ = roc_curve(y_test, y_prob_rf)

            fig_roc = plt.figure(figsize=(8,6))
            plt.plot(fpr_log, tpr_log, label=f"Logistic Regression (AUC = {log_auc:.2f})")
            plt.plot(fpr_rf, tpr_rf, label=f"Random Forest (AUC = {rf_auc:.2f})")
            plt.plot([0, 1], [0, 1], 'k--', label='Random')
            plt.xlabel("False Positive Rate")
            plt.ylabel("True Positive Rate")
            plt.title(f"ROC-AUC Curves{title_suffix}")
            plt.legend(loc="lower right")
            plt.grid()
            plt.show()

            # Store figure
            roc_figures.append((fig_roc, f"roc_auc_curve{title_suffix}.png"))

            importances = rf_model.feature_importances_
            importance_df = pd.DataFrame({'Feature': X.columns, 'Importance': importances})
            importance_df = importance_df.sort_values(by='Importance', ascending=False)

            print("\nFeature Importance (Random Forest):")
            print(importance_df)
            roc_feature_importance = importance_df

            fig_feat = plt.figure(figsize=(8,6))
            plt.barh(importance_df['Feature'], importance_df['Importance'], color='skyblue')
            plt.xlabel("Importance")
            plt.title(f"Feature Importance{title_suffix}")
            plt.gca().invert_yaxis()
            plt.grid(axis='x')
            plt.show()

            roc_figures.append((fig_feat, f"feature_importance{title_suffix}.png"))

            if show_refine:
                display(refine_button)
            # Show save results button
            display(save_roc_button)

    def initial_execution(change):
        nonlocal current_predictors
        current_predictors = selected_columns[:]
        execute_roc_auc_func(current_predictors, title_suffix="", show_refine=True)

    def refine_features(change):
        with output:
            clear_output()
            refined_feature_checkboxes.clear()
            for col in current_predictors:
                refined_feature_checkboxes.append(widgets.Checkbox(value=True, description=col))
            half = len(refined_feature_checkboxes) // 2 + len(refined_feature_checkboxes) % 2
            col1 = widgets.VBox(refined_feature_checkboxes[:half])
            col2 = widgets.VBox(refined_feature_checkboxes[half:])
            grid = widgets.HBox([col1, col2])
            display(widgets.VBox([widgets.Label("Select features for refined ROC AUC:"), grid, reexecute_button]))

    def reexecute_roc_auc_func(change):
        nonlocal current_predictors
        chosen_features = [cb.description for cb in refined_feature_checkboxes if cb.value]
        current_predictors = chosen_features
        execute_roc_auc_func(current_predictors, title_suffix=" (Refined)", show_refine=True)

    initial_execute_button.on_click(initial_execution)
    refine_button.on_click(refine_features)
    reexecute_button.on_click(reexecute_roc_auc_func)

    with output:
        clear_output()
        display(widgets.VBox([target, initial_execute_button]))

def save_roc_results(change=None):
    with output:
        clear_output()
        create_folders(results_path)
        # Save ROC figures
        for fig, filename in roc_figures:
            fig.savefig(f"{results_path}/figures/{filename}")
        # Save feature importance
        if roc_feature_importance is not None:
            roc_feature_importance.to_csv(f"{results_path}/stats/feature_importance.csv", index=False)
        print("ROC figures and feature importance saved.")

def reset_interface(change=None):
    global selected_columns, column_widgets, data
    with output:
        clear_output()
    if data is not None:
        set_data(data)  # Re-initialize columns and selection from memory
    else:
        print("No data in memory. Please upload a file first.")

reset_button.on_click(reset_interface)
save_histogram_button.on_click(save_histogram_results)
save_roc_button.on_click(save_roc_results)

file_upload.observe(upload_file, names='value')

ui = widgets.VBox([
    widgets.Label("Upload your dataset and select columns for analysis."),
    file_upload,
    group_col_text,
    widgets.HBox([group_val0_text, group_val1_text]),
    plot_button,
    run_analysis_button,
    reset_button,
    output
])

plot_button.on_click(plot_histograms)
run_analysis_button.on_click(run_roc_auc)

display(ui)


VBox(children=(Label(value='Upload your dataset and select columns for analysis.'), FileUpload(value=(), accep…