# Interactive notebook for visualisations of prediction performance

In [None]:
import numpy as np
import scanpy as sc
import plotly.express as px

from pathlib import Path
from src.utils.pred_func import prepare_test_data

%load_ext blackcellmagic

In [None]:
# Local
project_prefix = "/Users/khoat/Development/medical_genomics/projects/xML-workFlow"

# Append path to ./src/
import sys

sys.path.insert(0, project_prefix)
sys.path.insert(0, str(Path(project_prefix).joinpath("src")))

# Import main function from pred module
from src.pred import main as pred_main

In [None]:
y_test_ser, pred_proba_df, pred_metrics_ser, roc_curve_df = pred_main(
    search_obj_path=Path(project_prefix).joinpath("random_forest.pkl"),
    ml_model="random_forest",
    test_data_path=Path(project_prefix).joinpath("examples/sammut_et_al_test.h5ad"),
    output_dir="examples/",
    save_output=False,
)

In [None]:
pred_metrics_ser.to_frame().T

## Plotting
It's more intuitive to explore models prediction in Jupternotebook format <br><br>
Here we provide basic plotting functions, in both interactive (.html) and static (.png and .jpg) format. 

In [None]:
# Specify interaction suffix to safe (currently only support .html)
interactive_suffix = ".html"

# Specify static suffixes to safe (currently support .png, .jpg, and .svg)
static_suffixes = [".png", ".svg", ".jpg"]

### [Fig] ROC-AUC Curve

In [None]:
fig = px.line(
    roc_curve_df,
    x="fpr",
    y="tpr",
    labels=dict(x="False Positive Rate", y="True Positive Rate"),
    color_discrete_map={},
)

# Set lind width
fig.update_traces(line=dict(width=1))

# Update axes
fig.update_yaxes(
    title="TPR",
    title_font_size=8,
    title_standoff=4,
    linecolor="black",
    linewidth=0.75,
    ticks="outside",
    showticklabels=True,
    dtick=0.1,
    tickfont_size=7,
    tickwidth=0.75,
    ticklen=3,
    matches=None,
    scaleanchor="x",
    scaleratio=1,
    range=[0, 1],
    showgrid=False,
    gridwidth=0.5,
    gridcolor="lightgray",
    griddash="dash",
)
fig.update_xaxes(
    title="FPR",
    title_font_size=8,
    title_standoff=4,
    linecolor="black",
    linewidth=0.75,
    ticks="outside",
    showticklabels=True,
    dtick=0.1,
    tickfont_size=7,
    tickwidth=0.75,
    ticklen=3,
    matches=None,
    constrain="domain",
    range=[0, 1],
    showgrid=False,
    gridwidth=0.5,
    gridcolor="lightgray",
    griddash="dash",
)

# Add 45 degree line
fig.add_shape(type="line", line=dict(dash="dash", width=0.75), x0=0, x1=1, y0=0, y1=1)

# Update layout
fig["layout"].update(
    font=dict(size=7, color="black", family="Arial"),
    plot_bgcolor="rgba(0,0,0,0)",
    paper_bgcolor="rgba(0,0,0,0)",
    legend=dict(),
    showlegend=False,
    newshape=dict(opacity=1),
    margin=dict(t=5, l=0, r=5, b=0),  # Tight margin
)

# Format annotations
# Turn this on if we don't want to display subplot annotations
for anno in fig["layout"]["annotations"]:
    anno["text"] = ""

# Save image
if interactive_suffix:
    fig.write_html(
        Path(project_prefix)
        .joinpath(f"figures/roc_auc_curve")
        .with_suffix(interactive_suffix)
    )
if static_suffixes:
    for suffix in static_suffixes:
        fig.write_image(
            Path(project_prefix).joinpath(f"figures/roc_curve").with_suffix(suffix),
            scale=1 if suffix == ".svg" else 5,
            width=200,
            height=200,
        )

### [Fig] Bar plot of predicted probabilities
For this, we plot correct and incorrect predictions separately

In [None]:
# For this dataset, we also have acess to RCB score and RCB category in the test annbdata object
test_adata = sc.read(Path(project_prefix).joinpath("examples/sammut_et_al_test.h5ad"))
rcb_meta_df = test_adata[:, ["RCB.score", "RCB.category"]].to_df()

# Rename labels to be more intuitive
pred_proba_df["y_truth_label"] = pred_proba_df["y_truth"].replace(
    {0.0: "No pCR", 1.0: "pCR"}
)

# Mere with rcb metadata
pred_proba_df = pred_proba_df.merge(
    rcb_meta_df, left_index=True, right_index=True, how="inner"
)

In [None]:
fig = px.histogram(
    pred_proba_df,
    x="y_pred_proba",
    color="RCB.category",
    facet_col="y_truth_label",
    nbins=50,
    color_discrete_map={
        0.0: "blue",
        1.0: "pink",
        2.0: "red",
        3.0: "maroon",
    },
    category_orders={"RCB.category": [0, 1, 2, 3], "y_truth_label": ["pCR", "No pCR"]},
)

# Update all xaxes
fig.update_xaxes(
    title="Probability of pCR",
    title_font_size=8,
    title_standoff=6,
    linecolor="black",
    linewidth=0.5,
    ticks="outside",
    showticklabels=True,
    tickfont_size=7,
    tickwidth=0.5,
    ticklen=2,
    dtick=0.1,
    range=[0, 1],
    matches=None,
)
# Update all yaxes with no title
# then add title to the first yaxis
fig.update_yaxes(
    linecolor="black",
    linewidth=0.5,
    tickwidth=0.5,
    ticklen=2,
    ticks="outside",
    showticklabels=False,
    showgrid=False,
    gridwidth=0.5,
    gridcolor="lightgray",
    griddash="dash",
    range=[0, 7],
)
fig.update_yaxes(
    title="Count",
    title_font_size=8,
    title_standoff=5,
    showticklabels=True,
    tickfont_size=7,
    dtick=1,
    col=1,
)

# Add vertical line at 0.5
fig.add_vline(
    x=0.5, line_width=0.5, fillcolor="slategray", opacity=0.5, line_dash="dash"
)

# Format annotations
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))

# Update layout
fig["layout"].update(
    boxgap=0,
    font=dict(size=7, color="black", family="Arial"),
    plot_bgcolor="rgba(0,0,0,0)",
    paper_bgcolor="rgba(0,0,0,0)",
    showlegend=False,  # Uncomment to hide legend if too large
    legend=dict(
        title_font_size=5,
        font_size=4,
    ),
    newshape=dict(opacity=1),
    margin=dict(t=11, l=0, r=6, b=0),  # Tight margin
)

# Save plot
if interactive_suffix:
    fig.write_html(
        Path(project_prefix)
        .joinpath(f"figures/proba_hist")
        .with_suffix(interactive_suffix)
    )

if static_suffixes:
    for suffix in static_suffixes:
        fig.write_image(
            Path(project_prefix).joinpath(f"figures/proba_hist").with_suffix(suffix),
            scale=1 if suffix == ".svg" else 5,
            width=290,
            height=110,
        )

### Confusion matrix

In [None]:
fig = px.imshow(
    pred_metrics_ser["confusion"],
    x=["No pCR", "pCR"],
    y=["No pCR", "pCR"],
    color_continuous_scale="teal",
    zmin=0,
    zmax=55,
    text_auto=True,
)

# Update color axis
fig.update_coloraxes(
    showscale=True,
    colorbar=dict(tickvals=np.arange(0, 55, 5)),
)

# Update axes
fig.update_xaxes(
    ticks="outside",
    tickfont_size=7,
    ticklen=3,
    tickwidth=1,
    showgrid=False,
)
fig.update_yaxes(
    ticks="outside",
    tickfont_size=7,
    ticklen=3,
    tickwidth=1,
    tickangle=90,
    showgrid=False,
)

# Adjust layout
fig["layout"].update(
    boxgap=0,
    font=dict(size=8, color="black", family="Arial"),
    # plot_bgcolor="rgba(0,0,0,0)",
    # paper_bgcolor="rgba(0,0,0,0)",
    newshape=dict(opacity=1),
    margin=dict(t=0, l=0, r=0, b=0),  # Tight margin
)

# Save plots
if interactive_suffix:
    fig.write_html(
        Path(project_prefix)
        .joinpath(f"figures/confusion_matrix")
        .with_suffix(interactive_suffix)
    )
if static_suffixes:
    for suffix in static_suffixes:
        fig.write_image(
            Path(project_prefix)
            .joinpath(f"figures/confusion_matrix")
            .with_suffix(suffix),
            scale=1 if suffix == ".svg" else 5,
            width=225,
            height=200,
        )