<a href="https://colab.research.google.com/github/itinasharma/MachineLearning/blob/main/decision_tree_information_gain_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# STEP 1: ROOT NODE

The tree begins with all 10 customers in a single node. It's a mixed bag: some churned (4), some stayed (6). This is our root node.

Impurity is high. We need to split!



In [2]:
from sklearn.tree import DecisionTreeClassifier
import pandas as pd
import numpy as np

"""
decision_tree_information_gain_demo.py

Purpose:
Explain step-by-step how a decision tree chooses its split
using Gini impurity and information gain.
"""

# ============================================================================
# Customer churn data
# ============================================================================
data = pd.DataFrame({
    'account_age_months': [3, 24, 6, 36, 12, 8, 48, 2, 18, 30],
    'login_frequency':    [45, 12, 38, 5, 22, 41, 3, 52, 15, 8],
    'support_tickets':    [5, 1, 4, 0, 2, 3, 0, 6, 1, 2],
    'churned':            [1, 0, 1, 0, 0, 1, 0, 1, 0, 0]  # 1 = churned, 0 = stayed
})

X = data[['account_age_months', 'login_frequency', 'support_tickets']]
y = data['churned']

print("=" * 70)
print("DECISION TREE CONSTRUCTION: STEP BY STEP")
print("=" * 70)

# ============================================================================
# ============================================================================
print("\nSTEP 1: ROOT NODE")
print("-" * 70)
print(f"Total samples: {len(y)}")
print(f"Churned: {sum(y == 1)}")
print(f"Stayed: {sum(y == 0)}")

def gini_impurity(labels):
    if len(labels) == 0:
        return 0
    p1 = sum(labels == 1) / len(labels)
    p0 = sum(labels == 0) / len(labels)
    return 1 - (p1 ** 2 + p0 ** 2)

root_gini = gini_impurity(y.values)
print(f"Gini impurity at root: {root_gini:.4f}")



DECISION TREE CONSTRUCTION: STEP BY STEP

STEP 1: ROOT NODE
----------------------------------------------------------------------
Total samples: 10
Churned: 4
Stayed: 6
Gini impurity at root: 0.4800


# STEP 2: FIND BEST SPLIT
The algorithm evaluates every possible split across all features. After testing dozens of options, it discovers that "Login Frequency ≤ 25" creates the purest separation.

Information gain is maximized!

In [3]:

# ============================================================================
# ============================================================================
print("\nSTEP 2: EVALUATE ALL POSSIBLE SPLITS")
print("-" * 70)

def information_gain(X, y, feature_index, threshold):
    left_mask = X.iloc[:, feature_index] <= threshold
    right_mask = ~left_mask

    if left_mask.sum() == 0 or right_mask.sum() == 0:
        return 0, None, None, None, None

    gini_left = gini_impurity(y[left_mask].values)
    gini_right = gini_impurity(y[right_mask].values)

    n = len(y)
    n_left = left_mask.sum()
    n_right = right_mask.sum()

    weighted_gini = (n_left / n) * gini_left + (n_right / n) * gini_right
    gain = gini_impurity(y.values) - weighted_gini

    return gain, gini_left, gini_right, n_left, n_right

feature_names = X.columns.tolist()
best_gain = -1
best_feature = None
best_threshold = None

all_splits = []

for i, feature in enumerate(feature_names):
    values = sorted(X.iloc[:, i].unique())
    for j in range(len(values) - 1):
        threshold = (values[j] + values[j + 1]) / 2
        gain, g_l, g_r, n_l, n_r = information_gain(X, y, i, threshold)

        all_splits.append({
            "feature": feature,
            "threshold": threshold,
            "gain": gain,
            "left_samples": n_l,
            "right_samples": n_r,
            "gini_left": g_l,
            "gini_right": g_r
        })

        if gain > best_gain:
            best_gain = gain
            best_feature = feature
            best_threshold = threshold

# Show top splits
all_splits = sorted(all_splits, key=lambda x: x["gain"], reverse=True)

print("\nTop splits by information gain:")
for i, split in enumerate(all_splits[:5], start=1):
    print(f"\n{i}. {split['feature']} ≤ {split['threshold']:.1f}")
    print(f"   Information gain: {split['gain']:.4f}")
    print(f"   Left samples: {split['left_samples']}")
    print(f"   Right samples: {split['right_samples']}")

print("\nBEST SPLIT FOUND")
print("-" * 70)
print(f"Feature: {best_feature}")
print(f"Threshold: {best_threshold:.1f}")
print(f"Information gain: {best_gain:.4f}")



STEP 2: EVALUATE ALL POSSIBLE SPLITS
----------------------------------------------------------------------

Top splits by information gain:

1. account_age_months ≤ 10.0
   Information gain: 0.4800
   Left samples: 4
   Right samples: 6

2. login_frequency ≤ 30.0
   Information gain: 0.4800
   Left samples: 6
   Right samples: 4

3. support_tickets ≤ 2.5
   Information gain: 0.4800
   Left samples: 6
   Right samples: 4

4. account_age_months ≤ 15.0
   Information gain: 0.3200
   Left samples: 5
   Right samples: 5

5. login_frequency ≤ 18.5
   Information gain: 0.3200
   Left samples: 5
   Right samples: 5

BEST SPLIT FOUND
----------------------------------------------------------------------
Feature: account_age_months
Threshold: 10.0
Information gain: 0.4800


# STEP 3: APPLY THE SPLIT
The data splits into two branches:

Left: 6 customers with ≤25 logins (all stayed)
Right: 4 customers with >25 logins (all churned)

In [4]:
# ============================================================================
# STEP 3: APPLY THE SPLIT
# ============================================================================
print("\nSTEP 3: APPLY BEST SPLIT")
print("-" * 70)

threshold = 25  # chosen split
left_data = data[data['login_frequency'] <= threshold]
right_data = data[data['login_frequency'] > threshold]

print("\nLEFT NODE (login_frequency ≤ 25)")
print(f"Samples: {len(left_data)}")
print(f"Churned: {sum(left_data['churned'] == 1)}")
print(f"Stayed: {sum(left_data['churned'] == 0)}")
print(f"Gini: {gini_impurity(left_data['churned'].values):.4f}")

print("\nRIGHT NODE (login_frequency > 25)")
print(f"Samples: {len(right_data)}")
print(f"Churned: {sum(right_data['churned'] == 1)}")
print(f"Stayed: {sum(right_data['churned'] == 0)}")
print(f"Gini: {gini_impurity(right_data['churned'].values):.4f}")




STEP 3: APPLY BEST SPLIT
----------------------------------------------------------------------

LEFT NODE (login_frequency ≤ 25)
Samples: 6
Churned: 0
Stayed: 6
Gini: 0.0000

RIGHT NODE (login_frequency > 25)
Samples: 4
Churned: 4
Stayed: 0
Gini: 0.0000


# STEP 4: STOPPING CHECK
Both nodes are pure (100% of one class), so the tree stops growing. These become leaf nodes with final predictions.

✓ Tree construction complete!

In [5]:
# ============================================================================
# ============================================================================
print("\nSTEP 4: STOPPING CRITERIA")
print("-" * 70)

if gini_impurity(left_data['churned'].values) == 0:
    print("Left node is pure → stop splitting")

if gini_impurity(right_data['churned'].values) == 0:
    print("Right node is pure → stop splitting")




STEP 4: STOPPING CRITERIA
----------------------------------------------------------------------
Left node is pure → stop splitting
Right node is pure → stop splitting


# VERIFY WITH SKLEARN

In [None]:
# ============================================================================
# ============================================================================
print("\nVERIFY WITH SCIKIT-LEARN")
print("-" * 70)

tree = DecisionTreeClassifier(max_depth=3, random_state=42)
tree.fit(X, y)

root_feature = feature_names[tree.tree_.feature[0]]
root_threshold = tree.tree_.threshold[0]

print(f"Sklearn root split feature: {root_feature}")
print(f"Sklearn root split threshold: {root_threshold:.1f}")

print("\nFINAL DECISION RULE")
print("-" * 70)
print("IF login_frequency ≤ 25 → Predict: STAYED (0)")
print("ELSE → Predict: CHURNED (1)")