In [97]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px

from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import GridSearchCV

In [98]:
url = "https://raw.githubusercontent.com/digipodium/Datasets/main/classfication/heart.csv"
df = pd.read_csv(url)
df.head()

Unnamed: 0,age,sex,cp,trestbps,chol,fbs,restecg,thalach,exang,oldpeak,slope,ca,thal,target
0,63,1,3,145,233,1,0,150,0,2.3,0,0,1,1
1,37,1,2,130,250,0,1,187,0,3.5,0,0,2,1
2,41,0,1,130,204,0,0,172,0,1.4,2,0,2,1
3,56,1,1,120,236,0,1,178,0,0.8,2,0,2,1
4,57,0,0,120,354,0,1,163,1,0.6,2,0,2,1


In [99]:
scaler = StandardScaler()
X = df.drop(columns=["target"])
X = scaler.fit_transform(X)
y = df["target"]


In [100]:
clf = DecisionTreeClassifier()
params = {
    'criterion': ["gini", "entropy", "log_loss"],
    'splitter' : ["best", "random"],
    'max_depth': [5,10,15,20,50,100,150,200,250,300,350,400,450,500],
    'min_samples_split': [2,3,4,5,6,7,8,9,10],
}

In [101]:
grid = GridSearchCV(
    estimator=clf,
    param_grid=params,
    cv = 3,
    n_jobs = -1,
    verbose =3,
    return_train_score=True
)

In [102]:
grid.fit(X,y)

Fitting 3 folds for each of 756 candidates, totalling 2268 fits


In [103]:
grid_tree_results = pd.DataFrame(grid.cv_results_)
grid_tree_results

Unnamed: 0,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_criterion,param_max_depth,param_min_samples_split,param_splitter,params,split0_test_score,split1_test_score,split2_test_score,mean_test_score,std_test_score,rank_test_score,split0_train_score,split1_train_score,split2_train_score,mean_train_score,std_train_score
0,0.001333,4.714827e-04,0.000000,0.000000,gini,5,2,best,"{'criterion': 'gini', 'max_depth': 5, 'min_sam...",0.811881,0.712871,0.673267,0.732673,0.058295,491,0.960396,0.920792,0.940594,0.940594,0.016168
1,0.000667,4.714827e-04,0.000668,0.000472,gini,5,2,random,"{'criterion': 'gini', 'max_depth': 5, 'min_sam...",0.712871,0.722772,0.732673,0.722772,0.008084,643,0.886139,0.891089,0.910891,0.896040,0.010694
2,0.001333,4.706403e-04,0.000666,0.000471,gini,5,3,best,"{'criterion': 'gini', 'max_depth': 5, 'min_sam...",0.801980,0.732673,0.673267,0.735974,0.052599,440,0.960396,0.920792,0.940594,0.940594,0.016168
3,0.000667,4.715390e-04,0.000334,0.000472,gini,5,3,random,"{'criterion': 'gini', 'max_depth': 5, 'min_sam...",0.742574,0.801980,0.752475,0.765677,0.025987,106,0.891089,0.915842,0.905941,0.904290,0.010172
4,0.001333,4.715951e-04,0.000333,0.000472,gini,5,4,best,"{'criterion': 'gini', 'max_depth': 5, 'min_sam...",0.762376,0.702970,0.673267,0.712871,0.037046,727,0.960396,0.915842,0.935644,0.937294,0.018227
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
751,0.000667,4.714831e-04,0.000334,0.000472,log_loss,500,8,random,"{'criterion': 'log_loss', 'max_depth': 500, 'm...",0.752475,0.742574,0.782178,0.759076,0.016828,146,0.910891,0.910891,0.920792,0.914191,0.004667
752,0.001666,4.703026e-04,0.000334,0.000472,log_loss,500,9,best,"{'criterion': 'log_loss', 'max_depth': 500, 'm...",0.762376,0.722772,0.712871,0.732673,0.021389,491,0.950495,0.935644,0.935644,0.940594,0.007001
753,0.000667,4.715390e-04,0.000667,0.000472,log_loss,500,9,random,"{'criterion': 'log_loss', 'max_depth': 500, 'm...",0.772277,0.732673,0.752475,0.752475,0.016168,209,0.896040,0.891089,0.920792,0.902640,0.012993
754,0.001333,4.711462e-04,0.000334,0.000472,log_loss,500,10,best,"{'criterion': 'log_loss', 'max_depth': 500, 'm...",0.742574,0.762376,0.693069,0.732673,0.029148,491,0.935644,0.925743,0.920792,0.927393,0.006174


In [104]:
grid_tree_results.sort_values(by='rank_test_score', inplace=True)
grid_tree_results.reset_index(inplace=True)
grid_tree_results

Unnamed: 0,index,mean_fit_time,std_fit_time,mean_score_time,std_score_time,param_criterion,param_max_depth,param_min_samples_split,param_splitter,params,...,split1_test_score,split2_test_score,mean_test_score,std_test_score,rank_test_score,split0_train_score,split1_train_score,split2_train_score,mean_train_score,std_train_score
0,71,0.001003,3.132884e-06,0.000000,0.000000,gini,20,10,random,"{'criterion': 'gini', 'max_depth': 20, 'min_sa...",...,0.851485,0.801980,0.815182,0.025987,1,0.891089,0.891089,0.900990,0.894389,0.004667
1,731,0.001000,3.371748e-07,0.000000,0.000000,log_loss,450,7,random,"{'criterion': 'log_loss', 'max_depth': 450, 'm...",...,0.782178,0.782178,0.805281,0.032672,2,0.925743,0.935644,0.940594,0.933993,0.006174
2,257,0.000000,0.000000e+00,0.000333,0.000471,entropy,5,4,random,"{'criterion': 'entropy', 'max_depth': 5, 'min_...",...,0.801980,0.752475,0.801980,0.040421,3,0.871287,0.826733,0.925743,0.874587,0.040488
3,645,0.001000,1.946680e-07,0.000334,0.000472,log_loss,200,9,random,"{'criterion': 'log_loss', 'max_depth': 200, 'm...",...,0.881188,0.722772,0.801980,0.064673,3,0.910891,0.920792,0.881188,0.904290,0.016828
4,9,0.001000,6.743496e-07,0.000000,0.000000,gini,5,6,random,"{'criterion': 'gini', 'max_depth': 5, 'min_sam...",...,0.841584,0.772277,0.795380,0.032672,5,0.836634,0.900990,0.891089,0.876238,0.028294
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
751,749,0.000667,4.717076e-04,0.000334,0.000472,log_loss,500,7,random,"{'criterion': 'log_loss', 'max_depth': 500, 'm...",...,0.742574,0.702970,0.706271,0.028391,749,0.905941,0.920792,0.920792,0.915842,0.007001
752,548,0.001333,4.719890e-04,0.000667,0.000472,log_loss,15,6,best,"{'criterion': 'log_loss', 'max_depth': 15, 'mi...",...,0.702970,0.693069,0.706271,0.012349,749,0.975248,0.960396,0.965347,0.966997,0.006174
753,164,0.001334,4.720449e-04,0.000667,0.000472,gini,300,3,best,"{'criterion': 'gini', 'max_depth': 300, 'min_s...",...,0.702970,0.673267,0.702970,0.024252,754,0.990099,0.980198,0.995050,0.988449,0.006174
754,220,0.001334,4.715392e-04,0.000666,0.000471,gini,450,4,best,"{'criterion': 'gini', 'max_depth': 450, 'min_s...",...,0.702970,0.693069,0.702970,0.008084,754,0.985149,0.975248,0.960396,0.973597,0.010172


In [105]:
grid_tree_results.columns

Index(['index', 'mean_fit_time', 'std_fit_time', 'mean_score_time',
       'std_score_time', 'param_criterion', 'param_max_depth',
       'param_min_samples_split', 'param_splitter', 'params',
       'split0_test_score', 'split1_test_score', 'split2_test_score',
       'mean_test_score', 'std_test_score', 'rank_test_score',
       'split0_train_score', 'split1_train_score', 'split2_train_score',
       'mean_train_score', 'std_train_score'],
      dtype='object')

In [106]:
px.line(grid_tree_results,
       y=['mean_test_score','mean_train_score'],
       title='Decision Tree Grid Search',
       hover_data=['param_criterion', 'param_splitter', 'param_max_depth','param_min_samples_split'],)