In [2]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [3]:
# Colors and shapes used to plot points belonging to different classes
COLOR_PALETTE = ['red', 'green', 'blue', 'cyan']
COLOR_PALETTE = [matplotlib.cm.get_cmap('Reds'),
                 matplotlib.cm.get_cmap('Greens'),
                 matplotlib.cm.get_cmap('Blues'),
                 matplotlib.cm.get_cmap('Purples')]
DEFAULT_CMAP_INDEX = 0.65
SHAPE_PALETTE = ['o', 'P', '*', 's']

In [4]:
#
# Plots a line of the form y = mx + b over the domain (x limits) of 
# the current axes
#
def mbline(m, b, color = 'black'):
    axes = plt.gca()
    x = np.array(axes.get_xlim())
    y = m * x + b
    plt.plot(x, y, color = color)

In [5]:
#
# Convert a weight vector of the form [w1, w2] or [w0, w1, w2]
# to slope/intercept form.  In the former case it is assumed that w0 = 0.
#
def weight_to_mb(w):
    assert len(w) in [2, 3], 'Weight vector must have length 2 or 3'
    w = w[:]
    if len(w) == 2:
        w = [0] + w
    w = [val or 0.0001 for val in w]  # Prevent infinite and zero slope
    return -w[1]/w[2], -w[0]/w[2]

In [6]:
#
# Plot points in 2D space
#
# Arguments:
#
#   df (DataFrame) - dataset to plot
#   label (str) - name of column in df that contains class label
#   colors (list) - names of colors to use for different classes' points
#   shapes (list) - names of shapes to use for different classes' points
#   weights (list or dict) - A dictionary of class/weight pairs or a list
#                            of weights in the binary case
#   show_weight_vector (bool) - Plot weight vector(s) if true
#   show_separator (bool) - Plot linear separator(s) if true
#
def plot_2d(df, label = 'y', colors = COLOR_PALETTE, 
            shapes = SHAPE_PALETTE, weights = None, 
            color_idx = DEFAULT_CMAP_INDEX,
            show_weight_vector = False, show_separator = False,
            pad = None, make_figure = True):
       
    scatters = {}
    
    if make_figure:
        plt.figure(figsize = (4, 4))
        plt.axes().set_aspect('equal', 'box')

    features = list(df.columns)
    features.remove(label)
    
    plt.xlabel(features[0])
    plt.ylabel(features[1])
    
    # Scatter plot points using distinctive shapes and colors
    df = df.copy()
    df['color_idx'] = color_idx
    for y, color, shape in zip(sorted(df[label].unique()), colors, shapes):
        x1 = df[df[label] == y][features[0]]
        x2 = df[df[label] == y][features[1]]
        c = df[df[label] == y]['color_idx']
        scatter = plt.scatter(x1, x2, cmap = color, c = c, marker = shape, vmin = 0, vmax = 1)
        scatters[y] = scatter
    
    if weights is not None:
        xlim = plt.xlim()
        ylim = plt.ylim()

        if type(weights) == dict:
            weights = [weights[c] for c in sorted(df[label].unique())]
        else:
            weights = [list(weights)]
        for idx, w in enumerate(weights):
            if show_separator:
                m, b = weight_to_mb(w)
                mbline(m, b, color = (colors[idx](DEFAULT_CMAP_INDEX) if len(weights) > 2 else 'black'))
            if show_weight_vector:
                plt.quiver([0], [0], [w[-2]], [w[-1]], 
                           angles = 'xy', scale_units='xy', scale=1.)
                
        plt.xlim(xlim)
        plt.ylim(ylim)
    
    if pad:
        lim = plt.xlim()
        rng = lim[1] - lim[0]
        plt.xlim(lim[0] - pad * rng, lim[1] + pad * rng)
        lim = plt.ylim()
        rng = lim[1] - lim[0]
        plt.ylim(lim[0] - pad * rng, lim[1] + pad * rng)
        
    return scatters

In [6]:
def project_2d(df, weights, label = 'y', 
               colors = COLOR_PALETTE, shapes = SHAPE_PALETTE):
    plt.figure(figsize=(6, 0.5))
    
    for y, color, shape in zip(sorted(df[label].unique()), colors, shapes):
        x = df[df[label] == y]
        x = x.apply(lambda row: np.dot(row.values[:2], weights), axis = 1)
        y = [0] * len(x)
        plt.scatter(x, y, color = color(DEFAULT_CMAP_INDEX), marker = shape)

    plt.yticks([])
    plt.ylabel('')
    plt.xlabel(r'${\bf w} \cdot {\bf x}$')
    plt.show()