# Visualizing Top 3 Best and Worst Auto-Generated Maps

This notebook loads the top 3 best and worst maps (by improvement of RepeatedTopK over Shortest Path) from the no-pruning benchmark run and visualizes them.

**Legend:**
- **Node colors**: Height-based terrain colormap (green=low/corridor, brown=high/plateau), limegreen=source, cyan=target
- **Edge colors**: Red = chokepoint edges, gray = normal edges
- **Blob entry points**: Edges connecting plateau (blob) nodes to corridor (non-blob) nodes

In [None]:
import pickle
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as mcolors

# Load saved maps
with open('top_bottom_maps.pkl', 'rb') as f:
    data = pickle.load(f)

top3 = data['top3']      # list of (map_data, improvement_pct)
bottom3 = data['bottom3']  # list of (map_data, improvement_pct)
all_improvements = data['all_improvements']

print(f"Loaded {len(top3)} best and {len(bottom3)} worst maps")
print(f"Total maps benchmarked: {len(all_improvements)}")
print(f"Mean improvement: {np.mean(all_improvements):.2f}%")
print(f"Median improvement: {np.median(all_improvements):.2f}%")

## Distribution of Improvements Across All 50 Maps

In [None]:
plt.figure(figsize=(10, 4))
plt.hist(all_improvements, bins=20, edgecolor='black', alpha=0.7, color='steelblue')
plt.axvline(0, color='red', linestyle='--', linewidth=1.5, label='Break-even')
plt.axvline(np.mean(all_improvements), color='orange', linestyle='-', linewidth=1.5,
            label=f'Mean={np.mean(all_improvements):.1f}%')
plt.xlabel('Improvement (%)')
plt.ylabel('Count')
plt.title('Distribution of RepeatedTopK Improvement over Shortest Path')
plt.legend()
plt.tight_layout()
plt.show()

## Helper: Draw a single map

In [None]:
def draw_map(map_data, improvement_pct, ax=None):
    """
    Draws a map with:
    - Terrain-colored nodes (height-based)
    - Chokepoint edges in RED
    - Source in limegreen, target in cyan
    - Blob boundary edges (entry points) in blue
    - Shortest path in green (dashed)
    """
    g = map_data['env_graph']
    source = map_data['source']
    target = map_data['target']
    chokepoints = map_data['chokepoints']
    blobs = map_data['blobs']
    grid_size = map_data['grid_size']
    seed = map_data.get('seed', '?')
    label = map_data.get('label', '?')

    all_blob_nodes = set()
    for blob in blobs:
        all_blob_nodes.update(blob)

    # Build chokepoint edge set
    cp_set = set()
    for u, v in chokepoints:
        cp_set.add(tuple(sorted((u, v))))

    # Find blob boundary edges (entry points)
    entry_edges = set()
    for u, v in g.edges():
        if (u in all_blob_nodes) != (v in all_blob_nodes):
            entry_edges.add(tuple(sorted((u, v))))

    # Positions
    pos = nx.get_node_attributes(g, 'pos')

    # Node colors: height-based with special colors for source/target
    all_heights = [data.get('height', 0) for _, data in g.nodes(data=True)]
    max_height = max(all_heights) if all_heights else 1
    norm = mcolors.Normalize(vmin=0, vmax=max_height + 1)
    cmap = plt.cm.terrain

    node_colors = []
    node_sizes = []
    for node, ndata in g.nodes(data=True):
        if node == source:
            node_colors.append('limegreen')
            node_sizes.append(120)
        elif node == target:
            node_colors.append('cyan')
            node_sizes.append(120)
        else:
            h = ndata.get('height', 0)
            node_colors.append(cmap(norm(h)))
            node_sizes.append(60)

    # Edge colors: chokepoints=red, entry=blue, others=gray
    edge_colors = []
    edge_widths = []
    for u, v in g.edges():
        ek = tuple(sorted((u, v)))
        if ek in cp_set:
            edge_colors.append('red')
            edge_widths.append(2.5)
        elif ek in entry_edges:
            edge_colors.append('dodgerblue')
            edge_widths.append(2.0)
        else:
            edge_colors.append('lightgray')
            edge_widths.append(0.8)

    # Shortest path
    try:
        sp = nx.shortest_path(g, source, target, weight='distance')
        sp_edges = set(tuple(sorted((sp[i], sp[i+1]))) for i in range(len(sp)-1))
        sp_cost = sum(g.edges[sp[i], sp[i+1]]['distance'] for i in range(len(sp)-1))
    except nx.NetworkXNoPath:
        sp_edges = set()
        sp_cost = float('inf')

    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 8))

    # Draw the graph
    nx.draw(
        g, pos=pos, ax=ax,
        node_color=node_colors,
        node_size=node_sizes,
        with_labels=False,
        edge_color=edge_colors,
        width=edge_widths,
    )

    # Overlay shortest path as green dashed
    if sp_edges:
        sp_edge_list = [(u, v) for u, v in g.edges() if tuple(sorted((u, v))) in sp_edges]
        nx.draw_networkx_edges(
            g, pos=pos, ax=ax,
            edgelist=sp_edge_list,
            edge_color='green',
            width=2.0,
            style='dashed',
            alpha=0.7
        )

    # Count chokepoints visible from blobs
    visible_cps = set()
    for blob in blobs:
        for node in blob:
            if node in g.nodes():
                vis = g.nodes[node].get('visible_edges', [])
                for edge in vis:
                    ek = tuple(sorted(edge))
                    if ek in cp_set:
                        visible_cps.add(ek)
    cp_vis_pct = len(visible_cps) / len(cp_set) * 100 if cp_set else 0

    title = (f"{label} (seed={seed}, {grid_size}x{grid_size})\n"
             f"Improvement: {improvement_pct:+.1f}% | "
             f"Chokepoints: {len(chokepoints)} ({cp_vis_pct:.0f}% blob-visible)\n"
             f"Blobs: {len(blobs)} (sizes: {[len(b) for b in blobs]}) | "
             f"SP cost: {sp_cost:.1f} | Entry edges: {len(entry_edges)}")
    ax.set_title(title, fontsize=10)
    ax.axis('off')

    return ax

## Top 3 BEST Maps (Highest Improvement)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(24, 8))
for i, (md, imp) in enumerate(top3):
    draw_map(md, imp, ax=axes[i])
fig.suptitle('Top 3 BEST Maps (RepeatedTopK outperforms SP)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Top 3 WORST Maps (Lowest Improvement)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(24, 8))
for i, (md, imp) in enumerate(bottom3):
    draw_map(md, imp, ax=axes[i])
fig.suptitle('Top 3 WORST Maps (SP outperforms or matches RepeatedTopK)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

## Individual Large Views (Best Maps)

In [None]:
for i, (md, imp) in enumerate(top3):
    fig, ax = plt.subplots(figsize=(10, 10))
    draw_map(md, imp, ax=ax)
    plt.tight_layout()
    plt.show()

## Individual Large Views (Worst Maps)

In [None]:
for i, (md, imp) in enumerate(bottom3):
    fig, ax = plt.subplots(figsize=(10, 10))
    draw_map(md, imp, ax=ax)
    plt.tight_layout()
    plt.show()

## Detailed Structural Comparison

Compare key structural properties between best and worst maps.

In [None]:
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath('analytic_prune.py')))
sys.path.insert(0, '.')
from analytic_prune import score_map

print("=" * 80)
print(f"{'Property':<30} {'Best 1':>10} {'Best 2':>10} {'Best 3':>10} | {'Worst 1':>10} {'Worst 2':>10} {'Worst 3':>10}")
print("=" * 80)

all_maps = list(top3) + list(bottom3)
scores = []
for md, imp in all_maps:
    sc, comp = score_map(md)
    scores.append((sc, comp, imp))

# Print improvement
vals = [f"{s[2]:+.1f}%" for s in scores]
print(f"{'Improvement':<30} {vals[0]:>10} {vals[1]:>10} {vals[2]:>10} | {vals[3]:>10} {vals[4]:>10} {vals[5]:>10}")

# Print composite score
vals = [f"{s[0]:.2f}" for s in scores]
print(f"{'Composite score':<30} {vals[0]:>10} {vals[1]:>10} {vals[2]:>10} | {vals[3]:>10} {vals[4]:>10} {vals[5]:>10}")

# Print component scores
for key in ['cp_visibility', 'detour_near_blob', 'entry_efficiency', 'cp_impact', 'path_diversity']:
    vals = [f"{s[1][key]:.3f}" for s in scores]
    print(f"{key:<30} {vals[0]:>10} {vals[1]:>10} {vals[2]:>10} | {vals[3]:>10} {vals[4]:>10} {vals[5]:>10}")

# Print structural properties
for prop_name, prop_fn in [
    ('Grid size', lambda md: md['grid_size']),
    ('Num blobs', lambda md: len(md['blobs'])),
    ('Num chokepoints', lambda md: len(md['chokepoints'])),
    ('Total blob nodes', lambda md: sum(len(b) for b in md['blobs'])),
    ('Num edges', lambda md: md['env_graph'].number_of_edges()),
    ('Entry points', lambda md: sum(1 for u, v in md['env_graph'].edges()
                                     if (u in set().union(*md['blobs'])) != (v in set().union(*md['blobs'])))),
]:
    vals = [f"{prop_fn(md)}" for md, _ in all_maps]
    print(f"{prop_name:<30} {vals[0]:>10} {vals[1]:>10} {vals[2]:>10} | {vals[3]:>10} {vals[4]:>10} {vals[5]:>10}")

print("=" * 80)