In [1]:
import pandas as pd
from pathlib import Path
import numpy as np
from scipy.stats import zscore
from statsmodels.stats.multitest import multipletests
from sklearn.impute import SimpleImputer
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go
from dash import Dash, dcc, html, Input, Output, dash_table
import dash_bootstrap_components as dbc
import sys

sys.path.append("../../")
from lib.stats import multiple_linear_regression, cluster_corr_df, remove_diagonal

### Input

In [2]:
# Define I/O paths
path_input_demographics: Path = Path(
    "../../../data/processed/adni/demographics_tau.csv"
).resolve()
path_input_dict: Path = Path("../../../data/processed/adni/somascan_dict.csv").resolve()
path_input_proteomics: Path = Path(
    "../../../data/processed/adni/somascan.csv"
).resolve()
path_input_davidson: Path = Path(
    "../../../data/processed/other/hdl_proteome_davidson.csv"
).resolve()

In [3]:
# Read files
df_demographics: pd.DataFrame = pd.read_csv(path_input_demographics).convert_dtypes()
df_dict: pd.DataFrame = pd.read_csv(path_input_dict).convert_dtypes()
df_proteomics: pd.DataFrame = pd.read_csv(path_input_proteomics).convert_dtypes()
df_davidson: pd.DataFrame = pd.read_csv(path_input_davidson).convert_dtypes()

### Processing

In [4]:
# Join cognitive status table with proteomics data
df: pd.DataFrame = df_demographics.join(
    df_proteomics.set_index("RID"), on="RID", how="inner"
).reset_index(drop=True)

In [5]:
# Join dictionary table with davidson proteome watchlist
df_dict: pd.DataFrame = df_dict.join(
    df_davidson.set_index("uniprot_id"), on="uniprot_id", how="inner"
).reset_index(drop=True)

In [6]:
# Filter out columns that are not in the dictionary
df: pd.DataFrame = df[df_demographics.columns.tolist() + df_dict["label"].tolist()]

In [7]:
# log10 transform ptau and ttau
df["ptau"] = np.log10(df["ptau"])
df["ttau"] = np.log10(df["ttau"])

### PCA

In [8]:
# Prepare data for PCA
# Shape: (proteins, observations)
data_pca: pd.DataFrame = (
    df.drop(columns=[var for var in df_demographics.columns if var != "RID"])
    .set_index("RID")
    .transpose()
)

In [9]:
# Standardize data using z-score
data_standardized: pd.DataFrame | np.ndarray = zscore(
    data_pca.astype(float), axis=1, nan_policy="omit"
)
# Remove outliers using z-score threshold of 3
data_standardized[data_standardized > 3] = np.nan

In [10]:
# Impute missing values using mean
imputer: SimpleImputer = SimpleImputer(strategy="mean")
data_imputed: np.ndarray = imputer.fit_transform(data_standardized)
df[df_dict["label"]] = data_imputed.transpose()

In [11]:
# Perform PCA
pca: PCA = PCA()
pca: PCA = pca.fit(data_imputed)

In [12]:
# Get cumulated explained variance ratio
exp_var: np.ndarray = pca.explained_variance_ratio_.cumsum()

In [13]:
# Visualize explained variance ratio
px.area(
    x=range(1, exp_var.shape[0] + 1),
    y=exp_var,
    labels={"x": "Number of PCs", "y": "Explained Variance"},
)

### Determine the number of PCs and cluster proteins

In [14]:
# Find the number of components that explain 99% of the variance
n_components: np.intp = np.argmax(exp_var > 0.99) + 1

In [15]:
# Stack original data with the principal components
stacked: np.ndarray = np.vstack((data_imputed, pca.components_))

In [16]:
# Compute correlation matrix
corr_mat: np.ndarray = np.corrcoef(stacked, rowvar=True)

In [17]:
# Get the correlations between components and the original proteins
corr_mat_components: np.ndarray = corr_mat[: -pca.n_components_, -pca.n_components_ :]

In [18]:
# Assign each protein to a cluster
df_dict["cluster"] = np.argmax(np.abs(corr_mat_components), axis=1)

In [19]:
# Keep only the first n_components clusters
df_dict: pd.DataFrame = df_dict.loc[df_dict["cluster"] < n_components]

# Display the number of unique clusters
len(df_dict["cluster"].unique())

20

In [20]:
# A cluster that contains APOL1
cluster: int = df_dict.loc[df_dict["entrez_gene_symbol"] == "APOL1", "cluster"].values[
    0
]

### Multiple linear regression model adjusted for age and sex

In [21]:
# Change to compatible data types
df[df.select_dtypes(include=[bool, int]).columns] = df[
    df.select_dtypes(include=[bool, int]).columns
].astype(int)
df[df.select_dtypes(include=[float]).columns] = df[
    df.select_dtypes(include=[float]).columns
].astype(float)
df[df.select_dtypes(include=[object]).columns] = df[
    df.select_dtypes(include=[object]).columns
].astype(str)

In [22]:
# Perform multiple linear regression
glm_ptau: pd.DataFrame = multiple_linear_regression(
    data=df,
    list_of_predictors=df_dict["label"].tolist(),
    dv="ptau",
    not_adjust=["bmi", "apoe4"],
)

In [23]:
# Join the results with the dictionary
glm_ptau: pd.DataFrame = df_dict.join(
    glm_ptau.set_index("predictor"), on="label", how="inner"
)

In [24]:
# For each cluster, keep the protein with the smallest p-value
glm_cluster: pd.DataFrame = glm_ptau.loc[
    glm_ptau.groupby(by="cluster")["p_value"].idxmin()
]

# Apply Bonferroni correction
_, glm_cluster["p_value_bonf_adj"], _, _ = multipletests(
    glm_cluster["p_value"], alpha=0.05, method="bonferroni"
)

In [25]:
# Make entrez_gene_symbol unique by appending asterisks to duplicates
glm_ptau["entrez_gene_symbol"] = glm_ptau["entrez_gene_symbol"] + glm_ptau.groupby(
    "entrez_gene_symbol"
).cumcount().apply(lambda x: "*" * (x)).astype(str)

### Dashboard

In [26]:
# Create a dictionary to map entrez_gene_symbol to target label
map_label: dict[str, str] = dict(zip(glm_ptau["entrez_gene_symbol"], glm_ptau["label"]))

In [27]:
# fmt: off
app: Dash = Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])

app.layout = html.Div([
    html.Hr(),
    dbc.Row([
        dbc.Col([
            dbc.Label(children="Select a protein (name may not be unique):"),
            dcc.Dropdown(options=glm_ptau["target_full_name"].tolist(), value="Apolipoprotein E (isoform E4)", id="protein"),
            html.Hr(),
            dash_table.DataTable(id='protein_table',
                        style_data_conditional=[{'if': {'row_index': 'odd'}, 'backgroundColor': 'rgb(220, 220, 220)'}],
                        style_data={'fontWeight': '500'},
                        style_header={'backgroundColor': 'rgb(220, 220, 220)', 'fontWeight': 'bold'},
                        style_cell={'font_family': 'serif', 'font_size': '16px', 'color': 'black'},
                        page_size=5,
                        page_action='native',)
        ], width={"size": 4, "offset": 1}),
        dbc.Col([
            html.H5(id="cluster_label"),
            dash_table.DataTable(id='cluster_table',
                        style_data_conditional=[{'if': {'row_index': 'odd'}, 'backgroundColor': 'rgb(220, 220, 220)'}],
                        style_data={'fontWeight': '500'},
                        style_header={'backgroundColor': 'rgb(220, 220, 220)', 'fontWeight': 'bold'},
                        style_cell={'font_family': 'serif', 'font_size': '16px', 'color': 'black'},
                        page_size=8,
                        page_action='native',
                        sort_action='native',)
        ], width=6),
    ]),
    html.Hr(),
    dbc.Row([
        dbc.Col([
            dcc.Graph(id="scatter_plot_top"),
            dcc.Graph(id="scatter_plot_bot"),
        ], width={"size": 3, "offset": 1},
        ),
        dbc.Col(
            dcc.Graph(id="correlation_heatmap"),
            width=6,
        ),
    ]),
], style={"backgroundColor": "white"})

In [28]:
# fmt: off
@app.callback(
    Output("cluster_label", "children"), 
    Input("protein", "value"))
def update_cluster_lable(protein: str) -> str:
    cluster: int = glm_ptau.loc[glm_ptau["target_full_name"] == protein, "cluster"].min()
    count: int = (glm_ptau["cluster"] == cluster).sum()
    label: str = f"{protein} belongs to Cluster {cluster + 1}, which contains {count} proteins"
    return label


@app.callback(
    Output("protein_table", "data"), 
    Input("protein", "value"))
def update_cluster_label(protein: str) -> list[dict]:
    subset: pd.DataFrame = glm_ptau.loc[glm_ptau["target_full_name"] == protein, ["target_full_name", "entrez_gene_symbol", "coef", "p_value"]]
    # Display numbers in scientific notation
    subset.sort_values(by="p_value", ascending=True, inplace=True)
    subset.reset_index(drop=True, inplace=True)
    subset["coef"] = subset["coef"].apply(lambda x: f"{x:.4e}")
    subset["p_value"] = subset["p_value"].apply(lambda x: f"{x:.4e}")
    return subset.to_dict(orient="records")


@app.callback(
    Output("cluster_table", "data"), 
    Input("protein", "value"))
def update_cluster_table(protein: str) -> list[dict]:
    cluster: int = glm_ptau.loc[glm_ptau["target_full_name"] == protein, "cluster"].min()
    p_value_adjusted: float = glm_cluster.loc[glm_cluster["cluster"] == cluster, "p_value_bonf_adj"].values[0]
    subset: pd.DataFrame = glm_ptau.loc[glm_ptau["cluster"] == cluster, ["target_full_name", "entrez_gene_symbol", "coef", "p_value"]]
    # Display numbers in scientific notation
    subset.sort_values(by="p_value", ascending=True, inplace=True)
    subset.reset_index(drop=True, inplace=True)
    subset.loc[0, "p_value_adjusted"] = f"{p_value_adjusted:.4e}*"
    subset["coef"] = subset["coef"].apply(lambda x: f"{x:.4e}")
    subset["p_value"] = subset["p_value"].apply(lambda x: f"{x:.4e}")
    return subset.to_dict(orient="records")


@app.callback(
    Output("correlation_heatmap", "figure"), 
    Input("protein", "value"))
def update_correlation_heatmap(protein: str) -> go.Figure:
    cluster: int = glm_ptau.loc[glm_ptau["target_full_name"] == protein, "cluster"].min()
    subset: pd.DataFrame = glm_ptau.loc[glm_ptau["cluster"] == cluster].reset_index(drop=True)
    df_plot: pd.DataFrame = cluster_corr_df(df[subset["label"]].corr(method="pearson"))
    index_map: dict = dict(zip(subset["label"], subset["entrez_gene_symbol"]))
    fig: go.Figure = px.imshow(
        remove_diagonal(df_plot),
        x=df_plot.columns.map(index_map),
        y=df_plot.index.map(index_map),
        zmin=-1,
        zmax=1,
        color_continuous_scale=px.colors.diverging.RdBu_r,
    )
    fig: go.Figure = fig.update_layout(
        template="simple_white",
        width=800, height=750,
        font=dict(color="black", size=12),
        xaxis=dict(title=None, showticklabels=True, mirror=True, ticks="outside", showline=True),
        yaxis=dict(title=None, showticklabels=True, mirror=True, ticks="outside", showline=True),
        margin=dict(t=10, b=20, l=20, r=10),
    )
    return fig

def update_scatter_plot(hoverData: dict, y_var: str) -> go.Figure:
    x_label: str = hoverData["points"][0]["y"]
    x_var: str = map_label[x_label]
    scatter_data: pd.DataFrame = df[[x_var, y_var]].dropna().astype(float)
    fig: go.Figure = px.scatter(scatter_data, x=x_var, y=y_var, trendline="ols", trendline_color_override="#5654a2")
    fig.update_traces(marker=dict(color="#f78dcb", size=4))
    fig.for_each_xaxis(lambda a: a.update(title=None, tickfont=dict(size=14), mirror=True, ticks="outside", showline=True))
    fig.for_each_yaxis(lambda a: a.update(title=None, tickfont=dict(size=14), mirror=True, ticks="outside", showline=True))
    fig.update_layout(
        template="simple_white",
        paper_bgcolor="rgba(255, 255, 255, 1)",
        xaxis=dict(title=x_label + "<br>(z-transformed)", showticklabels=True, mirror=True, ticks="outside", showline=True),
        yaxis=dict(title=y_var + "<br>(log10-transformed)", mirror=True, ticks="outside", showline=True),
        font=dict(color="black", size=18, family="Arial"),
        title=dict(text=None, font=dict(color="black", size=22, family="Arial"), x=0.5),
        height=350,
        width=400,
        margin=dict(t=50, b=20, l=20, r=20),
    )
    return fig

@app.callback(
    Output("scatter_plot_top", "figure"),
    Input("correlation_heatmap", "hoverData"),
    prevent_initial_call=True)
def update_scatter_plot_top(hoverData: dict) -> go.Figure:
    return update_scatter_plot(hoverData, "ptau")


@app.callback(
    Output("scatter_plot_bot", "figure"),
    Input("correlation_heatmap", "hoverData"),
    prevent_initial_call=True)
def update_scatter_plot_bot(hoverData: dict) -> go.Figure:
    return update_scatter_plot(hoverData, "ttau")

In [30]:
app.run(debug=True, port=7522, use_reloader=False, jupyter_height=1350)