In [1]:
class ModelInspector():
    def __init__(self, model, X, y):
        self.model = model
        self.X = X
        self.y = y

    def _plot_confusion_matrix(self, y_pred, ax):
        tn, fp, fn, tp = confusion_matrix(self.y, y_pred).ravel()

        matrix = np.eye(2)
        matrix_annot = [[f'TP\n{tp}', f'FP\n{fp}'], [f'FN\n{fn}', f'TN\n{tn}']]

        sns.heatmap(matrix, annot=matrix_annot, annot_kws={"size": 20}, fmt='', cmap='Greens', cbar=False,
                    xticklabels=['', 'Good client'], yticklabels=['Bad client', ''], ax=ax)

    def _plot_metrics(self, y_pred, roc_auc, ax):
        matrix = np.array([[1, 0, 1, 0, 1]]).T

        matrix_annot = np.array([[
            f'ROC AUC: {roc_auc:.4f}',
            f'Balanced accuracy: {balanced_accuracy_score(self.y, y_pred):.3f}',
            f'F1-score: {f1_score(self.y, y_pred):.3f}',
            f'Precision score: {precision_score(self.y, y_pred):.3f}',
            f'Recall score: {recall_score(self.y, y_pred):.3f}'
        ]]).T

        sns.heatmap(matrix, annot=matrix_annot, fmt='', cbar=False, yticklabels=[],
                    xticklabels=[], annot_kws={'size': 16, 'ha': 'center'}, cmap='GnBu', ax=ax)

    def _plot_logistic_regression(self, probs, ax):
        fpr, tpr, threshold = roc_curve(self.y, probs)

        ax.plot([0, 1], label='Baseline', linestyle='--')
        ax.plot(fpr, tpr, label='Regression')
        ax.set_ylabel('True Positive Rate')
        ax.set_xlabel('False Positive Rate')
        ax.legend(loc='lower right')

    def inspect(self, size=5):
        y_pred = self.model.predict(self.X)
        probs = self.model.predict_proba(self.X)[:, 1]
        roc_auc = roc_auc_score(self.y, probs)

        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(3*size, size))

        fig.suptitle(f'Model inspection. ROC AUC: {roc_auc:.4f}', fontsize=20)

        ax1.set_title('Logistic Regression')
        ax2.set_title('Confusion matrix')
        ax3.set_title('Metrics')

        # Plot logistic regression
        self._plot_logistic_regression(probs, ax1)

        # Plot confusion matrix
        self._plot_confusion_matrix(y_pred, ax2)

        # plot metrics
        self._plot_metrics(y_pred, roc_auc, ax3)

        plt.show()