In [1]:
from pathlib import Path
from itertools import product
import pandas as pd
from scipy.stats import false_discovery_control
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, dash_table
import dash_bootstrap_components as dbc
import sys

sys.path.append("../../")
from lib.general import get_stage_list
from lib.stats import (
    demographic_characteristics,
    multiple_linear_regression,
    mask_outlier,
)
from lib.plotly import standard_layout

### Input

In [2]:
# Specify I/O paths
path_demographics: Path = Path(
    "../../../data/processed/adni/demographics_biomarkers.csv"
).resolve()
path_lipoprotein: Path = Path("../../../data/processed/adni/lipoprotein.csv").resolve()
path_lipoprotein_dict: Path = Path(
    "../../../data/processed/adni/lipoprotein_dict.csv"
).resolve()
path_output_figure: Path = Path("../../../assets/figures/adni/").resolve()

In [3]:
# Read files
demographics: pd.DataFrame = pd.read_csv(path_demographics, low_memory=False)
lipoprotein: pd.DataFrame = pd.read_csv(path_lipoprotein, low_memory=False)
lipoprotein_dict: pd.DataFrame = pd.read_csv(
    path_lipoprotein_dict, low_memory=False
).convert_dtypes()

### Processing

In [4]:
# Join dataframes
df: pd.DataFrame = demographics.join(
    lipoprotein.set_index("RID"), on="RID", how="inner"
)

In [5]:
# Define lists
stage_list: list[str] = get_stage_list(2)
lipo_list: list[str] = lipoprotein_dict.loc[lipoprotein_dict["pct"], "label"].tolist()

In [6]:
# Compute statistics table
df_stats: pd.DataFrame = demographic_characteristics(df, stage_list)
df_stats

Unnamed: 0,Total,CSF-/PET-,CSF+/PET-,CSF-/PET+,CSF+/PET+,p_value
,,,,,,
N,353,175,15,33,130,
Age (year),71.5 (7.0),70.1 (7.1),75.0 (7.7),72.3 (6.3),72.9 (6.6),< 0.001
BMI (kg/m^2),28.1 (5.0),29.0 (5.2),25.9 (3.3),28.6 (5.5),27.0 (4.5),0.0014
Sex,,,,,,
Female (n),161 (46%),78 (45%),7 (47%),20 (61%),56 (43%),0.3325
Male (n),192 (54%),97 (55%),8 (53%),13 (39%),74 (57%),
Cognitive Status,,,,,,
CN (n),122 (35%),79 (45%),5 (33%),15 (45%),23 (18%),< 0.001
MCI (n),231 (65%),96 (55%),10 (67%),18 (55%),107 (82%),


In [7]:
# Define the list of comparison options
comparison_options: list[str] = list(df.select_dtypes(include=[bool]).columns)

In [8]:
# Instantiate the list of annotation dataframes
# Order: [stage1, stage2, stage3, ...]
annot_df_list: list[tuple[pd.DataFrame, pd.DataFrame]] = []

# Define column order: [none, cond1-False, cond1-True, cond2-False, cond2-True, ...]
conditions: list[str] = ["none"] + [
    f"{option}={condition}"
    for option, condition in list(product(comparison_options, [0, 1]))
]

# For each stage
for stage in stage_list:
    # Initialize two dataframes to store the p-values and R^2 values
    pval_annot_df: pd.DataFrame = pd.DataFrame(index=lipo_list, columns=conditions)
    rsqr_annot_df: pd.DataFrame = pd.DataFrame(index=lipo_list, columns=conditions)

    # For each comparison option and each condition, compute the p-values and R^2 values for each lipo class
    for column in conditions:
        # For each condition, perform multiple linear regression and store the results
        if column == "none":
            # Subset data based on the given stage
            subset: pd.DataFrame = df.loc[df["stage"] == stage]
            # Perform multiple linear regression
            linreg: pd.DataFrame = multiple_linear_regression(
                subset, lipo_list, "strem2_log10"
            )
            # Perform B-H false discovery control
            linreg["bh_adj_p_value"] = false_discovery_control(
                linreg["p_value"], axis=0, method="bh"
            )
            # Sort the results by ascending adjusted p-value
            linreg: pd.DataFrame = linreg.sort_values(
                by=["bh_adj_p_value", "p_value"], ascending=True
            ).set_index("predictor")
        else:
            # Extract the comparison variable and value
            comparison, value = column.split("=")
            # Subset data based on the given stage and comparison condition
            subset: pd.DataFrame = df.loc[
                (df["stage"] == stage) & (df[comparison] == bool(int(value)))
            ]
            if subset.shape[0] < 8:
                continue
            # Perform multiple linear regression
            linreg: pd.DataFrame = multiple_linear_regression(
                subset, lipo_list, "strem2_log10"
            )
            # Perform B-H false discovery control
            linreg["bh_adj_p_value"] = false_discovery_control(
                linreg["p_value"], axis=0, method="bh"
            )
            # Sort the results by ascending adjusted p-value
            linreg: pd.DataFrame = linreg.sort_values(
                by=["bh_adj_p_value", "p_value"], ascending=True
            ).set_index("predictor")
        # Store the p-values and R^2 values in the dataframes
        pval_annot_df[column] = linreg["bh_adj_p_value"]
        rsqr_annot_df[column] = linreg["rsq"]
    # Append the dataframes to the list
    annot_df_list.append((pval_annot_df, rsqr_annot_df))

### Dash app

In [9]:
# fmt: off
app: Dash = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

app.layout = html.Div([
    html.Hr(),
    dbc.Row([
        dbc.Col([
            dbc.Label(children="Select a stage:"),
            dcc.RadioItems(options=stage_list, value=stage_list[1], id="stage", inline=False),
            html.Hr(),
            dbc.Switch(label="Outlier removal", id="outlier_removal", value=False),
        ], width={"size": 2, "offset": 1}),
        dbc.Col([
            dbc.Label(children="plot height"),
            dcc.Slider(id="height", min=100, max=600, step=25, value=350, marks={x: str(x) for x in range(100, 600 + 200, 200)}),
            dbc.Label(children="plot width"),
            dcc.Slider(id="width", min=400, max=1400, step=25, value=800, marks={x: str(x) for x in range(400, 1400 + 200, 200)}),
            html.Hr(),
            dbc.Button(children="Download as PDF+PNG", id="download", n_clicks=0),
        ], width=3),
    ], justify="evenly"),
    html.Hr(),
    dbc.Row([
        dbc.Col(dcc.Graph(id="graph"), width={"size": 4, "offset": 0}),
        dbc.Col(dash_table.DataTable(id="table",
                                     style_data_conditional=[{'if': {'row_index': 'odd'}, 'backgroundColor': 'rgb(220, 220, 220)'}],
                                     style_data={'fontWeight': '500'},
                                     style_header={'backgroundColor': 'rgb(220, 220, 220)', 'fontWeight': 'bold'},
                                     style_cell={'font_family': 'serif', 'font_size': '16px', 'color': 'black'},
                                     page_action='native',
                                     sort_action='native',
                                     page_size=8,),
                width={"size": 5, "offset": 0}),
    ]),
], style={"backgroundColor": "white"})

In [10]:
# fmt: off
@app.callback(
    Output("table", "data"), 
    Input("stage", "value"))
def update_table(stage: str) -> list[dict]:
    pval_df, _ = annot_df_list[stage_list.index(stage)]
    pval_df: pd.DataFrame = pval_df.sort_values(by="none", ascending=True).round(decimals=4)
    return pval_df.reset_index(drop=False).rename(columns={"index": "id"}).to_dict(orient="records")


@app.callback(
    Output("graph", "figure"),
    Input("stage", "value"),
    Input("table", "active_cell"),
    Input("outlier_removal", "value"),
    Input("height", "value"),
    Input("width", "value"),
    Input("download", "n_clicks"),
    prevent_initial_call=True)
def update_graph(stage: str, active_cell: dict, outlier_removal: bool, height: int, width: int, download: int) -> go.Figure:
    # If no cell is selected, return an empty figure
    if active_cell is None:
        return go.Figure()
    
    # Extract the selected lipoprotein component and comparison condition
    lipo: str = active_cell["row_id"]
    condition: str = active_cell["column_id"]
    comparison: str = condition.split("=")[0] if "=" in condition else "none"

    # Remove outliers if checked
    if outlier_removal:
        df_plot: pd.DataFrame = df.loc[mask_outlier(df[[lipo, "strem2_log10"]])]
    else:
        df_plot: pd.DataFrame = df

    # Generate the scatter plot object
    if comparison == "none":
        fig: go.Figure = px.scatter(df_plot, x=lipo, y="strem2_log10", trendline="ols", trendline_color_override="#5654a2",
                                    facet_col="stage", facet_col_spacing=0.01, category_orders={"stage": stage_list})
        height_plot: int | float = height
    else:
        fig: go.Figure = px.scatter(df_plot, x=lipo, y="strem2_log10", trendline="ols", trendline_color_override="#5654a2", facet_col="stage",
                                    facet_col_spacing=0.01, facet_row=comparison, category_orders={"stage": stage_list, comparison: [True, False]})
        height_plot: int | float = 1.5 * height

    # Update the color and size of the markers
    fig.update_traces(marker=dict(color="#f78dcb", size=4))

    # Update the panel label to include the number of samples in each stage
    for ii, stg in enumerate(stage_list):
        if comparison == "none":
            count: int = df_plot.loc[df_plot["stage"] == stg].shape[0]
            panel_label: str = stg + "<br>(" + str(count) + ")"
        else:
            countTrue: int = df_plot.loc[(df_plot["stage"] == stg) & (df_plot[comparison])].shape[0]
            countFalse: int = df_plot.loc[(df_plot["stage"] == stg) & (~df_plot[comparison])].shape[0]
            panel_label: str = stg + "<br>(" + str(countTrue) + ") / (" + str(countFalse) + ")"
        fig["layout"]["annotations"][ii].update(text=panel_label)

    # Update the axis labels to include the plasma lipo class and CSF sTREM2 unit
    lipo_unit: str = lipoprotein_dict.loc[lipoprotein_dict["label"] == lipo, "unit"].values[0]
    fig.add_annotation(
        x=0.5, y=0.1 - 170 / height_plot,
        text=f"Total Plasma {lipo} ({lipo_unit})<br>(log<sub>10</sub> transformed)",
        xref="paper", yref="paper",
        textangle=0, showarrow=False,
        font=dict(color="black", size=18, family="Arial"),
    )
    fig.add_annotation(
        x=-100 / width, y=0.5,
        text="CSF sTREM2 (pg/mL)<br>(log<sub>10</sub> transformed)",
        xref="paper", yref="paper",
        textangle=-90, showarrow=False,
        font=dict(color="black", size=18, family="Arial"),
    )

    # Add the p-values and R^2 values to the plot
    for xaxis in range(len(stage_list)):
        pval_df, rsqr_df = annot_df_list[xaxis]
        if comparison == "none":
            fig.add_annotation(
                x=0.05,
                y=0.05,
                xref="x domain" if xaxis == 0 else f"x{xaxis + 1} domain",
                yref="y domain",
                text=f"p={pval_df.at[lipo, comparison]:.4f}",
                showarrow=False,
            )
            fig.add_annotation(
                x=1,
                y=1,
                xref="x domain" if xaxis == 0 else f"x{xaxis + 1} domain",
                yref="y domain",
                text=f"R<sup>2</sup>={rsqr_df.at[lipo, comparison]:.2f}",
                showarrow=False,
            )
        else:
            for yaxis in range(2):
                fig.add_annotation(
                    x=0.05,
                    y=0.05 if yaxis == 0 else yaxis * 1.2 + 0.05,
                    xref="x domain" if xaxis == 0 else f"x{xaxis + 1} domain",
                    yref="y domain" if yaxis == 0 else f"y{yaxis + 1} domain",
                    text=f"p={pval_df.at[lipo, comparison + "=" + str(yaxis)]:.4f}",
                    showarrow=False,
                )
                fig.add_annotation(
                    x=1,
                    y=1 if yaxis == 0 else yaxis * 1.2 + 0.85,
                    xref="x domain" if xaxis == 0 else f"x{xaxis + 1} domain",
                    yref="y domain" if yaxis == 0 else f"y{yaxis + 1} domain",
                    text=f"R<sup>2</sup>={rsqr_df.at[lipo, comparison + "=" + str(yaxis)]:.2f}",
                    showarrow=False,
                )

    # Update figure layout
    fig: go.Figure = standard_layout(fig, True)
    fig.update_layout(
        height=height_plot,
        width=width,
        margin=dict(t=50, b=100),
    )

    # Update boundary box
    fig.for_each_xaxis(lambda a: a.update(title=None, tickfont=dict(size=14), mirror=True, ticks="outside", showline=True))
    fig.for_each_yaxis(lambda a: a.update(title=None, tickfont=dict(size=14), mirror=True, ticks="outside", showline=True))

    # Update trendline width
    for k, trace in enumerate(fig.data):
        if trace.mode is not None and trace.mode == "lines":
            fig.data[k].update(line_width=4)
            
    # Download as PDF+PNG
    if download > 0:
        fig.write_image(path_output_figure / f"regression_{lipo}_{comparison}.pdf" )
        fig.write_image(path_output_figure / f"regression_{lipo}_{comparison}.png", scale=2)

    return fig

In [11]:
app.run(debug=True, jupyter_height=1000, port=7606, use_reloader=False)