In [1]:
from pathlib import Path
import pandas as pd
from scipy.stats import false_discovery_control
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output
import dash_bootstrap_components as dbc
import sys

sys.path.append("../../")
from lib.general import get_stage_list
from lib.stats import (
    demographic_characteristics,
    multiple_linear_regression,
    mask_outlier,
)
from lib.plotly import standard_layout

### Input

In [2]:
# Define input paths
path_demographics: Path = Path(
    "../../../data/processed/adni/demographics_biomarkers.csv"
).resolve()
path_lipidomics: Path = Path(
    "../../../data/processed/adni/lipidomics_total.csv"
).resolve()
path_lipidomics_dict: Path = Path(
    "../../../data/processed/adni/lipidomics_dict.csv"
).resolve()
path_output_figure: Path = Path("../../../assets/figures/adni/").resolve()

In [3]:
# Read files
demographics: pd.DataFrame = pd.read_csv(path_demographics, low_memory=False)
lipidomics: pd.DataFrame = pd.read_csv(path_lipidomics, low_memory=False)
lipidomics_dict: pd.DataFrame = pd.read_csv(
    path_lipidomics_dict, low_memory=False
).convert_dtypes()

In [4]:
# Join dataframes
df: pd.DataFrame = demographics.join(lipidomics.set_index("RID"), on="RID", how="inner")

In [5]:
# Define lists
stage_list: list[str] = get_stage_list(2)
lipid_list: list[str] = lipidomics_dict["lipid_class"].unique().tolist()

### Processing

In [6]:
# Compute statistics table
df_stats: pd.DataFrame = demographic_characteristics(df, stage_list)

In [7]:
# Define the list of comparison options
comparison_options: list[str] = ["none"] + list(
    df.select_dtypes(include=[bool]).columns
)

In [8]:
# Function to compute the 10 most significant lipid classes for a given stage
def generate_list_of_lipid_classes(stage: str) -> list[str]:
    # Filter the data based on the given stage
    subset: pd.DataFrame = df.loc[df["stage"] == stage]

    # Perform multiple linear regression
    linreg: pd.DataFrame = multiple_linear_regression(
        subset, lipid_list, "strem2_log10"
    )

    # Perform B-H false discovery control
    linreg["bh_adj_p_value"] = false_discovery_control(
        linreg["p_value"], axis=0, method="bh"
    )

    # Sort the results by ascending adjusted p-value
    linreg: pd.DataFrame = linreg.sort_values(
        by=["bh_adj_p_value", "p_value"], ascending=True
    ).head(10)

    # Return the list of lipid classes with their adjusted p-values
    return [
        f"p={p_value:.4f} {predictor}"
        for predictor, p_value in zip(linreg["predictor"], linreg["bh_adj_p_value"])
    ]

In [9]:
# Generate the 10 most significant lipid classes for each stage
all_options: dict[str, list[str]] = {
    stage: generate_list_of_lipid_classes(stage) for stage in stage_list
}

In [10]:
# Instantiate the list of annotation dataframes
# Order: [none, cond1-False, cond1-True, cond2-False, cond2-True, ...]
annot_df_list: list[tuple[pd.DataFrame, pd.DataFrame]] = []

# For each comparison option and each condition, compute the p-values and R^2 values for each lipid class
for ii in range(1, 2 * len(comparison_options)):
    # Determine the comparison option
    comparison: str = comparison_options[ii // 2]

    # Initialize two dataframes to store the p-values and R^2 values
    pval_annot_df: pd.DataFrame = pd.DataFrame(index=lipid_list, columns=stage_list)
    rsqr_annot_df: pd.DataFrame = pd.DataFrame(index=lipid_list, columns=stage_list)

    # For each stage, perform multiple linear regression and store the results
    for stage in stage_list:
        if comparison == "none":
            # Filter the data based on the given stage
            subset: pd.DataFrame = df.loc[df["stage"] == stage]
            # Perform multiple linear regression
            linreg: pd.DataFrame = multiple_linear_regression(
                subset, lipid_list, "strem2_log10"
            )
            # Perform B-H false discovery control
            linreg["bh_adj_p_value"] = false_discovery_control(
                linreg["p_value"], axis=0, method="bh"
            )
            # Sort the results by ascending adjusted p-value
            linreg: pd.DataFrame = linreg.sort_values(
                by=["bh_adj_p_value", "p_value"], ascending=True
            ).set_index("predictor")
        else:
            # Filter the data based on the given stage and comparison condition
            subset: pd.DataFrame = df.loc[
                (df["stage"] == stage) & (df[comparison] == bool(ii % 2))
            ]
            # Perform multiple linear regression
            linreg: pd.DataFrame = multiple_linear_regression(
                subset, lipid_list, "strem2_log10"
            )
            # Perform B-H false discovery control
            linreg["bh_adj_p_value"] = false_discovery_control(
                linreg["p_value"], axis=0, method="bh"
            )
            # Sort the results by ascending adjusted p-value
            linreg: pd.DataFrame = linreg.sort_values(
                by=["bh_adj_p_value", "p_value"], ascending=True
            ).set_index("predictor")
        # Store the p-values and R^2 values in the dataframes
        pval_annot_df[stage] = linreg["bh_adj_p_value"]
        rsqr_annot_df[stage] = linreg["rsq"]
    # Append the dataframes to the list
    annot_df_list.append((pval_annot_df, rsqr_annot_df))

In [11]:
# fmt: off
app: Dash = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

app.layout = html.Div([
    html.Hr(),
    dbc.Row([
        dbc.Col([
            dbc.Label(children="Select a stage:"),
            dcc.RadioItems(options=stage_list, value=stage_list[1], id="stage", inline=False),
            html.Hr(),
            dbc.Label(children="Select a condition to compare:"),
            dbc.RadioItems(options=comparison_options, value="none", id="comparison", inline=True),
            html.Hr(),
            dbc.Switch(label="Toggle outlier removal", id="outlier_removal", value=False),
        ], width={"size": 3, "offset": 1}),
        dbc.Col([
            dbc.Label(id="lipid_label"),
            dbc.RadioItems(id="lipid", inline=False),
        ], width=4),
        dbc.Col([
            dbc.Label(children="Adjust height"),
            dcc.Slider(id="height", min=100, max=600, step=25, value=350, marks={x: str(x) for x in range(100, 600 + 200, 200)}),
            dbc.Label(children="Adjust width"),
            dcc.Slider(id="width", min=400, max=1400, step=25, value=800, marks={x: str(x) for x in range(400, 1400 + 200, 200)}),
            html.Hr(),
            dbc.Button(children="Download as PDF+PNG", id="download", n_clicks=0),
        ], width=3),
    ]),
    html.Hr(),
    dbc.Col(dcc.Graph(id="graph"), width={"size": 6, "offset": 2}),
], style={"backgroundColor": "white"})

In [12]:
# fmt: off
@app.callback(
    Output("lipid_label", "children"),
    Output("lipid", "options"),
    Output("lipid", "value"),
    Input("stage", "value"))
def update_lipid_options(stage: str) -> tuple[str, list[str], str]:
    return f"Lipids with the most significant p-values in {stage}", all_options[stage], all_options[stage][0]

@app.callback(
    Output("graph", "figure"),
    Input("stage", "value"),
    Input("lipid", "value"),
    Input("comparison", "value"),
    Input("outlier_removal", "value"),
    Input("height", "value"),
    Input("width", "value"),
    Input("download", "n_clicks"))
def update_graph(stage: str, lipid_label: str, comparison: str, outlier_removal: bool, height: int, width: int, download: int) -> go.Figure:
    # Extract the lipid class name
    lipid: str = lipid_label.split(" ")[1]
    
    # Remove outliers if checked
    if outlier_removal:
        df_plot: pd.DataFrame = df.loc[mask_outlier(df[[lipid, "strem2_log10"]])]
    else:
        df_plot: pd.DataFrame = df
        
    # Generate the scatter plot
    if comparison == "none":
        fig: go.Figure = px.scatter(df_plot, x=lipid, y="strem2_log10", trendline="ols", trendline_color_override="#5654a2",
                                    facet_col="stage", facet_col_spacing=0.01, category_orders={"stage": stage_list})
        height_plot: int | float = height
    else:
        fig: go.Figure = px.scatter(df_plot, x=lipid, y="strem2_log10", trendline="ols", trendline_color_override="#5654a2", facet_col="stage",
                                    facet_col_spacing=0.01, facet_row=comparison, category_orders={"stage": stage_list, comparison: [True, False]})
        height_plot: int | float= 1.5 * height
    
    # Update the color and size of the markers
    fig.update_traces(marker=dict(color="#f78dcb", size=4))
    
    # Update the panel label to include the number of samples in each stage
    for ii, stage_label in enumerate(stage_list):
        panel_label: str = stage_label + "<br>(" + str(df_stats.loc["N", stage_label]) + ")"
        fig["layout"]["annotations"][ii].update(text=panel_label)
    
    # Update the axis labels to include the plasma lipid and CSF sTREM2 measurement units
    lipid_unit: str = lipidomics_dict.loc[lipidomics_dict["lipid_class"] == lipid, "unit"].values[0]
    fig.add_annotation(
        x=0.5, y=0.09 - 160 / height_plot,
        text=f"Total Plasma {lipid} ({lipid_unit})<br>(log<sub>10</sub> transformed)",
        xref="paper", yref="paper",
        textangle=0, showarrow=False,
        font=dict(color="black", size=18, family="Arial"),
    )
    fig.add_annotation(
        x=-90 / width, y=0.5,
        text="CSF sTREM2 (pg/mL)<br>(log<sub>10</sub> transformed)",
        xref="paper", yref="paper",
        textangle=-90, showarrow=False,
        font=dict(color="black", size=18, family="Arial"),
    )
    
    # Add the p-values and R^2 values to the plot
    if comparison == "none":
        # If no comparison, annotate the p-values and R^2 values for each stage
        pval_df, rsqr_df = annot_df_list[comparison_options.index("none")]
        for xaxis in range(len(stage_list)):
            fig.add_annotation(
                x=0.05,
                y=0.05,
                xref="x domain" if xaxis == 0 else f"x{xaxis + 1} domain",
                yref="y domain",
                text=f"p={pval_df.transpose()[lipid].iloc[xaxis]:.4f}",
                showarrow=False,
            )
            fig.add_annotation(
                x=1,
                y=1,
                xref="x domain" if xaxis == 0 else f"x{xaxis + 1} domain",
                yref="y domain",
                text=f"R<sup>2</sup>={rsqr_df.transpose()[lipid].iloc[xaxis]:.2f}",
                showarrow=False,
            )
    else:
        # If comparison is selected, annotate the p-values and R^2 values for each stage/col and condition/row
        for yaxis in range(2):
            pval_df, rsqr_df = annot_df_list[comparison_options.index(comparison) + yaxis]
            for xaxis in range(len(stage_list)):
                fig.add_annotation(
                    x=0.05,
                    y=0.05 if yaxis == 0 else yaxis * 1.2 + 0.05,
                    xref="x domain" if xaxis == 0 else f"x{xaxis + 1} domain",
                    yref="y domain" if yaxis == 0 else f"y{yaxis + 1} domain",
                    text=f"p={pval_df.transpose()[lipid].iloc[xaxis]:.4f}",
                    showarrow=False,
                )
                fig.add_annotation(
                    x=1,
                    y=1 if yaxis == 0 else yaxis * 1.2 + 0.85,
                    xref="x domain" if xaxis == 0 else f"x{xaxis + 1} domain",
                    yref="y domain" if yaxis == 0 else f"y{yaxis + 1} domain",
                    text=f"R<sup>2</sup>={rsqr_df.transpose()[lipid].iloc[xaxis]:.2f}",
                    showarrow=False,
                )

    # Update the layout
    fig: go.Figure = standard_layout(fig, True)
    fig.update_layout(
        height=height_plot,
        width=width,
        margin=dict(t=50, b=100),
    )

    # Update the boundary box
    fig.for_each_xaxis(lambda a: a.update(title=None, tickfont=dict(size=14), mirror=True, ticks="outside", showline=True))
    fig.for_each_yaxis(lambda a: a.update(title=None, tickfont=dict(size=14), mirror=True, ticks="outside", showline=True))

    # Update the trendline width
    for k, trace in enumerate(fig.data):
        if trace.mode is not None and trace.mode == "lines":
            fig.data[k].update(line_width=4)
    # Download as PDF+PNG
    if download > 0:
        fig.write_image(path_output_figure / f"regression_{lipid}_{comparison}.pdf")
        fig.write_image(path_output_figure / f"regression_{lipid}_{comparison}.png", scale=2)
    return fig

In [13]:
app.run(debug=True, jupyter_height=1000, port=7605, use_reloader=False)