In [1]:
import importlib.resources as pkg_resources
from typing import Optional

import torch
import pandas as pd
import plotly.graph_objects as go
from sklearn.metrics import precision_recall_curve, auc
import torch.nn.functional as F

import deep_paint

In [2]:
def plot_auprc(
    predictions_df: pd.DataFrame,
    classes_map: Optional[dict] = None,
    renderer: str = "notebook",
    **kwargs
):
    """
    Plot area under the Precision-Recall curve using sklearn.
    
    Parameters
    ----------
    predictions_df: pd.DataFrame,
        DataFrame of model predictions.
    classes_map: dict, optional
        Dictionary of class names to colors.
    renderer: str, default notebook
        Plotly renderer (ex. notebook, jupyterlab, vscode, iframe, etc).
    **kwargs
        Keyword arguments to pass to go.Figure().update_layout()
    """
    if classes_map is None:
        num_classes = 2
    else:
        num_classes = len(classes_map.keys())

    fig = go.Figure()
    
    if num_classes > 2:
        labels = classes_map.keys()
        logits_cols = [f"logits_{i}" for i in range(num_classes)]
        logits = predictions_df[logits_cols].values
        probs = F.softmax(torch.tensor(logits), dim=1).numpy()
        
        # Calculate precision-recall curve and AUC for each class
        for i, label in enumerate(labels):
            preds = probs[:, i]
            true_labels = predictions_df['y_true'].values == i
            
            precision, recall, _ = precision_recall_curve(true_labels, preds)
            pr_auc = auc(recall, precision)

            # Plot PR curve for this class
            fig.add_trace(
                go.Scatter(
                    x=recall,
                    y=precision,
                    mode='lines',
                    name=f'AUC-PR {label} vs rest = {pr_auc:.4f}',
                    line=dict(color=classes_map[label])
                )
            )
    else:
        logits = predictions_df[['logits_0', 'logits_1']].values
        probs = F.softmax(torch.tensor(logits), dim=1).numpy()
        preds = probs[:, 1]
        true_labels = predictions_df['y_true'].values

        # Calculate precision-recall curve and AUC for binary classification
        precision, recall, _ = precision_recall_curve(true_labels, preds)
        pr_auc = auc(recall, precision)
        
        # Plot PR curve for binary classification
        fig.add_trace(go.Scatter(
            x=recall,
            y=precision,
            mode='lines',
            name=f'AUC-PR = {pr_auc:.4f}',
            line=dict(color='gray')
        ))
        # Add the chance level curve (AUC-PR for random classifier)
        fig.add_trace(
            go.Scatter(
                x=[0, 1],
                y=[1, 0],
                mode='lines',
                name='Chance Level (AUC-PR = 0.5)', 
                line=dict(dash='dash', color='black')
            )
        )

    # Format the layout
    fig.update_layout(
        xaxis_title='Recall',
        xaxis=dict(
            title_standoff=25,
            showgrid=True,
            gridcolor="#D3D3D3",
            showticklabels=True,
            tickmode='array',
            tickvals=[0.00, 0.25, 0.50, 0.75, 1.00],
            showline=True,
            linewidth=1,
            linecolor='black',
            mirror=True,
            minor=dict(
                tickmode='linear',
                tick0=0,
                dtick=0.125,
                gridcolor='#E6E6E6',
                gridwidth=0.5
            )
        ),
        yaxis_title='Precision',
        yaxis=dict(
            title_standoff=25,
            gridcolor="#D3D3D3",
            showticklabels=True,
            tickmode='array',
            tickvals=[0.00, 0.25, 0.50, 0.75, 1.00],
            showline=True,
            linewidth=1,
            linecolor='black',
            mirror=True,
            minor=dict(
                tickmode='linear',
                tick0=0,
                dtick=0.125,
                gridcolor='#E6E6E6',
                gridwidth=0.5
            )
        ),
        autosize=False,
        margin=dict(
            l=50,
            r=50,
            b=100,
            t=50, 
            pad=4
        ),
        plot_bgcolor='white',
        legend=dict(
            x=1,
            y=0,
            xanchor='right',
            yanchor='bottom',
            bordercolor='black',
            borderwidth=1,
            font=dict(
                family="Arial, sans-serif",
                size=20,
                color="black"
            )
        ),
        font=dict(
            family="Arial, sans-serif",
            size=24,
            color="black"
        ),
        **kwargs
    )

    fig.show(renderer=renderer)

## Binary Classifier

In [3]:
# Load predictions
binary_preds_path = pkg_resources.files(deep_paint).joinpath("..", "results", "predictions", "binary_rxrx2.csv").resolve()
binary_preds = pd.read_csv(binary_preds_path)

In [4]:
# Change renderer according to your IDE
plot_auprc(binary_preds, renderer="jupyterlab", width=800, height=700)

# Multiclass Classifier

In [5]:
# Load predictions
multiclass_preds_path = pkg_resources.files(deep_paint).joinpath("..", "results", "predictions", "multiclass_rxrx2.csv").resolve()
multiclass_preds = pd.read_csv(multiclass_preds_path)

In [6]:
# Change renderer according to your IDE
color_map = {
    'Cytokine-GF': '#c22020', #Label 0
    'Toxin': '#dfc261', # Label 1
    'Untreated': '#79d2f0' # Label 2
}
plot_auprc(multiclass_preds, color_map, renderer="jupyterlab", width=800, height=700)