In [1]:
import random
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import math
import time
from IPython.display import display, Markdown
%run Plot\ Utilities.ipynb
%matplotlib inline

In [2]:
def sample_mushroom(edible = 0.5):
    label = 'edible' if random.random() < edible else 'poisonous'
    height = random.uniform(0.25, 0.5) if label == 'poisonous' else random.uniform(0.75, 1.0)
    width = random.uniform(0.25, 0.5) if label == 'edible' else random.uniform(0.75, 1.0)
    color = random.uniform(-0.5, 0.5)
    
    return {'label':label, 'height':height, 'width':width, 'color':color}

In [3]:
def show_mushrooms(D, cols, label = True):
    img = mpimg.imread('mushroom.png')
    red = (img[:,:,0] > img[:,:,1] + 0.1) * 1
    img = np.mean(img, axis = 2)
    
    rows = math.ceil(len(D) / cols)
    fig, axes = plt.subplots(rows, cols, sharex = False, sharey = False, figsize=(2 * cols, 2 * rows))
    axes = axes.flatten().tolist()
    
    for axis in axes:
        axis.set_xticks([])
        axis.set_yticks([])
        axis.axis('off')

    for idx, mushroom in enumerate(D):
        i = np.copy(img)
        i = i + mushroom['color'] * red
        
        axis = axes[idx]
        if label:
            axis.set_title(mushroom['label'])
        w = 100 * (1 - mushroom['height']) / 2
        h = 100 * (1 - mushroom['width']) / 2
        
        extent = (h, 100 - h, 100 - w, w)
                
        axis.imshow(i, cmap = 'gray', vmin = 0, vmax = 1, extent = extent)
        
    plt.show()

In [1]:
def plot_mushroom_features(df, features, split = None):
    if isinstance(features, str):
        features = [features]
        
    plt.figure(figsize=(6, 0.5 if len(features) == 1 else 6))
    
    for label in ['poisonous', 'edible']:
        x = df[df['label'] == label][features[0]]
        plt.xlabel(features[0])
        if len(features) == 1:
            y = [0] * len(x)
            plt.ylabel('')
        else:
            y = df[df['label'] == label][features[1]]
            plt.ylabel(features[1])
        color = (COLOR_PALETTE[0] if label == 'edible' else COLOR_PALETTE[1])
        shape = (SHAPE_PALETTE[0] if label == 'edible' else SHAPE_PALETTE[1])
        plt.scatter(x, y, color = color, marker = shape)

    if len(features) == 1:
        plt.yticks([])
        
    if split:
        plt.axvline(split, color = 'black')
    
    plt.show()


In [6]:
def generate_and_test_mushroom_rules(df):
    for feature in ['color', 'width', 'height']:
        df.sort_values(by = [feature], inplace = True)
        for idx in range(len(df) - 1):
            split = (df.iloc[idx][feature] + df.iloc[idx + 1][feature]) / 2
            x = df[feature] - split
            y = list(-1 if label == 'edible' else 1 for label in df['label'])
            accuracy = 100 * sum(x * y > 0 for x, y in zip(x, y)) / len(df)
            compare = '<'
            if accuracy < 50:
                compare = '>'
                accuracy = 100 - accuracy
            plot_mushroom_features(df, feature, split)
            print('IF %s %s %.2f THEN\n\tlabel = edible\nELSE\n\tlabel = poisonous' % (feature, compare, split))
            print('Training set accuracy = %.2f' % accuracy)
            display(Markdown('<hr style="height:3px;color:black;" />'))