# Visualization Notebook

In [5]:
import sys
from pathlib import Path

src_path = Path("../src/dt-distance").resolve()
if str(src_path) not in sys.path:
    sys.path.insert(0, str(src_path))
    
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score
from dt_distance.distance_calculator import DistanceCalculator


## Pareto Frontier Visualization
- plotting pareto frontier from a collection of trees based on average distance, $d_{b}$ , $\forall b \in \mathcal{T}$ and the out-of-sample AUC_ROC score $a_{b}$, $\forall b \in \mathcal{T}$

In [None]:
def plot_pareto_frontier(distances, auc_scores, pareto_indices):
    distances = np.array(distances)
    auc_scores = np.array(auc_scores)
    pareto_indices = set(pareto_indices)
    is_pareto = np.array([i in pareto_indices for i in range(len(distances))])
    # Plotting
    plt.figure(figsize=(8, 6))
    plt.scatter(distances[~is_pareto], auc_scores[~is_pareto], c='blue', label='Dominated Trees', alpha=0.6)
    plt.scatter(distances[is_pareto], auc_scores[is_pareto], c='red', edgecolors='black', s=80, label='Pareto Optimal Trees')
    plt.xlabel("Stability (Lower is Better)")
    plt.ylabel("AUC (Higher is Better)")
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


## Mean 

In [None]:

def plot_tree_complexity_metrics(trees):

    depths = [tree.get_depth() for tree in trees]
    node_counts = [tree.tree_.node_count for tree in trees]

    fig, axs = plt.subplots(1, 2, figsize=(12, 5))

    # Tree Depth Plot
    axs[0].scatter(range(len(trees)), depths, color='blue', alpha=0.7)
    axs[0].set_title("Mean and Standard Deviation of Tree Depth")
    axs[0].set_xlabel("Dataset")
    axs[0].set_ylabel("Tree Depth")
    axs[0].grid(True)

    # Node Count Plot
    axs[1].scatter(range(len(trees)), node_counts, color='blue', alpha=0.7)
    axs[1].set_title(title_right)
    axs[1].set_xlabel("Tree Index")
    axs[1].set_ylabel("Number of Nodes")
    axs[1].grid(True)

    plt.suptitle("Tree Complexity Analysis")
    plt.tight_layout()
    plt.show()

    return {
        "depths": depths,
        "node_counts": node_counts
    }