In [1]:
import os
import numpy as np
import pandas as pd

import plotly.graph_objects as go

In [2]:
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

In [3]:
eval_scores = np.load('./data/eval_scores.npy', allow_pickle=True).item()

In [4]:
best_thres = eval_scores['best_criteria_threshold']
best_scores = eval_scores['thres_scores'][best_thres]

pos_lab = 'yes'
neg_lab = 'no'
cm = best_scores['confusion_matrix']
cm

array([[6804.,  652.],
       [ 293.,  695.]])

In [5]:
def define_annotation(data):
    annotations = []
    for n, row in enumerate(data):
        for m, val in enumerate(row):
            font_color = "#000000" if (val < data.max()/2) else "#FFFFFF"
            annotations.append(
                go.layout.Annotation(
                    text=str(data[n][m]),
                    x=m,
                    y=n,
                    xref="x1",
                    yref="y1",
                    font=dict(color=font_color),
                    showarrow=False,
                )
            )
    return annotations

In [6]:
fig = go.Figure()

fig.add_trace(
go.Heatmap(
    x=[neg_lab, pos_lab],
    y=[neg_lab, pos_lab],
    z=cm,
    colorscale='Blues',
    showscale=True,
    reversescale=False,
    zmin=0,
)
)


# button_anno = dict(text='Hover:', showarrow=False, x=0.05, y=1, xref='paper', yref='paper', xanchor='left', yanchor='top', yshift=40)

annotations = define_annotation(cm)

fig.update_layout(
    title={
        'text': "Confusion Matrix",    
        'y':0.99,
        'x':0,
        'xanchor': 'left',
        'yanchor': 'top',
        'font': {'size': 22},
        },
    width=882,
    height=725,
    margin=dict(t=120, b=0, l=0, r=0),
    hovermode=False,
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(238,238,238,1)',
    modebar=dict(
        bgcolor='rgba(0,0,0,0)', activecolor='rgba(68,68,68, 0.7)', color='rgba(68, 68, 68, 0.3)',
#         remove=['zoom', 'lasso', 'select'],
    ),
#     modebar={'add': ["hoverclosest", "v1hovermode"]},
)

fig.update_layout({'annotations': annotations})

fig.update_layout(
    updatemenus=[
        dict(
            type = "buttons",
            direction = "right",
            buttons=list([
                dict(
                    args=[{"annotations": []}],
                    args2=[{"annotations": annotations}],
                    label="Show value",
                    method="relayout"
                )
            ]),
            active=1,
            x=0.0,
            xanchor="left",
            y=1,
            yanchor="top",
            pad={"t": -65},
            bgcolor='rgba(238,238,238,1)', #'rgba(159,197,232,1)',
        ),
    ]
)


fig.update_xaxes(title='Predict', side='bottom', tickmode='linear', fixedrange=True, scaleanchor = 'y', constrain='domain')
fig.update_yaxes(title='True', autorange='reversed', tickmode='linear', ticksuffix='  ', fixedrange=True)
fig.update_coloraxes(cmin=0)

fig.show(config={'displaylogo': False})
# fig.write_html('./automl_plot/confusion_matrix.html', config={'displaylogo':False}, include_plotlyjs='cdn', full_html=False)