# Chapter 5: Decision Trees

**Student Learning Version**  
Work through this notebook to understand decision trees. The solution notebook will be provided after completion.

In [None]:
low_memory=False

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from sklearn.datasets import make_blobs
from sklearn.tree import DecisionTreeClassifier
from sklearn import metrics
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import BaggingClassifier
from sklearn.metrics import classification_report

## 5.1 Introduction & Motivation

Welcome to the fascinating world of decision trees! In this chapter, we'll explore one of the most intuitive and interpretable machine learning algorithms available to data scientists today.

Now that we've gained a solid foundation in classification models, we'll advance our understanding by introducing one of the most intuitive and interpretable classifiers: **decision trees**.

Decision trees mirror the way humans naturally make decisions by breaking down complex problems into a series of simple yes/no questions. This makes them particularly valuable for understanding how predictions are made, unlike "black box" algorithms where the decision-making process is opaque.

Consider the following decision tree, which demonstrates how we can systematically classify animals into their respective families using a series of logical questions:

![](https://github.com/jakevdp/PythonDataScienceHandbook/blob/master/notebooks/figures/05.08-decision-tree.png?raw=1)

This animal classification tree perfectly illustrates how decision trees work in practice. Each internal node (like "Has fur?") represents a decision based on a feature, each branch represents the outcome of that decision, and each leaf node represents the final classification.

While we could manually construct such a tree using countless $if-else$ statements, the real power of decision trees lies in their ability to automatically learn these decision rules from data. Machine learning algorithms can analyze datasets and determine the optimal sequence of questions to ask, creating decision trees that can classify new, unseen examples with remarkable accuracy.

**Key advantages of decision trees:**
- **Interpretability**: The decision-making process is transparent and easy to explain
- **No assumptions**: They don't require assumptions about data distribution
- **Handles mixed data types**: Can work with both numerical and categorical features
- **Feature selection**: Automatically identifies the most important features

## 5.2 Problem Setting: Handwritten Digit Recognition

To demonstrate the power and versatility of decision trees, we'll tackle a classic machine learning challenge: recognizing handwritten digits.

For our Decision Tree Classifier implementation, we'll use the same digits dataset that we employed in our Logistic Regression chapter. This strategic choice serves multiple purposes:

1. **Direct Comparison**: By using identical data, we can make fair, apples-to-apples comparisons between different algorithms
2. **Consistent Evaluation**: We can apply the same metrics and evaluation criteria across different models
3. **Understanding Trade-offs**: We'll discover when decision trees might be preferred over logistic regression and vice versa

**Dataset Overview:**
The digits dataset contains 8×8 pixel grayscale images of handwritten digits (0-9), where each pixel intensity is represented as a value between 0 and 16. This creates a 64-dimensional feature space that our decision tree will navigate to make classifications.

If any of the following concepts seem unfamiliar, we recommend revisiting the previous chapter on Logistic Regression for a detailed explanation:

In [None]:
digits = datasets.load_digits()
dir(digits)
# data
# Print to show there are 1797 images (8 by 8 images for a dimensionality of 64)
print("Image Data Shape" , digits.data.shape)
# Print to show there are 1797 labels (integers from 0–9)
print("Label Data Shape", digits.target.shape)

In [None]:
plt.figure(figsize=(20,4))
for index, (image, label) in enumerate(zip(digits.data[0:5], digits.target[0:5])):
 plt.subplot(1, 5, index + 1)
 plt.imshow(np.reshape(image, (8,8)), cmap=plt.cm.gray)
 plt.title('Training: %i\n' % label, fontsize = 20)

## 5.3 Understanding Decision Trees: From Theory to Practice

Decision trees work by recursively partitioning the feature space into regions that are as homogeneous as possible with respect to the target variable. Let's explore how this process works both theoretically and practically.

### 5.3.1 The Decision Tree Algorithm: How It Works

Understanding decision trees conceptually is crucial before diving into implementation. While the mathematical foundations involve concepts like information gain, entropy, and Gini impurity, we'll focus on the intuitive understanding that makes decision trees so appealing.

The concrete mathematical setup of decision tree classifiers involves sophisticated concepts from information theory. However, the beauty of decision trees lies in their intuitive nature, which we can understand through visual examples.

**Core Concept**: Decision trees work by asking a series of questions about the features in your data, with each question designed to split the data into groups that are as "pure" as possible (containing mostly one class).

To illustrate this process, let's examine a simple two-dimensional classification problem where we can visualize exactly how the algorithm makes its decisions:

In [None]:
X, y = make_blobs(n_samples=300, centers=4,
                  random_state=0, cluster_std=1.0)
plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='rainbow');

**Understanding the Visualization:**
- **X-axis and Y-axis**: These represent two independent variables (features) in our dataset
- **Colors**: Each color represents a different class (category) that we want to predict
- **Goal**: Create a decision tree that can accurately classify new points based on their X and Y coordinates

This scatter plot shows four distinct clusters of data points, each representing a different class. A human can easily see the patterns, but how does a computer algorithm learn to distinguish between these groups?

**The Decision Tree Learning Process:**

When a decision tree algorithm analyzes this data, it follows these steps:

1. **Initial Assessment**: Start with all data points mixed together
2. **Find the Best Split**: Identify the feature and threshold that best separates the classes
3. **Create Branches**: Split the data into two groups based on this criterion
4. **Repeat Recursively**: Apply the same process to each new group
5. **Stop When Pure**: Continue until each group contains mostly one class or meets stopping criteria

**Key Insight**: The algorithm draws lines (splits) through the feature space, with each line representing a decision boundary. These lines are always parallel to the axes because each split considers only one feature at a time.

The following diagram illustrates how this iterative splitting process creates increasingly pure regions:

![](https://github.com/jakevdp/PythonDataScienceHandbook/blob/master/notebooks/figures/05.08-decision-tree-levels.png?raw=1)

### 5.3.2 Model Implementation and Training

Now that we understand the theory behind decision trees, let's implement one using scikit-learn and apply it to our handwritten digits dataset. The implementation process follows the standard machine learning workflow we've established in previous chapters.

**Step 1: Data Preparation**

Following machine learning best practices, we begin by splitting our dataset into training and testing portions. This separation is crucial for obtaining unbiased estimates of our model's performance:

In [None]:
x_train, x_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.25, random_state=0)

**Step 2: Model Creation and Training**

Similar to our Logistic Regression implementation, we instantiate a decision tree classifier and train it on our data. The scikit-learn implementation handles all the complex mathematics behind the scenes:

In [None]:
tree = DecisionTreeClassifier()
tree_fit = tree.fit(x_train, y_train)

**Step 3: Making Predictions**

With our trained model, we can now make predictions on our test set. Each prediction represents the model's best guess about which digit (0-9) is represented in each test image:

In [None]:
predictions = tree.predict(x_test)
print(predictions)

## 5.4 Model Evaluation: Comprehensive Performance Analysis

Evaluating our decision tree's performance requires examining multiple metrics to gain a complete understanding of how well our model performs. We'll use the same evaluation framework established in previous chapters to enable direct comparison with other algorithms.

### Accuracy: The Foundation Metric

**What is Accuracy?**
Accuracy measures the proportion of correct predictions out of all predictions made. It's calculated as:

$$\text{Accuracy} = \frac{\text{Number of Correct Predictions}}{\text{Total Number of Predictions}}$$

This metric provides a quick, intuitive understanding of overall model performance.

Let's calculate our decision tree's accuracy and compare it with our previous logistic regression results. This comparison will help us understand the relative strengths and weaknesses of each approach.

In [None]:
score = tree_fit.score(x_test, y_test)
print(score)

##### Question 1: Interpret the accuracy of this model. Based on this result, would you prefer the decision tree or the logistic regression classifier for this dataset? Justify your answer.

**Your Analysis:**

### Precision, Recall, and F1-Score: Class-Level Performance

While accuracy gives us an overall picture, precision, recall, and F1-score provide deeper insights into how well our model performs for each individual digit (0-9). These metrics are particularly important for multi-class problems like digit recognition.

In [None]:
print(classification_report(y_test, predictions))

##### Question 2: Analyze the precision, recall, and F1-scores for each digit class. Which digits does the decision tree classify most accurately, and which ones pose the greatest challenges? Compare these results with logistic regression performance.

**Your Detailed Analysis:**

### Confusion Matrix: Visualizing Classification Patterns

The confusion matrix provides a comprehensive view of our model's classification behavior, showing exactly which digits are being confused with others. This visualization is invaluable for understanding systematic errors in our model's predictions.

In [None]:
cm = metrics.confusion_matrix(y_test, predictions)
print(cm)

In [None]:
plt.figure(figsize=(10,10))
sns.heatmap(cm, annot=True, fmt=".3f", linewidths=.5, square = True, cmap = 'Blues_r');
plt.ylabel('Actual label');
plt.xlabel('Predicted label');
all_sample_title = 'Accuracy Score: {0}'.format(score)
plt.title(all_sample_title, size = 15);

##### Question 3: Analyze the confusion matrix visualization above. What patterns do you observe in the misclassifications? How does this compare to the logistic regression confusion matrix, and what does this tell us about the nature of each algorithm's decision-making process?

**Your Confusion Matrix Analysis:**

## 5.5 Exercises: Deepening Your Understanding

These exercises will help consolidate your understanding of decision trees and develop your skills in comparative model analysis. Work through each question systematically, using both the theoretical concepts and practical results we've explored.

##### Question 4: Comparative Confusion Matrix Analysis
Comparing confusion matrices between different algorithms can be challenging due to their complexity. Create a comprehensive visual analysis that shows the differences in prediction patterns between Logistic Regression and Decision Tree classifiers. Your analysis should help identify where each algorithm excels or struggles.

**Implementation Requirements:**
1. Train both a Logistic Regression and Decision Tree model on the same dataset
2. Generate confusion matrices for both models
3. Create a visualization showing the difference between the matrices
4. Interpret the results to understand algorithmic differences

In [None]:
from sklearn.linear_model import LogisticRegression

In [None]:
# Your implementation here
# Train both models, generate confusion matrices, and create comparative visualization

**Your Comparative Analysis:**

##### Question 5: Algorithm Comparison and Selection Guidelines

Based on your research and the practical results from this chapter, provide a comprehensive comparison between logistic regression and decision trees. Your analysis should address:

**Performance Characteristics:**
- When does each algorithm typically excel?
- What types of data favor each approach?
- How do computational requirements compare?

**Interpretability and Explainability:**
- Which algorithm provides clearer insights into decision-making?
- When might interpretability be more important than pure accuracy?

**Practical Considerations:**
- Data preprocessing requirements
- Hyperparameter tuning complexity  
- Scalability to large datasets

**Recommendation Framework:**
Develop guidelines for choosing between these algorithms based on project requirements, data characteristics, and business constraints.

**Your Comprehensive Analysis:**

**Logistic Regression Advantages:**

**Decision Tree Advantages:**

**When to Choose Logistic Regression:**

**When to Choose Decision Trees:**

**Recommendation Framework:**