In [1]:
import importlib.resources as pkg_resources

import torch
import pandas as pd
import plotly.graph_objects as go
from torchmetrics.functional.classification import (binary_confusion_matrix,
    multiclass_confusion_matrix)

import deep_paint

In [2]:
def plot_cm(
    embeddings: pd.DataFrame,
    classes: list,
    renderer: str = "notebook",
    **kwargs
):
    """
    Plot pretty confusion matrix.
    
    Parameters
    ----------
    embeddings: pd.DataFrame,
        DataFrame of model embeddings.
    classes: list
        List of class names for axes titles.
    renderer: str, default notebook
        Plotly renderer (ex. notebook, jupyterlab, vscode, iframe, etc).
    **kwargs
       Keyword arguments to pass to go.Figure().update_layout()
    """
    num_classes = len(classes)
    if num_classes > 2:
        cm = multiclass_confusion_matrix(
            preds=torch.tensor(embeddings['y_pred'].values),
            target=torch.tensor(embeddings['y_true'].values),
            num_classes=num_classes
        ).numpy()
    else:
        cm = binary_confusion_matrix(
            preds=torch.tensor(embeddings['y_pred'].values),
            target=torch.tensor(embeddings['y_true'].values)
        ).numpy()
    # Plot confusion matrix
    fig = go.Figure()

    fig.add_trace(go.Heatmap(
        z=cm,
        x=classes,
        y=classes,
        colorscale='Reds',
        showscale=True,
        colorbar=dict(title="Images")
    ))

    fig.update_layout(
        xaxis=dict(
            title='Predicted Label',
            title_standoff=25,
            showgrid=False
        ),
        yaxis=dict(
            title='True Label',
            title_standoff=25,
            showgrid=False,
            autorange='reversed'
        ),  
        autosize=False,
        margin=dict(
            l=50,
            r=50,
            b=100,
            t=50, 
            pad=4
        ),
        legend=dict(
            x=1,
            y=0,
            xanchor='right',
            yanchor='bottom',
            bordercolor='black',
            borderwidth=1,
            title="Images"
        ),
        font=dict(
            family="Arial, sans-serif",
            size=24,
            color="black"
        ),
        plot_bgcolor='white',
        **kwargs
    )

    # Add number of images
    for i, row in enumerate(cm):
        for j, value in enumerate(row):
            fig.add_annotation(go.layout.Annotation(
                text=str(value),
                x=classes[j],
                y=classes[i],
                xref="x",
                yref="y",
                showarrow=False,
                font=dict(
                    size=42,
                    color="white" if i == j else "black"
                )
            ))
            fig.add_shape(
                type="rect",
                x0=j - 0.5,
                y0=i - 0.5,
                x1=j + 0.5,
                y1=i + 0.5,
                line=dict(color="black", width=5),
                layer="below"
            )

    fig.show(renderer=renderer)

## Binary Classifier

In [3]:
# Load predictions
binary_preds_path = pkg_resources.files(deep_paint).joinpath("..", "results", "predictions", "binary_rxrx2.csv").resolve()
binary_preds = pd.read_csv(binary_preds_path)

In [4]:
# Change renderer according to your IDE
plot_cm(binary_preds, ["Low", "High"], renderer="jupyterlab", height=700, width=700)

## Multiclass Classifier

In [5]:
# Load predictions
multiclass_preds_path = pkg_resources.files(deep_paint).joinpath("..", "results", "predictions", "multiclass_rxrx2.csv").resolve()
multiclass_preds = pd.read_csv(multiclass_preds_path)

In [6]:
# Change renderer according to your IDE
plot_cm(multiclass_preds, ['Cytokine-GF', 'Toxin', 'Untreated'], renderer="jupyterlab", width=800, height=700)