In [None]:
import pandas as pd
from sklearn.tree import DecisionTreeClassifier # Import Decision Tree Classifier
from sklearn.model_selection import train_test_split # Import train_test_split function
from sklearn import metrics
from sklearn.tree import export_graphviz
from six import StringIO
from IPython.display import Image  
import pydotplus
import matplotlib.pyplot as plt 

In [None]:
#Reading in Dataset
df = pd.read_csv('BankChurners.csv')

#Subsetting to needed columns
df = df[df.columns[1:21]]

#Changing Attrition_Flag column to 0s and 1s
df['Attrition_Flag'] = df['Attrition_Flag'].map({'Existing Customer': 0, 'Attrited Customer': 1})

#Dummy variables for all Categorical Data
df = pd.get_dummies(df, columns=["Gender", "Education_Level","Marital_Status", "Income_Category", "Card_Category"])
df.iloc[14:] = df.iloc[14:].astype(int)

X = df.loc[:, df.columns != 'Attrition_Flag']
y = df.loc[:, df.columns == 'Attrition_Flag']

y = y.astype('int')



In [None]:
# Creating Training and Testing Splits
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)

In [None]:
# Create Decision Tree classifer object
clf = DecisionTreeClassifier()

# Train Decision Tree Classifer
clf = clf.fit(X_train,y_train)

# Predicting Churn
y_pred = clf.predict(X_test)

In [None]:
# Model Accuracy Metric
print("Accuracy:",metrics.accuracy_score(y_test, y_pred))

In [None]:
# Plotting the Tree
dot_data = StringIO()
export_graphviz(clf, out_file=dot_data,  
                filled=True, rounded=True,
                special_characters=True,feature_names = X.columns,class_names=['0','1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
graph.write_png('diabetes.png')
Image(graph.create_png())

In [None]:
#Creating a New, Simpler Tree

In [None]:
# Create Decision Tree classifer object
clf_simple = DecisionTreeClassifier(criterion="gini", max_depth=3)

# Train Decision Tree Classifer
clf_simple = clf_simple.fit(X_train,y_train)

#Predict Churn
y_pred_simple = clf_simple.predict(X_test)

# Model Metrics
print("Accuracy:",metrics.accuracy_score(y_test, y_pred_simple))
print("Precision:",metrics.precision_score(y_test, y_pred_simple))
print("Recall:",metrics.recall_score(y_test, y_pred_simple))

In [None]:
#ROC Curve
fpr, tpr, _ = metrics.roc_curve(y_test,  y_pred_simple)
auc = metrics.roc_auc_score(y_test, y_pred_simple)
plt.plot(fpr,tpr,label="data 1, auc="+str(auc))
plt.legend(loc=4)
plt.show()

In [None]:
# Plotting the Simple Tree
dot_data = StringIO()
export_graphviz(clf_simple, out_file=dot_data,  
                filled=True, rounded=True,
                special_characters=True, feature_names = X.columns,class_names=['0','1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
graph.write_png('diabetes.png')
Image(graph.create_png())