In [None]:
# install libraries
#import os
#os.environ["PATH"] = "/srv/conda/lib" + os.pathsep + os.environ["PATH"]
#print(os.environ["PATH"])


#not all libraries can be installed so easily. Often graphics libraries (graphs) have to do some 
# clever stuff to talk to the graphics card.
# unfortunately, those libraries have to be isntalled in a slightly different way...
# from home ,choose New -> Terminal
# enter the following (copy paste), and press enter
# conda install -c anaconda graphviz --yes
# repeat for:
# conda install -c anaconda pydotplus --yes

!conda install -c anaconda graphviz --yes 
!pip install pandas 
!pip install sklearn
!pip install matplotlib
!pip install seaborn
!pip install graphviz
!pip install pydotplus


In [None]:
#import libraries
# 🐼 is to work with tables of data (http://pandas.pydata.org/)
import pandas as pd

# sklearn is for machine learning (http://scikit-learn.org)
from sklearn import tree

# matplotlib is to make plots, pandas using it under the hood
# Display plots in this page rather than open another page
%matplotlib inline

import seaborn as sns

import graphviz 
import pydotplus #library for drawing dots on graph efficiently

from sklearn.model_selection import cross_val_score, GridSearchCV #libraries for imporving learning

from IPython.display import Image #create image files

In [None]:
# load the data
df = pd.read_csv('mushrooms.csv')

In [None]:
#explore data
df.head()

In [None]:
#describe the data
df.describe()

In [None]:
#Let's try and visualise this data with the help of https://www.kaggle.com/surajit346/ml-models-and-visualizations-for-beginners


In [None]:
sns.countplot(x='odor',hue='class',data=df)


In [None]:
#one hot encoding the data
features = pd.get_dummies(df)
features.head()

In [None]:
#we are interested whether its edible or poisonous, so we don't want to use that information as part of our model
#if we did, we would be able to use the fact a mushroom is edible or poisonous to tell us if its either...
#that would be cheating. So we remove these features.
#we will store these properties (the result its predicting in classes. So now we have features and classes)
features = features.drop(['class_e','class_p'],axis=1)
classes = df['class']

In [None]:
#so now we use a decision tree on the features to predict the classes
model = tree.DecisionTreeClassifier()
model.fit(features,classes)

In [None]:
#using sklearn's documentation we have written a function for you to plot the tree
#from here: # http://scikit-learn.org/stable/modules/tree.html#classification
def plotTree():
    dot_data = tree.export_graphviz(model, 
                                    out_file=None, 
                                    feature_names=features.columns,
                                    filled=True, 
                                    rounded=True,  
                                    class_names=model.classes_,
                                    special_characters=True)
    #graph = graphviz.Source(dot_data)
    graph = pydotplus.graph_from_dot_data(dot_data)
    return graph.create_png()

In [None]:
# make sure to have installed the libraries via anaconda
# conda install -c anaconda graphviz --yes

#lets plot the tree
img = plotTree()
Image(img)

In [None]:
#first things first, its splitting on odour! from a picture we cant tell odour, so we have to remove it. 
# And infact there are a load of things we can't tell from a picture:
features = pd.get_dummies(df.drop(['class','odor','gill-attachment',
                                  'gill-spacing','stalk-shape','stalk-root','stalk-surface-above-ring',
                                 'stalk-surface-below-ring','population','gill-size','habitat','bruises','spore-print-color'],axis=1))
model = tree.DecisionTreeClassifier()
model.fit(features,classes)
img = plotTree()
Image(img)

In [None]:
#right, its now having to ask ALOT of questions to get anything its satisfied with. This isn't easy to read.
# Lets see how good it is... (remember 0 is bad)
model.score(features,classes)

In [None]:
#overfitting...
#lets cross validate
# http://scikit-learn.org/stable/modules/cross_validation.html
train_test = cross_val_score(model, features, classes)
print(train_test)

In [None]:
train_test.mean() #get the average score of all the tests...

In [None]:
parameters = {'max_depth':range(2,20)}

modelSearch = GridSearchCV(tree.DecisionTreeClassifier(), parameters)#, n_jobs=4

modelSearch.fit(features, classes)

depth = modelSearch.best_params_["max_depth"]
#tree_model = modelSearch.best_estimator_

model = tree.DecisionTreeClassifier(max_depth=depth)
model.fit(features,classes)
model.score(features,classes)

In [None]:
img = plotTree()
Image(img)