In [4]:
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, accuracy_score

from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier

In [2]:
 titanic_df = pd.read_csv('data/titanic/processed.csv')

In [7]:
X = titanic_df.drop('Survived', axis=1)
Y= titanic_df['Survived']

x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)

In [8]:
def summ_clf(y_test, y_pred):
    acc = accuracy_score(y_test, y_pred, normalize=True)
    num_acc = accuracy_score(y_test, y_pred, normalize=False)
    
    prec = precision_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)
    
    return {'accuracy': acc,
            'precision': prec,
            'recall' : recall,
            'accuracy_count' : num_acc}

In [9]:
from sklearn.model_selection import GridSearchCV

parameters = {'max_depth': [2,4,5,7,8,9,10]}

grid = GridSearchCV(DecisionTreeClassifier(),parameters,cv=4, return_train_score=True)

grid.fit(x_train, y_train)

grid.best_params_

{'max_depth': 4}

In [11]:
for i in range(7):
    
    print('Params :', grid.cv_results_['params'][i])
    
    print('Mean Test Score:', grid.cv_results_['mean_test_score'][i])
    
    print('Rank :', grid.cv_results_['rank_test_score'][i])

Params : {'max_depth': 2}
Mean Test Score: 0.7943760984182777
Rank : 2
Params : {'max_depth': 4}
Mean Test Score: 0.7961335676625659
Rank : 1
Params : {'max_depth': 5}
Mean Test Score: 0.7785588752196837
Rank : 3
Params : {'max_depth': 7}
Mean Test Score: 0.7715289982425307
Rank : 5
Params : {'max_depth': 8}
Mean Test Score: 0.7785588752196837
Rank : 3
Params : {'max_depth': 9}
Mean Test Score: 0.7680140597539543
Rank : 7
Params : {'max_depth': 10}
Mean Test Score: 0.7715289982425307
Rank : 5


In [12]:
model = DecisionTreeClassifier(max_depth=grid.best_params_['max_depth']).fit(x_train, y_train)

y_pred = model.predict(x_test)

print(summ_clf(y_test, y_pred))

{'accuracy': 0.8111888111888111, 'precision': 0.8857142857142857, 'recall': 0.5740740740740741, 'accuracy_count': 116}
