In [2]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [3]:
COLOR_PALETTE = ['red', 'green', 'blue', 'cyan']
SHAPE_PALETTE = ['o', 'P', '*', 's']

In [4]:
def mbline(m, b, color = 'black'):
    axes = plt.gca()
    x = np.array(axes.get_xlim())
    y = m * x + b
    plt.plot(x, y, color = color)

In [5]:
def weight_to_mb(w):
    if len(w) == 2:
        d = w[1] or 0.001
        return -w[0]/d, 0
    else:
        d = w[2] or 0.001
        return -w[1]/d, -w[0]/d

In [6]:
def plot_binary_2d(df, label = 'y', colors = ('red', 'green'), 
                   weights = None, 
                   show_weight_vector = False, 
                   show_separator = False):
        
    plt.figure(figsize = (4, 4))
    plt.axes().set_aspect('equal', 'box')

    features = list(df.columns)
    features.remove(label)
    
    plt.xlabel(features[0])
    plt.ylabel(features[1])
    
    for y, color in zip(sorted(df[label].unique()), colors):
        x1 = df[df[label] == y][features[0]]
        x2 = df[df[label] == y][features[1]]
        plt.scatter(x1, x2, color = color)
          
    if weights:
        if show_separator:
            m, b = weight_to_mb(weights)
            mbline(m, b)
        if show_weight_vector:
            plt.quiver([0], [0], [weights[-2]], [weights[-1]], 
                       angles = 'xy', scale_units='xy', scale=1.)
        

In [9]:
def plot_2d(df, label = 'y', colors = COLOR_PALETTE, 
            shapes = SHAPE_PALETTE, weights = None, 
            show_weight_vector = False, show_separator = False):
        
    plt.figure(figsize = (4, 4))
    plt.axes().set_aspect('equal', 'box')

    features = list(df.columns)
    features.remove(label)
    
    plt.xlabel(features[0])
    plt.ylabel(features[1])
    
    for y, color, shape in zip(sorted(df[label].unique()), colors, shapes):
        x1 = df[df[label] == y][features[0]]
        x2 = df[df[label] == y][features[1]]
        plt.scatter(x1, x2, color = color, marker = shape)
    
    if weights is not None:
        xlim = plt.xlim()
        ylim = plt.ylim()

        if type(weights) == dict:
            weights = [weights[c] for c in sorted(df[label].unique())]
        else:
            weights = [list(weights)]
        for idx, w in enumerate(weights):
            if show_separator:
                m, b = weight_to_mb(w)
                mbline(m, b, color = colors[idx])
            if show_weight_vector:
                plt.quiver([0], [0], [w[-2]], [w[-1]], 
                           angles = 'xy', scale_units='xy', scale=1.)
                
        plt.xlim(xlim)
        plt.ylim(ylim)