In [1]:
% matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn import datasets
from sklearn.ensemble import RandomForestClassifier, ExtraTreesClassifier
from sklearn import tree

  from numpy.core.umath_tests import inner1d


In [2]:
def treeCrossValidate(data, model, name=''):
    data['train'] = (np.random.uniform(0,1, len(data)) * 10).astype(int)
    err = 0
    print(f"\n*** {name} ***")
    for i in range(10):
        data_train = data[data['train'] != i]
        data_test = data[data['train'] == i]
        
        y_train = data_train['class']
        X_train = data_train.drop('class', axis = 1)
        model.fit(X_train, y_train)
        y_test = data_test['class']
        X_test = data_test.drop('class', axis = 1)
        
        err += model.score(X_test, y_test)
#         print(f'\t{model.score(X_test, y_test)}')
    print('Score: ', err/10)
    return err

In [11]:
data = pd.read_csv('data/bank-additional-full.csv',
                   sep=';',
                   names=[
                       'age',
                       'job',
                       'marital',
                       'education',
                       'default',
                       'housing',
                       'loan',
                       'contact',
                       'month',
                       'day_of_week',
                       'duration',
                       'campaign',
                       'pdays',
                       'previous',
                       'poutcome',
                       'emp.var.rate',
                       'cons.price.idx',
                       'cons.conf.idx',
                       'euribor3m',
                       'nr.employed',
                       'class'
                   ]
                   )

data.head()

Unnamed: 0,age,job,marital,education,default,housing,loan,contact,month,day_of_week,...,campaign,pdays,previous,poutcome,emp.var.rate,cons.price.idx,cons.conf.idx,euribor3m,nr.employed,class
0,56,housemaid,married,basic.4y,no,no,no,telephone,may,mon,...,1,999,0,nonexistent,1.1,93.994,-36.4,4.857,5191.0,no
1,57,services,married,high.school,unknown,no,no,telephone,may,mon,...,1,999,0,nonexistent,1.1,93.994,-36.4,4.857,5191.0,no
2,37,services,married,high.school,no,yes,no,telephone,may,mon,...,1,999,0,nonexistent,1.1,93.994,-36.4,4.857,5191.0,no
3,40,admin.,married,basic.6y,no,no,no,telephone,may,mon,...,1,999,0,nonexistent,1.1,93.994,-36.4,4.857,5191.0,no
4,56,services,married,high.school,no,no,yes,telephone,may,mon,...,1,999,0,nonexistent,1.1,93.994,-36.4,4.857,5191.0,no


In [12]:
data = data.drop('duration', axis=1)
for name in data.columns.values:
    if data[name].dtype == 'object':
        data[name] = pd.Categorical(data[name]).codes
data.head()

Unnamed: 0,age,job,marital,education,default,housing,loan,contact,month,day_of_week,campaign,pdays,previous,poutcome,emp.var.rate,cons.price.idx,cons.conf.idx,euribor3m,nr.employed,class
0,56,3,1,0,0,0,0,1,6,1,1,999,0,1,1.1,93.994,-36.4,4.857,5191.0,0
1,57,7,1,3,1,0,0,1,6,1,1,999,0,1,1.1,93.994,-36.4,4.857,5191.0,0
2,37,7,1,3,0,2,0,1,6,1,1,999,0,1,1.1,93.994,-36.4,4.857,5191.0,0
3,40,0,1,1,0,0,0,1,6,1,1,999,0,1,1.1,93.994,-36.4,4.857,5191.0,0
4,56,7,1,3,0,0,2,1,6,1,1,999,0,1,1.1,93.994,-36.4,4.857,5191.0,0


In [13]:
models = [
    (
        RandomForestClassifier(n_estimators=20, max_depth=13, min_samples_split=10),
        'RandomForestClassifier'
    ),
    (
        ExtraTreesClassifier(n_estimators=30, min_samples_split=20, max_leaf_nodes=15),
        'ExtraTreesClassifier'
    ),
    (
        tree.DecisionTreeClassifier(), 'Default DecisionTreeClassifier'
    ),
    (
        RandomForestClassifier(), 'Default RandomForestClassifier'
    )
]

In [14]:
for model, name in models:
    treeCrossValidate(data, model, name)


*** RandomForestClassifier ***
Score:  0.9000991831640344

*** ExtraTreesClassifier ***
Score:  0.8991793392494056

*** Default DecisionTreeClassifier ***
Score:  0.8358843836197917

*** Default RandomForestClassifier ***
Score:  0.8948574144917437
