In [1]:
import os
import json
import re
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict

folder = "./misclassification-details" 

output_folder = "./output"
os.makedirs(output_folder, exist_ok=True)

In [6]:


def draw_decision_tree(json_path, csv_path, output_path, data_name=None, fold=None, depth=None):
    with open(json_path) as f:
        raw_tree = json.load(f)

    df = pd.read_csv(csv_path)
    feature_cols = [str(c) for c in df.columns if c not in ("y", "prediction")]

    tree = {}
    for node_id, feature_id in raw_tree:
        tree[node_id] = {
            "feature": str(feature_id + 1), ### for new
            "left": 2 * node_id,
            "right": 2 * node_id + 1
        }

    leaf_stats = defaultdict(lambda: {"total": 0, "misclass": 0})

    def find_leaf(row):
        node = 1
        while node in tree:
            feature = tree[node]["feature"]
            node = tree[node]["right"] if bool(row[feature]) else tree[node]["left"]
        return node

    for _, row in df.iterrows():
        leaf = find_leaf(row)
        leaf_stats[leaf]["total"] += 1
        if row["y"] != row["prediction"]:
            leaf_stats[leaf]["misclass"] += 1

    node_texts = {}
    for node_id, data in tree.items():
        node_texts[node_id] = f"$Node_{{{node_id}}}$\n$x_{{{data['feature']}}}=1$"
    for leaf, stats in leaf_stats.items():
        node_texts[leaf] = f"$Node_{{{leaf}}}$\nMisclass: {stats['misclass']}\nTotal: {stats['total']}"

    position_map = {}
    x_counter = [0]

    def layout_recursive(node, depth=0):
        if node not in tree:
            x = x_counter[0]
            position_map[node] = (x, -depth)
            x_counter[0] += 1
            return x
        left_x = layout_recursive(tree[node]["left"], depth + 1)
        right_x = layout_recursive(tree[node]["right"], depth + 1)
        x = (left_x + right_x) / 2
        position_map[node] = (x, -depth)
        return x

    layout_recursive(1)

    fig, ax = plt.subplots(figsize=(12, 6))
    ax.axis("off")

    def draw_node(text, x, y, parent_coords=None, color="lightblue"):
        ax.text(
            x, y, text, ha="center", va="center",
            bbox=dict(facecolor=color, edgecolor='black', boxstyle='round,pad=0.5'),
            fontsize=9
        )
        if parent_coords:
            ax.plot([parent_coords[0], x], [parent_coords[1], y], color="black")

    for node, (x, y) in position_map.items():
        parent_id = node // 2 if node != 1 else None
        parent_coords = position_map.get(parent_id)
        is_leaf = node not in tree
        text = node_texts.get(node, f"$Node_{{{node}}}$")
        color = "lightgreen" if is_leaf else "lightblue"
        draw_node(text, x, y, parent_coords, color)

    total_misclass = sum(s["misclass"] for s in leaf_stats.values())
    total_rows = sum(s["total"] for s in leaf_stats.values())
    accuracy = 1 - (total_misclass / total_rows) if total_rows > 0 else 0

    meta_text = f"Total Misclassification: {total_misclass}\nAccuracy: {accuracy:.2%}"
    if data_name or fold or depth:
        meta_text += f"\n\nDataset: {data_name.replace('mcart','rstg') or ''}\nFold: {fold or ''}\nDepth: {depth or ''}"

    ax.text(
        0.01, 1.02,  # Top-left corner
        meta_text,
        ha="left", va="top", transform=ax.transAxes,
        fontsize=11,
        bbox=dict(facecolor="white", edgecolor="black", boxstyle="round,pad=0.5")
    )

    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()

# --- Loop Over Files ---
for filename in os.listdir(folder):
    if filename.endswith(".json"):
        base_name = filename.replace(".json", "")
        json_path = os.path.join(folder, base_name + ".json")
        csv_path = os.path.join(folder, base_name + ".csv")
        output_path = os.path.join(output_folder, base_name + ".png")

        match = re.match(r"(.+?)_fold=(\d+)_depth=(\d+)", base_name)
        if match:
            data_name, fold, depth = match.groups()
        else:
            data_name = fold = depth = None
            print("Skipping ", data_name)
        if os.path.exists(csv_path):
            print(f"Processing {base_name}...")
            draw_decision_tree(json_path, csv_path, output_path, data_name, fold, depth)


Processing wdbc_test_mcart_fold=3_depth=6...


  plt.tight_layout()
  plt.tight_layout()


Processing wdbc_test_mcart_fold=1_depth=3...
Processing wine_test_mcart_fold=10_depth=5...


  plt.tight_layout()
  plt.tight_layout()


Processing wdbc_test_mcart_fold=6_depth=3...
Processing wine_test_mcart_fold=3_depth=7...


  plt.tight_layout()


Processing wine_test_mcart_fold=6_depth=8...


  plt.tight_layout()


Processing wdbc_test_mcart_fold=10_depth=3...
Processing wine_test_mcart_fold=2_depth=6...


  plt.tight_layout()
  plt.tight_layout()


Processing wine_test_mcart_fold=5_depth=6...


  plt.tight_layout()


Processing wine_test_mcart_fold=5_depth=7...


  plt.tight_layout()


Processing wine_test_mcart_fold=2_depth=7...


  plt.tight_layout()
  plt.tight_layout()


Processing wdbc_test_mcart_fold=7_depth=3...
Processing wine_test_mcart_fold=10_depth=8...


  plt.tight_layout()


Processing wine_test_mcart_fold=3_depth=6...


  plt.tight_layout()


Processing titanic_test_mcart_fold=4_depth=4...


  plt.tight_layout()


Processing wine_test_mcart_fold=5_depth=8...


  plt.tight_layout()


Processing wine_test_mcart_fold=2_depth=8...


  plt.tight_layout()


Processing wine_test_mcart_fold=3_depth=5...
Processing wine_test_mcart_fold=10_depth=7...


  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


Processing wine_test_mcart_fold=5_depth=4...
Processing wine_test_mcart_fold=6_depth=6...


  plt.tight_layout()
  plt.tight_layout()


Processing wdbc_test_mcart_fold=3_depth=3...
Processing wine_test_mcart_fold=3_depth=8...


  plt.tight_layout()


Processing wine_test_mcart_fold=6_depth=7...


  plt.tight_layout()


Processing wine_test_mcart_fold=2_depth=5...
Processing wine_test_mcart_fold=5_depth=5...


  plt.tight_layout()
  plt.tight_layout()


Processing wine_test_mcart_fold=10_depth=6...
Processing wdbc_test_mcart_fold=3_depth=4...


  plt.tight_layout()
  plt.tight_layout()


Processing titanic_test_mcart_fold=1_depth=3...
Processing wdbc_test_mcart_fold=3_depth=5...


  plt.tight_layout()
  plt.tight_layout()
