In [1]:
import numpy as np
import matplotlib.pyplot as plt

def plot_classifier(clf, X, y, mesh_step_size=0.01):
    # X: 2d np array
    # y: 1d np array
    min_x1, max_x1 = X[:,0].min() - 1, X[:,0].max() + 1
    min_x2, max_x2 = X[:,1].min() - 1, X[:,1].max() + 1
    
    x1_vals, x2_vals = np.meshgrid(np.arange(min_x1, max_x1, mesh_step_size), np.arange(min_x2, max_x2, mesh_step_size))
    
    y_preds = clf.predict(np.c_[x1_vals.ravel(), x2_vals.ravel()]).reshape(x1_vals.shape)
    
    plt.figure()
    plt.pcolormesh(x1_vals, x2_vals, y_preds, cmap=plt.cm.Spectral, shading='auto', alpha=0.05)
    plt.scatter(X[:,0], X[:,1], c=y, s=50, ec='black', lw=1, cmap=plt.cm.Spectral, alpha=1.0)

    plt.xlim(min_x1, max_x1)
    plt.ylim(min_x2, max_x2)

    plt.xticks((np.arange(int(min_x1), int(max_x1), 1)))
    plt.yticks((np.arange(int(min_x2), int(max_x2), 1)))

    plt.show()