<a href="https://colab.research.google.com/github/earthianhivemind/DLlearning/blob/main/Decision_tree_example_python.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Classification Tree Example in Python

We will run a simple example of a classification tree in Python with the aim of looking in more detail at the output we get from the scikit-learn package.

We will be trying to predict whether a passenger survived the Titanic accident, based on their Age, Sex and the Class they were travelling on.

In [None]:
# Installing and loading packages
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.datasets import fetch_openml

# Set style for better-looking plots
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)

## Looking at the data

In [None]:
# Loading the Titanic data
# We'll create the aggregated dataset similar to the R version
titanic_df = pd.DataFrame({
    'Class': ['1st', '2nd', '3rd', 'Crew'] * 8,
    'Sex': ['Male']*16 + ['Female']*16,
    'Age': (['Child']*8 + ['Adult']*8) * 2,
    'Survived': ['No', 'No', 'No', 'No', 'Yes', 'Yes', 'Yes', 'Yes'] * 4,
    'Freq': [0, 0, 35, 0, 5, 11, 13, 0,  # Male Child
             118, 154, 387, 670, 57, 14, 75, 192,  # Male Adult
             0, 0, 17, 0, 1, 13, 14, 0,  # Female Child
             4, 13, 89, 3, 140, 80, 76, 20]  # Female Adult
})

print(titanic_df.head(10))
print("\nDataset summary:")
print(titanic_df.describe())
print("\nData types:")
print(titanic_df.dtypes)

The dataset shows the variables: Class (1st, 2nd, 3rd, and Crew), Sex (Male, Female), Age (Child, Adult), and Survived (No, Yes).

Freq contains the number of cases observed for a particular combination of Class, Sex, Age and Survived.

In [None]:
# Convert categorical variables to appropriate types
titanic_df['Class'] = pd.Categorical(titanic_df['Class'], categories=['1st', '2nd', '3rd', 'Crew'], ordered=True)
titanic_df['Sex'] = pd.Categorical(titanic_df['Sex'])
titanic_df['Age'] = pd.Categorical(titanic_df['Age'])
titanic_df['Survived'] = pd.Categorical(titanic_df['Survived'])

print(titanic_df.info())

## Reshaping the data

The Titanic dataset in R is aggregated: each row contains a frequency count representing multiple individuals.

We need to expand this data so that we have one row per individual passenger. This is done by replicating rows according to their Freq values.

In [None]:
# Reshaping of the data to one row per individual
expanded = titanic_df.loc[titanic_df.index.repeat(titanic_df['Freq'])].reset_index(drop=True)
expanded = expanded[['Class', 'Sex', 'Age', 'Survived']]

print(f"Expanded dataset shape: {expanded.shape}")
print(f"Total passengers: {len(expanded)}")
print("\nFirst 10 rows:")
print(expanded.head(10))

Now we can start looking at some descriptive statistics of the data. The plot below tells us about how different are the proportions of survivors by variable Sex. It is clear that Sex is a very strong predictor, as Females have a much larger chance at having survived compared to Males.

In [None]:
# Some descriptive statistics about the data
# Sex
sex_survival = pd.crosstab(expanded['Sex'], expanded['Survived'], normalize='index') * 100

fig, ax = plt.subplots(figsize=(8, 6))
sex_survival.plot(kind='bar', stacked=True, ax=ax, color=['#F8766D', '#00BFC4'])
ax.set_ylabel('Proportion (%)')
ax.set_xlabel('Sex')
ax.set_title('Survival Proportion by Sex')
ax.set_xticklabels(ax.get_xticklabels(), rotation=0)
ax.legend(title='Survived', labels=['No', 'Yes'])

# Add percentage labels on bars
for container in ax.containers:
    ax.bar_label(container, fmt='%.1f%%', label_type='center')

plt.tight_layout()
plt.show()

print("\nSurvival rates by Sex:")
print(sex_survival)

We will finish by looking at the differences by Age. Although it is clearly a good predictor, notice that is less powerful at separating survivors from non-survivors than Sex.

In [None]:
# Some descriptive statistics about the data
# Age
age_survival = pd.crosstab(expanded['Age'], expanded['Survived'], normalize='index') * 100

fig, ax = plt.subplots(figsize=(8, 6))
age_survival.plot(kind='bar', stacked=True, ax=ax, color=['#F8766D', '#00BFC4'])
ax.set_ylabel('Proportion (%)')
ax.set_xlabel('Age')
ax.set_title('Survival Proportion by Age')
ax.set_xticklabels(ax.get_xticklabels(), rotation=0)
ax.legend(title='Survived', labels=['No', 'Yes'])

# Add percentage labels on bars
for container in ax.containers:
    ax.bar_label(container, fmt='%.1f%%', label_type='center')

plt.tight_layout()
plt.show()

print("\nSurvival rates by Age:")
print(age_survival)

## Training the Classification Tree

Now we will build our classification tree model using scikit-learn's DecisionTreeClassifier.

In [None]:
# Prepare the data for modeling
# Convert categorical variables to numerical using one-hot encoding
X = pd.get_dummies(expanded[['Class', 'Sex', 'Age']], drop_first=False)
y = (expanded['Survived'] == 'Yes').astype(int)  # Convert to binary: 1 for Yes, 0 for No

print("Feature columns:")
print(X.columns.tolist())
print(f"\nFeature matrix shape: {X.shape}")
print(f"Target variable shape: {y.shape}")

In [None]:
# Training the classification tree
# Using parameters similar to R's rpart defaults
model = DecisionTreeClassifier(
    criterion='gini',  # Similar to R's rpart default
    random_state=42,
    min_samples_split=20,  # Minimum samples to split a node
    min_samples_leaf=7     # Minimum samples in a leaf
)

model.fit(X, y)

print("Decision Tree Model trained successfully!")
print(f"\nTree depth: {model.get_depth()}")
print(f"Number of leaves: {model.get_n_leaves()}")
print(f"Number of features used: {model.n_features_in_}")

# Feature importance
feature_importance = pd.DataFrame({
    'Feature': X.columns,
    'Importance': model.feature_importances_
}).sort_values('Importance', ascending=False)

print("\nFeature Importances:")
print(feature_importance)

In [None]:
# Get predictions and accuracy
predictions = model.predict(X)
accuracy = (predictions == y).mean()

print(f"Training Accuracy: {accuracy:.2%}")

# Confusion matrix
from sklearn.metrics import confusion_matrix, classification_report

cm = confusion_matrix(y, predictions)
print("\nConfusion Matrix:")
print(pd.DataFrame(cm,
                   index=['Actual: No', 'Actual: Yes'],
                   columns=['Predicted: No', 'Predicted: Yes']))

print("\nClassification Report:")
print(classification_report(y, predictions, target_names=['No', 'Yes']))

## Visualizing the Decision Tree

In [None]:
# Plotting the tree with better visualization
plt.figure(figsize=(20, 10))
plot_tree(model,
          feature_names=X.columns,
          class_names=['No', 'Yes'],
          filled=True,
          rounded=True,
          fontsize=10)
plt.title("Decision Tree: Titanic Survival Classification", fontsize=16, pad=20)
plt.tight_layout()
plt.show()

In [None]:
# Alternative: Using graphviz for a cleaner tree visualization (if available)
from sklearn.tree import export_text

# Text representation of the tree
tree_rules = export_text(model, feature_names=list(X.columns))
print("Decision Tree Rules:")
print(tree_rules)

# Looking at the tree plot

* The root node (top box) shows the distribution before any splits take place. It displays the gini impurity, total samples, and the value counts for each class.

* The first split indicates the most important variable. Here, Sex is the key predictor with males and females being separated.

* Subsequent splits refine those decisions. For males, Age becomes important as children have higher survival rates. For females, Class becomes the distinguishing factor.

* Looking at the tree allows us to see how the splitting happened. The boxes at the bottom (leaf nodes) show:
  - The predicted class (No or Yes)
  - The gini impurity (how pure the node is)
  - The number of samples in that node
  - The distribution of classes [No, Yes]

* The color coding helps visualize the prediction: darker orange for "No" (didn't survive) and darker blue for "Yes" (survived).

* The whole population is being split into groups based on combinations of Sex, Age, and Class.

* Something you may notice is that the final nodes getting a prediction of "Yes" are typically more pure than those predicting "No". This means nodes predicting survival contain mostly survivors, whereas nodes predicting death contain a mix. This relates to the classification error, which we will explore in more detail later.