In [1]:
from __future__ import annotations
import json
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from tqdm.auto import tqdm
import webbrowser
import networkx as nx
from pyvis.network import Network
import concurrent.futures
import contextlib
import random
import os

# Additional imports for Dash Cytoscape visualization.
import dash
from jupyter_dash import JupyterDash
import dash_cytoscape as cyto
import dash_html_components as html
import threading

from IPython.display import clear_output

cyto.load_extra_layouts()

def load_taxonomy_json(filename: str = "taxonomy.json") -> nx.DiGraph:
    """
    Load a taxonomy (graph) from a JSON file in node-link format and convert it into a
    NetworkX graph.

    Args:
        filename (str, optional): The filename of the JSON file to load.
            Defaults to "taxonomy.json".

    Returns:
        networkx.DiGraph: The reconstructed directed graph.
    """
    with open(filename, "r", encoding="utf-8") as f:
        data = json.load(f)
    G = nx.node_link_graph(data, edges="edges")
    return G



def visualize_graph_dash(G: nx.DiGraph, host="127.0.0.1", port=8050) -> None:
    """
    Visualize a NetworkX graph using Dash Cytoscape in an external web browser,
    using the dagre layout for a balanced hierarchical tree.
    """
    # Convert nodes for Cytoscape.
    nodes = []
    for node_id, data in G.nodes(data=True):
        label = data.get("label", str(node_id))
        nodes.append({"data": {"id": str(node_id), "label": label}})

    # Convert edges.
    edges = []
    for source, target in G.edges():
        edges.append({"data": {"source": str(source), "target": str(target)}})

    app = JupyterDash(__name__)
    app.layout = html.Div(
        [
            cyto.Cytoscape(
                id="cytoscape-graph",
                elements=nodes + edges,
                layout={
                    "name": "dagre",
                    "rankDir": "LR",  # orient the tree from left to right
                    "nodeSep": 30,
                    "edgeSep": 10,
                    "rankSep": 70,
                    "padding": 10,
                },
                style={"width": "100%", "height": "100vh"},
                stylesheet=[
                    {
                        "selector": "node",
                        "style": {
                            "label": "data(label)",
                            "background-color": "#87CEFA",  # light sky blue
                            "color": "#000",  # black font color
                            "font-size": "14px",
                            "text-valign": "bottom",
                            "text-halign": "center",
                            "text-margin-y": "10px",
                            "width": "40px",  # bigger node size
                            "height": "40px",
                            "text-wrap": "wrap",  # allow text to wrap
                            "text-max-width": "200px",  # increase maximum text width
                        },
                    },
                    {"selector": "edge", "style": {"line-color": "#B3B3B3"}},
                ],
            )
        ]
    )

    def open_browser():
        webbrowser.open(f"http://{host}:{port}")

    threading.Timer(1, open_browser).start()

    # Clear the inline output in the notebook.
    clear_output(wait=True)

    # Remove any existing server thread for (host, port) to prevent thread-kill errors.
    if (host, port) in app._server_threads:
        del app._server_threads[(host, port)]

    app.run_server(mode="external", debug=True, host=host, port=port)


  from .autonotebook import tqdm as notebook_tqdm
The dash_html_components package is deprecated. Please replace
`import dash_html_components as html` with `from dash import html`
  import dash_html_components as html


In [6]:
###############################################
# Main Execution: Visualize the Hierarchical Tree
###############################################

# Configuration parameters
num_classes = 10000
top_p = 0.9
max_depth = 6  # For debugging, try a shallow tree first.
max_width = 16  # Maximum children per node.
max_attempts = 32
temperature = 1.5
port = 8050

filename = (
    f"trees/num_classes_{num_classes}_top_p_{top_p}_max_depth_{max_depth}"
    f"_max_width_{max_width}_max_attempts_{max_attempts}"
    f"_temperature_{temperature}"
)

# load and visualize the taxonomy.
G = load_taxonomy_json(filename=f"{filename}.json")
visualize_graph_dash(G, host="127.0.0.1", port=port)

Dash app running on http://127.0.0.1:8050/


In [9]:
import glob
import json

file_pattern = './trees/*.json'
search_term = '(Q5)'  # the string to search for

results = []

# Loop through all JSON files matching the pattern
for file_path in glob.glob(file_pattern):
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
        
        # Get the nodes list (if it exists)
        nodes = data.get('nodes', [])
        
        # Count how many node labels contain the search term
        count = sum(1 for node in nodes if search_term in node.get('label', ''))
        
        results.append((file_path, count))
    except Exception as e:
        print(f"Error processing {file_path}: {e}")

# Sort results in descending order of the counts
results.sort(key=lambda x: x[1], reverse=True)

# Print out the filename and the count in sorted order
for file_path, count in results:
    print(f"{file_path}: {count}")


./trees/num_classes_10_top_p_0.9_max_depth_8_max_width_16_max_attempts_32_temperature_1.5.json: 5
./trees/num_classes_10000_top_p_0.9_max_depth_5_max_width_16_max_attempts_32_temperature_1.5.json: 0
./trees/num_classes_1000_top_p_0.9_max_depth_5_max_width_16_max_attempts_32_temperature_1.5.json: 0
./trees/num_classes_10000_top_p_0.9_max_depth_6_max_width_16_max_attempts_32_temperature_1.5.json: 0
./trees/num_classes_10_top_p_0.9_max_depth_4_max_width_16_max_attempts_32_temperature_1.5.json: 0
./trees/num_classes_100_top_p_0.9_max_depth_4_max_width_16_max_attempts_32_temperature_1.5.json: 0
./trees/num_classes_10_top_p_0.9_max_depth_6_max_width_16_max_attempts_32_temperature_1.5.json: 0
./trees/num_classes_1000_top_p_0.9_max_depth_7_max_width_16_max_attempts_32_temperature_1.5.json: 0
./trees/num_classes_100_top_p_0.9_max_depth_6_max_width_16_max_attempts_32_temperature_1.5.json: 0
./trees/num_classes_10_top_p_0.9_max_depth_7_max_width_16_max_attempts_32_temperature_1.5.json: 0
./trees/

In [7]:
import glob
import json
import math
from collections import Counter

def calculate_entropy(hist):
    """Calculate Shannon entropy given a histogram (Counter) of counts."""
    total = sum(hist.values())
    if total == 0:
        return 0
    entropy = 0
    for count in hist.values():
        if count > 0:
            p = count / total
            entropy -= p * math.log2(p)
    return entropy

file_pattern = './trees/*.json'
results = []

# Process each JSON file matching the pattern
for file_path in glob.glob(file_pattern):
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
        
        nodes = data.get('nodes', [])
        # Extract the node labels
        labels = [node.get('label', '') for node in nodes]
        
        # Build a histogram of node labels
        hist = Counter(labels)
        
        # Calculate the Shannon entropy for the histogram
        entropy = calculate_entropy(hist)
        
        results.append((file_path, entropy))
    except Exception as e:
        print(f"Error processing {file_path}: {e}")

# Sort the results by entropy (descending order)
results.sort(key=lambda x: x[1], reverse=True)

# Print out the filename and the computed entropy
for file_path, entropy in results:
    print(f"{file_path}: Entropy = {entropy:.4f}")


./trees/num_classes_1000_top_p_0.9_max_depth_7_max_width_16_max_attempts_32_temperature_1.5.json: Entropy = 9.2372
./trees/num_classes_10000_top_p_0.9_max_depth_6_max_width_16_max_attempts_32_temperature_1.5.json: Entropy = 8.6960
./trees/num_classes_100_top_p_0.9_max_depth_7_max_width_16_max_attempts_32_temperature_1.5.json: Entropy = 8.5158
./trees/num_classes_1000_top_p_0.9_max_depth_6_max_width_16_max_attempts_32_temperature_1.5.json: Entropy = 7.9472
./trees/num_classes_10_top_p_0.9_max_depth_8_max_width_16_max_attempts_32_temperature_1.5.json: Entropy = 7.4928
./trees/num_classes_100_top_p_0.9_max_depth_6_max_width_16_max_attempts_32_temperature_1.5.json: Entropy = 7.4777
./trees/num_classes_10_top_p_0.9_max_depth_7_max_width_16_max_attempts_32_temperature_1.5.json: Entropy = 7.3393
./trees/num_classes_10000_top_p_0.9_max_depth_5_max_width_16_max_attempts_32_temperature_1.5.json: Entropy = 7.2751
./trees/num_classes_1000_top_p_0.9_max_depth_5_max_width_16_max_attempts_32_temperat

In [8]:
import glob
import json
from collections import defaultdict

file_pattern = './trees/*.json'
results = []

# Process each JSON file matching the pattern
for file_path in glob.glob(file_pattern):
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
        
        # Count children for each parent node based on the edges
        edges = data.get('edges', [])
        children_counts = defaultdict(int)
        for edge in edges:
            parent = edge.get('source')
            if parent is not None:
                children_counts[parent] += 1
        
        # Compute the average number of children per parent node
        if children_counts:
            avg_children = sum(children_counts.values()) / len(children_counts)
        else:
            avg_children = 0
        
        results.append((file_path, avg_children))
    except Exception as e:
        print(f"Error processing {file_path}: {e}")

# Sort the results by average number of children (descending order)
results.sort(key=lambda x: x[1], reverse=True)

# Print out the filename and the average number of children
for file_path, avg_children in results:
    print(f"{file_path}: Average children = {avg_children:.4f}")


./trees/num_classes_10_top_p_0.9_max_depth_8_max_width_16_max_attempts_32_temperature_1.5.json: Average children = 4.4373
./trees/num_classes_10000_top_p_0.9_max_depth_4_max_width_16_max_attempts_32_temperature_1.5.json: Average children = 4.0455
./trees/num_classes_1000_top_p_0.9_max_depth_4_max_width_16_max_attempts_32_temperature_1.5.json: Average children = 3.5833
./trees/num_classes_1000_top_p_0.9_max_depth_7_max_width_16_max_attempts_32_temperature_1.5.json: Average children = 3.5365
./trees/num_classes_100_top_p_0.9_max_depth_4_max_width_16_max_attempts_32_temperature_1.5.json: Average children = 3.5294
./trees/num_classes_10000_top_p_0.9_max_depth_6_max_width_16_max_attempts_32_temperature_1.5.json: Average children = 3.5062
./trees/num_classes_1000_top_p_0.9_max_depth_6_max_width_16_max_attempts_32_temperature_1.5.json: Average children = 3.0243
./trees/num_classes_100_top_p_0.9_max_depth_7_max_width_16_max_attempts_32_temperature_1.5.json: Average children = 2.9214
./trees/nu