## Look for specific entity

In [44]:
import glob
import json
from tqdm.auto import tqdm

file_pattern = './trees/*.json'
search_term = '(Q5)'  # the string to search for

results = []

# Loop through all JSON files matching the pattern
for file_path in glob.glob(file_pattern):
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
        
        # Get the nodes list (if it exists)
        nodes = data.get('nodes', [])
        
        # Count how many node labels contain the search term
        count = sum(1 for node in nodes if search_term in node.get('label', ''))
        
        results.append((file_path, count))
    except Exception as e:
        print(f"Error processing {file_path}: {e}")

# Sort results in descending order of the counts
results.sort(key=lambda x: x[1], reverse=True)

# Print out the filename and the count in sorted order
for file_path, count in tqdm(results):
    print(f"{file_path}: {count}")


100%|██████████| 36/36 [00:00<00:00, 291496.03it/s]

./trees/model_size_small_num_classes_100_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_14_max_width_128_temperature_0.75_width_decay_factor_0.75.json: 124
./trees/model_size_small_num_classes_100_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_15_max_width_128_temperature_0.75_width_decay_factor_0.76.json: 76
./trees/model_size_small_num_classes_100_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_13_max_width_128_temperature_0.75_width_decay_factor_0.73.json: 28
./trees/model_size_small_num_classes_100_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_12_max_width_128_temperature_0.75_width_decay_factor_0.71.json: 16
./trees/model_size_small_num_classes_100_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_11_max_width_128_temperature_0.75_width_decay_factor_0.69.json: 15
./trees/model_size_small_num_classes_100_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_10_max_width_128_temperature_0.75_width_decay_factor




## Sort by entropy

In [1]:
import glob
import json
import math
from tqdm.auto import tqdm
from collections import Counter

def calculate_entropy(hist, normalize=False):
    """Calculate Shannon entropy given a histogram (Counter) of counts.
    
    Args:
        hist (Counter): A histogram of counts (label frequencies).
        normalize (bool): If True, normalize the entropy by the max possible entropy.

    Returns:
        float: Computed entropy (normalized if specified).
    """
    total = sum(hist.values())
    if total == 0:
        return 0
    
    entropy = 0
    for count in hist.values():
        if count > 0:
            p = count / total
            entropy -= p * math.log2(p)

    if normalize:
        num_categories = len(hist)
        if num_categories > 1:
            entropy /= math.log2(num_categories)  # Normalize by max entropy
        else:
            entropy = 0  # If there's only one category, normalized entropy is 0
    
    return entropy

file_pattern = './trees/*.json'
results = []

# Process each JSON file matching the pattern
for file_path in glob.glob(file_pattern):
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
        
        nodes = data.get('nodes', [])
        # Extract the node labels
        labels = [node.get('label', '') for node in nodes]
        
        # Build a histogram of node labels
        hist = Counter(labels)
        
        # Calculate the Shannon entropy (both raw and normalized)
        entropy = calculate_entropy(hist)
        norm_entropy = calculate_entropy(hist, normalize=True)
        
        results.append((file_path, entropy, norm_entropy))
    except Exception as e:
        print(f"Error processing {file_path}: {e}")

# Sort the results by normalized entropy (descending order)
results.sort(key=lambda x: x[1], reverse=True)

# Print out the filename and the computed entropy
for file_path, entropy, norm_entropy in tqdm(results):
    print(f"{file_path}: Entropy = {entropy:.4f}, Normalized Entropy = {norm_entropy:.4f}")


  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 38/38 [00:00<00:00, 260430.64it/s]

./trees/model_size_small_num_classes_10000_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_12_max_width_128_temperature_0.75_width_decay_factor_0.71.json: Entropy = 11.2245, Normalized Entropy = 0.9065
./trees/model_size_small_num_classes_10000_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_11_max_width_128_temperature_0.75_width_decay_factor_0.69.json: Entropy = 11.2107, Normalized Entropy = 0.9201
./trees/model_size_small_num_classes_10000_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_10_max_width_128_temperature_0.75_width_decay_factor_0.66.json: Entropy = 11.0419, Normalized Entropy = 0.9420
./trees/model_size_small_num_classes_10000_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_13_max_width_128_temperature_0.75_width_decay_factor_0.73.json: Entropy = 10.9718, Normalized Entropy = 0.8697
./trees/model_size_small_num_classes_10000_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_14_max_width_128_temperature_0.75_wi




## Sort by average children

In [2]:
import glob
import json
from tqdm.auto import tqdm
from collections import defaultdict

file_pattern = './trees/*.json'
results = []

# Process each JSON file matching the pattern
for file_path in glob.glob(file_pattern):
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            data = json.load(file)
        
        # Count children for each parent node based on the edges
        edges = data.get('edges', [])
        children_counts = defaultdict(int)
        for edge in edges:
            parent = edge.get('source')
            if parent is not None:
                children_counts[parent] += 1
        
        # Compute the average number of children per parent node
        if children_counts:
            avg_children = sum(children_counts.values()) / len(children_counts)
        else:
            avg_children = 0
        
        results.append((file_path, avg_children))
    except Exception as e:
        print(f"Error processing {file_path}: {e}")

# Sort the results by average number of children (descending order)
results.sort(key=lambda x: x[1], reverse=True)

# Print out the filename and the average number of children
for file_path, avg_children in tqdm(results):
    print(f"{file_path}: Average children = {avg_children:.4f}")


100%|██████████| 37/37 [00:00<00:00, 287387.50it/s]

./trees/model_size_small_num_classes_10000_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_6_max_width_128_temperature_0.75_width_decay_factor_0.5.json: Average children = 2.0376
./trees/model_size_small_num_classes_10000_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_5_max_width_128_temperature_0.75_width_decay_factor_0.44.json: Average children = 1.9926
./trees/model_size_small_num_classes_1000_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_6_max_width_128_temperature_0.75_width_decay_factor_0.5.json: Average children = 1.9368
./trees/model_size_small_num_classes_1000_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_4_max_width_128_temperature_0.75_width_decay_factor_0.36.json: Average children = 1.9362
./trees/model_size_small_num_classes_10000_allowed_threshold_0.5_loss_threshold_0.1_top_p_0.9_max_depth_4_max_width_128_temperature_0.75_width_decay_factor_0.36.json: Average children = 1.9091
./trees/model_size_small_num_classes_10


