# Pareto Front Analysis: GA-Optimized Decision Trees

This notebook uses NSGA-II multi-objective optimization to explore the
accuracy vs. interpretability trade-off for decision trees on the Iris dataset.

**Steps:**
1. Run Pareto-front evolution
2. Visualize the Pareto front
3. Identify the 'knee' solution (best balanced trade-off)
4. Inspect decision rules for selected solutions

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris

from ga_trees.ga.multi_objective import ParetoOptimizer
from ga_trees.ga.engine import TreeInitializer, Mutation
from ga_trees.fitness.calculator import FitnessCalculator

iris = load_iris()
X, y = iris.data, iris.target
feature_names = list(iris.feature_names)
n_features = X.shape[1]
n_classes = len(np.unique(y))
feature_ranges = {i: (float(X[:, i].min()), float(X[:, i].max())) for i in range(n_features)}
print(f'Iris: {X.shape[0]} samples, {n_features} features, {n_classes} classes')

## 1. Setup and Run NSGA-II Evolution

In [None]:
initializer = TreeInitializer(
    n_features=n_features,
    n_classes=n_classes,
    max_depth=5,
    min_samples_split=10,
    min_samples_leaf=5,
)

# FitnessCalculator in pareto mode returns (accuracy, interpretability) tuple
fc = FitnessCalculator(mode='pareto', accuracy_weight=0.7, interpretability_weight=0.3)

mut = Mutation(n_features, feature_ranges)
default_mutation_types = {
    'threshold_perturbation': 0.4,
    'feature_replacement': 0.3,
    'prune_subtree': 0.2,
    'expand_leaf': 0.1,
}

def mutation_fn(tree):
    return mut.mutate(tree, default_mutation_types)

optimizer = ParetoOptimizer(
    initializer=initializer,
    fitness_fn=fc.calculate_fitness,
    mutation_fn=mutation_fn,
    crossover_prob=0.7,
    mutation_prob=0.2,
    random_state=42,
)
print('ParetoOptimizer created')

In [None]:
pareto_front = optimizer.evolve_pareto_front(
    X, y,
    population_size=30,
    n_generations=20,
    verbose=True,
)
print(f'\nPareto front size: {len(pareto_front)} solutions')

## 2. Visualize the Pareto Front

In [None]:
# Re-evaluate each solution on the full dataset to get fitness values
accuracies = []
interpretabilities = []
for tree in pareto_front:
    acc, interp = fc.calculate_fitness(tree, X, y)
    accuracies.append(acc)
    interpretabilities.append(interp)

fig, ax = plt.subplots(figsize=(9, 6))
scatter = ax.scatter(
    interpretabilities, accuracies,
    c=range(len(pareto_front)), cmap='viridis',
    s=80, zorder=3
)
ax.set_xlabel('Interpretability Score', fontsize=13)
ax.set_ylabel('Accuracy', fontsize=13)
ax.set_title('Pareto Front: Accuracy vs. Interpretability (Iris)', fontsize=14)
plt.colorbar(scatter, ax=ax, label='Solution index')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 3. Identify the 'Knee' Solution

The knee point maximises the perpendicular distance from the line connecting
the extreme points of the Pareto front — the best balanced trade-off.

In [None]:
points = np.array(list(zip(interpretabilities, accuracies)))

# Normalise to [0,1] range
p_min, p_max = points.min(axis=0), points.max(axis=0)
p_range = np.where(p_max - p_min > 0, p_max - p_min, 1.0)
norm = (points - p_min) / p_range

# Line from extreme interpretability to extreme accuracy
a = norm[np.argmin(norm[:, 1])]   # lowest accuracy
b = norm[np.argmax(norm[:, 0])]   # highest interpretability
line_vec = b - a
line_len = np.linalg.norm(line_vec)

# Perpendicular distance for each point
distances = []
for p in norm:
    vec = p - a
    proj = np.dot(vec, line_vec) / (line_len ** 2) * line_vec
    perp = vec - proj
    distances.append(np.linalg.norm(perp))

knee_idx = int(np.argmax(distances))
knee_tree = pareto_front[knee_idx]
knee_acc = accuracies[knee_idx]
knee_interp = interpretabilities[knee_idx]

print(f'Knee solution (index {knee_idx}):')
print(f'  Accuracy:          {knee_acc:.4f}')
print(f'  Interpretability:  {knee_interp:.4f}')
print(f'  Nodes:             {knee_tree.get_num_nodes()}')
print(f'  Depth:             {knee_tree.get_depth()}')

# Overlay knee on the Pareto front plot
fig2, ax2 = plt.subplots(figsize=(9, 6))
ax2.scatter(interpretabilities, accuracies, s=70, label='Pareto solutions', zorder=2)
ax2.scatter([knee_interp], [knee_acc], s=200, marker='*', color='red',
            zorder=3, label=f'Knee (idx={knee_idx})')
ax2.set_xlabel('Interpretability Score', fontsize=13)
ax2.set_ylabel('Accuracy', fontsize=13)
ax2.set_title('Pareto Front with Knee Solution Highlighted', fontsize=14)
ax2.legend(fontsize=12)
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 4. Inspect Decision Rules

Print the interpretable rules extracted from three representative solutions:
- **Most accurate** solution
- **Knee** (best balance)
- **Most interpretable** solution

In [None]:
def print_rules(tree, label, acc, interp):
    print(f'=== {label} ===')
    print(f'    Accuracy: {acc:.4f}  |  Interpretability: {interp:.4f}')
    print(f'    Nodes: {tree.get_num_nodes()}  |  Depth: {tree.get_depth()}')
    rules = tree.to_rules()
    if rules:
        for rule in rules[:10]:  # limit to first 10 rules
            print(f'    {rule}')
        if len(rules) > 10:
            print(f'    ... ({len(rules) - 10} more rules)')
    else:
        print('    (no rules — leaf-only tree)')
    print()

best_acc_idx = int(np.argmax(accuracies))
best_interp_idx = int(np.argmax(interpretabilities))

print_rules(
    pareto_front[best_acc_idx], 'Most Accurate',
    accuracies[best_acc_idx], interpretabilities[best_acc_idx]
)
print_rules(
    knee_tree, 'Knee (Best Balance)',
    knee_acc, knee_interp
)
print_rules(
    pareto_front[best_interp_idx], 'Most Interpretable',
    accuracies[best_interp_idx], interpretabilities[best_interp_idx]
)

## Summary

The Pareto front reveals the inherent trade-off between accuracy and interpretability:

- **Most accurate** trees tend to be deeper and use more nodes.
- **Most interpretable** trees use fewer nodes but sacrifice some accuracy.
- The **knee solution** provides a principled way to select a balanced model without
  manually tuning the accuracy/interpretability weight.