In [1]:
import os
import numpy as np
import pandas as pd

import plotly.graph_objects as go

In [2]:
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

In [3]:
eval_scores = np.load('./data/eval_scores.npy', allow_pickle=True).item()
best_thres = eval_scores['best_criteria_threshold']
thres_scores = eval_scores['thres_scores']

In [4]:
precision = np.array([v['precision'] for v in thres_scores.values()])
recall = np.array([v['recall'] for v in thres_scores.values()])
f1_score = np.array([v['f1_score'] for v in thres_scores.values()])
thres = list(thres_scores.keys())

In [5]:
def plot_pr_curve(precision, recall, threshold, f1_score, best_thres, width=882, height=725):
#     hovertemplate = (
#         'Threshold: %{x:,.4f} <br>' + 
#         'Score: %{y:,.4f}' + 
#         '<extra></extra>'
#     )
    hovertemplate = "Threshold: %{x}<br> %{name}: %{y} <extra></extra>"
    
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(x=threshold, y=precision,
                        mode='lines',
                        name='precision',
                        marker=dict(color='#00CC96'),
#                         hovertemplate=hovertemplate,
#                         hovertemplate='Precision<br><br>'+hovertemplate
                            ))

    fig.add_trace(go.Scatter(x=threshold, y=recall,
                        mode='lines',
                        name='recall',
                        marker=dict(color='#636EFA'),
#                         hovertemplate=hovertemplate,
#                         hovertemplate='Recall<br><br>'+hovertemplate
                            ))
#     if f1_score is not None:
    fig.add_trace(go.Scatter(x=threshold, y=f1_score,
                        mode='lines',
                        name='f1 score',
                        marker=dict(color='#EF553B'),
#                         hovertemplate=hovertemplate,
#                             hovertemplate='F1 score<br><br>'+hovertemplate
                            ))
    
    fig.add_vline(x=best_thres, line_dash="dash", line_color="navy")
        
#     fig.update_traces(hovertemplate=hovertemplate)

    fig.update_layout(
        title={
                'text': 'Precision Recall Curve',
                'font': dict(size=22),
                'y':0.99,
                'x':0,
                'xanchor': 'left',
                'yanchor': 'top',
        },
        width=width,
        height=height,
        margin=dict(t=100, b=0, l=0, r=0),
        hoverlabel=dict(
            bgcolor="white",
            font_size=13,
            font_family="Rockwell",
        ),
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(238,238,238,1)',
        modebar=dict(
            bgcolor='rgba(0,0,0,0)', activecolor='rgba(68,68,68, 0.7)', color='rgba(68, 68, 68, 0.3)',
            remove=['zoom', 'lasso', 'select'],
        ),
        hovermode='x unified'
    )
    
    fig.update_xaxes(title={'text':'Thresholds', 'standoff': 50, 'font':{'size': 14}},
                     range=[-0.01, 1.01],
                    )
    fig.update_yaxes(title={'text':'Score', 'standoff': 50, 'font':{'size': 14}}, type='-')

    return fig   

In [6]:
fig = plot_pr_curve(precision, recall, thres, f1_score=f1_score, best_thres=best_thres)
fig.show(config={'displaylogo':False})