In [1]:
import os
import subprocess
import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz

In [2]:
def get_iris_data():
    if os.path.exists("iris.csv"):
        print("iris.csv在本地")
        df = pd.read_csv("iris.csv", index_col=0)
    else:
        print("从github下载数据")
        fn="https://raw.githubusercontent.com/pydata/pandas/master/pandas/tests/data/iris.csv"
        try:
            df = pd.read_csv(fn)
        except:
            exit("不能下载iris数据集")
    
        with open("iris.csv", "w") as f:
            print("iris.csv数据保存到本地")
            df.to_csv(f)
    return df

In [3]:
df = get_iris_data()

从github下载数据
iris.csv数据保存到本地


In [4]:
print(df.head())

   SepalLength  SepalWidth  PetalLength  PetalWidth         Name
0          5.1         3.5          1.4         0.2  Iris-setosa
1          4.9         3.0          1.4         0.2  Iris-setosa
2          4.7         3.2          1.3         0.2  Iris-setosa
3          4.6         3.1          1.5         0.2  Iris-setosa
4          5.0         3.6          1.4         0.2  Iris-setosa


In [5]:
print(df.tail())

     SepalLength  SepalWidth  PetalLength  PetalWidth            Name
145          6.7         3.0          5.2         2.3  Iris-virginica
146          6.3         2.5          5.0         1.9  Iris-virginica
147          6.5         3.0          5.2         2.0  Iris-virginica
148          6.2         3.4          5.4         2.3  Iris-virginica
149          5.9         3.0          5.1         1.8  Iris-virginica


In [6]:
print(df["Name"].unique())

['Iris-setosa' 'Iris-versicolor' 'Iris-virginica']


In [7]:
def encode_target(df, target_column):
    """Add column to df with integers for the target.

    Args
    ----
    df -- pandas DataFrame.
    target_column -- column to map to int, producing
                     new Target column.

    Returns
    -------
    df_mod -- modified DataFrame.
    targets -- list of target names.
    """
    df_mod = df.copy()
    targets = df_mod[target_column].unique()
    map_to_int = {name: n for n, name in enumerate(targets)}
    print(map_to_int)
    df_mod["Target"] = df_mod[target_column].replace(map_to_int)

    return (df_mod, targets)

In [8]:
df2, targets = encode_target(df, "Name")

{'Iris-setosa': 0, 'Iris-versicolor': 1, 'Iris-virginica': 2}


In [9]:
print(df2.head())

   SepalLength  SepalWidth  PetalLength  PetalWidth         Name  Target
0          5.1         3.5          1.4         0.2  Iris-setosa       0
1          4.9         3.0          1.4         0.2  Iris-setosa       0
2          4.7         3.2          1.3         0.2  Iris-setosa       0
3          4.6         3.1          1.5         0.2  Iris-setosa       0
4          5.0         3.6          1.4         0.2  Iris-setosa       0


In [10]:
print(df2.tail())

     SepalLength  SepalWidth  PetalLength  PetalWidth            Name  Target
145          6.7         3.0          5.2         2.3  Iris-virginica       2
146          6.3         2.5          5.0         1.9  Iris-virginica       2
147          6.5         3.0          5.2         2.0  Iris-virginica       2
148          6.2         3.4          5.4         2.3  Iris-virginica       2
149          5.9         3.0          5.1         1.8  Iris-virginica       2


In [11]:
print(targets)

['Iris-setosa' 'Iris-versicolor' 'Iris-virginica']


In [12]:
features=list(df2.columns[:4])

In [13]:
print(features)

['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth']


In [14]:
X=df2[features]
y=df2['Target']

In [15]:
dt=DecisionTreeClassifier(min_samples_split=20, random_state=99)

In [16]:
dt.fit(X,y)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=20,
            min_weight_fraction_leaf=0.0, presort=False, random_state=99,
            splitter='best')

In [17]:
def visualize_tree(tree, feature_names):
    """Create tree png using graphviz.

    Args
    ----
    tree -- scikit-learn DecsisionTree.
    feature_names -- list of feature names.
    """
    with open("dt.dot", 'w') as f:
        export_graphviz(tree, out_file=f,
                        feature_names=feature_names)

    command = ["dot", "-Tpng", "dt.dot", "-o", "dt.png"]
    try:
        subprocess.check_call(command)
    except:
        exit("Could not run dot, ie graphviz, to "
             "produce visualization")

In [18]:
visualize_tree(dt, features)