In [1]:
from pathlib import Path
import pandas as pd
from scipy.stats import zscore
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output
import dash_bootstrap_components as dbc
import sys

sys.path.append("../../")
from lib.general import get_stage_list
from lib.stats import demographic_characteristics

### Input

In [2]:
# Define I/O paths
path_demographics: Path = Path(
    "../../../data/processed/adni/demographics_biomarkers.csv"
).resolve()
path_lipidomics: Path = Path(
    "../../../data/processed/adni/lipidomics_total.csv"
).resolve()
path_lipidomics_dict: Path = Path(
    "../../../data/processed/adni/lipidomics_dict.csv"
).resolve()
path_output_figure: Path = Path("../../../assets/figures/adni/").resolve()

In [3]:
# Read file
demographics: pd.DataFrame = pd.read_csv(path_demographics).dropna().drop_duplicates()
lipidomics: pd.DataFrame = pd.read_csv(path_lipidomics).dropna().drop_duplicates()
lipidomics_dict: pd.DataFrame = (
    pd.read_csv(path_lipidomics_dict).dropna().drop_duplicates().convert_dtypes()
)

In [4]:
# Join dataframes
df: pd.DataFrame = demographics.join(lipidomics.set_index("RID"), on="RID", how="inner")

In [5]:
# Get the list of stages
stage_list: list[str] = get_stage_list(2)
# Get the list of plasma lipids
lipid_list: list[str] = lipidomics_dict["lipid_class"].unique().tolist()

In [6]:
# Convert columns to the compatible data types
df[df.select_dtypes(include=[bool, int]).columns] = df[
    df.select_dtypes(include=[bool, int]).columns
].astype(int)
df[df.select_dtypes(include=[float]).columns] = df[
    df.select_dtypes(include=[float]).columns
].astype(float)
df[df.select_dtypes(include=[object]).columns] = df[
    df.select_dtypes(include=[object]).columns
].astype(str)

In [7]:
# Compute the summary statistics table
stats_df: pd.DataFrame = demographic_characteristics(df, stage_list)

### Dashboard

In [8]:
# Create stage labels which include the number of samples in each stage
stage_label_list: list[str] = [
    f"{stage}<br>({stats_df.loc['N', stage]})" for stage in stage_list
]

In [9]:
# Compute the mean of each lipid class in each stage
lipid_means: pd.DataFrame = (
    df.groupby(by="stage")
    .describe()
    .xs("mean", level=1, axis=1, drop_level=True)
    .reindex(stage_list)
    .transpose()
    .loc[lipid_list]
)

In [10]:
# Z-transform the lipid means for each lipid class
lipid_means[:] = zscore(lipid_means.values, axis=1)

In [11]:
# Melt the dataframe for plotting
df_melted: pd.DataFrame = pd.melt(
    lipid_means.transpose().reset_index(drop=False),
    id_vars=["stage"],
    var_name="lipid",
    value_name="zscore",
)

In [12]:
# Describe the trend of each lipid class by comparing the mean of each stage with the previous stage
df_trend: pd.DataFrame = pd.DataFrame(index=lipid_means.index)
for ii in range(1, len(stage_list)):
    df_trend[f"{stage_list[ii]} > {stage_list[ii - 1]}"] = (
        lipid_means[stage_list[ii]] > lipid_means[stage_list[ii - 1]]
    )

In [13]:
# Count the number of times each trend occurs
df_trend_label: pd.DataFrame = (
    df_trend.groupby(by=df_trend.columns.tolist())
    .size()
    .reset_index(drop=False)
    .rename(columns={0: "count"})
    .sort_values(by="count", ascending=False)
    .reset_index(drop=True)
)

In [14]:
# Assign a prevalence rank to each lipid class
df_prevalence: pd.DataFrame = pd.DataFrame(
    columns=["prevalence_rank", "count"], index=lipid_means.index
)
# For each prevalence rank
for prevalence, prevalence_row in df_trend_label.drop(columns=["count"]).iterrows():
    # For each lipid class
    for lipid, lipid_row in df_trend.iterrows():
        if prevalence_row.equals(lipid_row):
            df_prevalence.loc[lipid, "prevalence_rank"] = prevalence + 1
            df_prevalence.loc[lipid, "count"] = df_trend_label.loc[prevalence, "count"]

In [15]:
# Define a new column that has the format of "prevalence_rank (count)"
df_prevalence.sort_values(by="prevalence_rank", ascending=True, inplace=True)
df_prevalence["lipid_label"] = (
    df_prevalence["prevalence_rank"].astype(str)
    + " ("
    + df_prevalence["count"].astype(str)
    + ")"
)

In [16]:
# Join labels to the melted dataframe for plotting
df_plot: pd.DataFrame = df_melted.join(
    df_prevalence["lipid_label"], on="lipid", how="left"
)
df_plot["stage_label"] = df_plot["stage"].map(dict(zip(stage_list, stage_label_list)))

### Dash app

In [17]:
# fmt: off
app: Dash = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

app.layout = html.Div([
    html.Hr(),
    dbc.Row([
        dbc.Col([
            dbc.Label(children="Add connecting lines:"),
            dbc.RadioItems(options=["none", "mean", "median"], value="none", id="connecting_line", inline=True),
            html.Hr(),
            dbc.Switch(label="Show legend", id="show_legend", value=True),
            html.Hr(),
            dbc.Switch(label="Separate lipids by color", id="separate_lipids", value=False),
        ], width={"size": 3, "offset": 1}),
        dbc.Col([
            dbc.Label(children="Adjust height"),
            dcc.Slider(id="height", min=200, max=1000, step=25, value=450, marks={x: str(x) for x in list(range(200, 1000 + 200, 200))}),
            dbc.Label(children="Adjust width"),
            dcc.Slider(id="width", min=400, max=1400, step=25, value=800, marks={x: str(x) for x in list(range(400, 1400 + 200, 200))}),
        ], width=4),
        dbc.Col([
            dbc.Label("Line color"),
            dbc.Input(type="color", id="color_line", value="#000000", style={"width": 75, "height": 50},),
            html.Hr(),
            dbc.Button(children="Download as PDF+PNG", id="download", n_clicks=0),
        ], width=3),
    ]),
    html.Hr(),
    dbc.Row([
        dbc.Col(
            dcc.Checklist(id="checklist", options=df_prevalence["lipid_label"].unique(), value=df_prevalence["lipid_label"].unique(), inline=False),
            width={"size": 1, "offset": 1}
        ),
        dbc.Col(dcc.Graph(id="graph"), width=10),
    ]),
], style={"backgroundColor": "white"})

In [18]:
# fmt: off
@app.callback(
    Output("graph", "figure"),
    Input("connecting_line", "value"),
    Input("show_legend", "value"),
    Input("separate_lipids", "value"),
    Input("height", "value"),
    Input("width", "value"),
    Input("color_line", "value"),
    Input("checklist", "value"),
    Input("download", "n_clicks"))
def update_plot(connecting_line: str, show_legend: bool, separate_lipids: bool, height: int,
                width: int, color_line: str, trend_list: list[str], download: int) -> go.Figure:
    # Filter the dataframe based on the selected trends
    df_plot_filtered: pd.DataFrame = df_plot.loc[df_plot["lipid_label"].isin(trend_list)]
    
    # Initialize the figure object
    fig: go.Figure = px.line(df_plot_filtered, x="stage_label", y="zscore", color="lipid", category_orders={"stage_label": stage_label_list})
    
    # Add horizontal line at y=0
    fig.add_hline(y=0, line_width=2, line_dash='dot')

    # Update figure layout
    fig.update_layout(
        template="plotly",
        paper_bgcolor= 'rgba(255, 255, 255, 1)',
        plot_bgcolor= 'rgba(255, 255, 255, 1)',
        xaxis=dict(title=None, mirror=True, ticks='outside', showline=True),
        yaxis=dict(title='lipid concentration<br>(z-transformed)', mirror=True, ticks='outside', showline=True),
        font=dict(color='black', size=18, family='Arial'),
        title=dict(text=None, font=dict(color='black', size=22, family='Arial')),
        height=height, width=width + show_legend * 300,
        margin=dict(t=100, b=100),
        showlegend=show_legend,
    )
    
    # Separate lipids by line color if selected, otherwise use the default color for all lipids
    if not separate_lipids:
        fig.update_traces(line_color="#9da4ef")

    # Add connecting lines if selected
    if connecting_line == 'mean':
        y_value: list[float] = [df_plot_filtered.loc[df_plot_filtered['stage'] == stage, 'zscore'].mean() for stage in stage_list]
        fig.add_trace(go.Scatter(x=stage_label_list, y=y_value, mode='lines', showlegend=False, line=dict(color=color_line)))
    elif connecting_line == 'median':
        y_value: list[float] = [df_plot_filtered.loc[df_plot_filtered['stage'] == stage, 'zscore'].median() for stage in stage_list]
        fig.add_trace(go.Scatter(x=stage_label_list, y=y_value, mode='lines', showlegend=False, line=dict(color=color_line)))
    
    # Download as PDF+PNG
    if download > 0:
        fig.write_image(path_output_figure / "trend_lipidomics.pdf" )
        fig.write_image(path_output_figure / "trend_lipidomics.png", scale=2)

    return fig

In [19]:
app.run(debug=True, jupyter_height=800, port=7611, use_reloader=False)