In [None]:
import joblib
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

In [None]:
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='sans-serif')
sns.set_palette(['#6da7de', '#9e0059', '#dee000', '#d82222', '#5ea15d',
                 '#943fa6', '#63c5b5', '#ff38ba', '#eb861e', '#ee266d'])
sns.set_context('paper', font_scale=1.3)

In [None]:
rf_stats, *_ = joblib.load('../data/processed/train_classifier_rf.joblib')
lr_stats, *_ = joblib.load('../data/processed/train_classifier_lr.joblib')
svm_stats, *_ = joblib.load('../data/processed/train_classifier_svm.joblib')

In [None]:
width = 7
height = width / 1.618    # Golden ratio.
fig, ax = plt.subplots(figsize=(width, height))

interval = np.linspace(0, 1, 101)
tpr_rf = rf_stats['tpr_mean_test']
tpr_rf[0], tpr_rf[-1] = 0, 1
tpr_lr = lr_stats['tpr_mean_test']
tpr_lr[0], tpr_lr[-1] = 0, 1
tpr_svm = svm_stats['tpr_mean_test']
tpr_svm[0], tpr_svm[-1] = 0, 1
ax.plot(interval, tpr_rf,
        label=f'Random forest (AUC = {rf_stats["roc_auc_test"]:.3f} '
              f'± {rf_stats["roc_auc_std_test"]:.3f})')
ax.fill_between(interval, tpr_rf - rf_stats['tpr_std_test'],
                tpr_rf + rf_stats['tpr_std_test'], alpha=0.2)
ax.plot(interval, tpr_lr,
        label=f'Logistic regression (AUC = {lr_stats["roc_auc_test"]:.3f} '
              f'± {lr_stats["roc_auc_std_test"]:.3f})')
ax.fill_between(interval, tpr_lr - lr_stats['tpr_std_test'],
                tpr_lr + lr_stats['tpr_std_test'], alpha=0.2)
ax.plot(interval, tpr_svm,
        label=f'SVM (AUC = {svm_stats["roc_auc_test"]:.3f} '
              f'± {svm_stats["roc_auc_std_test"]:.3f})')
ax.fill_between(interval, tpr_svm - svm_stats['tpr_std_test'],
                tpr_svm + svm_stats['tpr_std_test'], alpha=0.2)
        
ax.plot([0, 1], [0, 1], c='black', ls='--')

ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

ax.set_xlabel('False Positive Rate')
ax.set_ylabel('True Positive Rate')

ax.legend(loc='lower right', frameon=False)

sns.despine()

plt.savefig('train_classifier_roc.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()