# Exploring Palace Partitioning

In [None]:
import pickle
from collections import deque
from pathlib import Path

import graphviz as gv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from gtsam import Pose3
from typing import Tuple, List, Dict, Any, Optional, Set

from gtsfm.graph_partitioner.metis_partitioner import MetisPartitioner
from gtsfm.products.visibility_graph import (
    VisibilityGraph,
    visibility_graph_keys,
)
from gtsfm.utils.io import load_poses, save_poses
from gtsfm.products.cluster_tree import ClusterTree

PALACE = Path("../tests/data/palace")

## Common Visualization Functions

In [None]:
def get_edge_coordinates(xy: np.ndarray, edges: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    """Prepare edge coordinates for a Plotly Scatter trace."""
    if edges.size == 0:
        return np.array([]), np.array([])
    
    xe = np.empty(3 * len(edges))
    ye = np.empty(3 * len(edges))
    xe[0::3] = xy[edges[:, 0], 0]
    ye[0::3] = xy[edges[:, 0], 1]
    xe[1::3] = xy[edges[:, 1], 0]
    ye[1::3] = xy[edges[:, 1], 1]
    xe[2::3] = np.nan  # separator between segments
    ye[2::3] = np.nan
    return xe, ye

def create_base_figure_with_background(xy: np.ndarray, edges_arr: np.ndarray) -> go.Figure:
    """Creates a Plotly figure with all nodes and edges as a faint background."""
    fig = go.Figure()
    
    # --- background: all edges, very faint ---
    xe_bg, ye_bg = get_edge_coordinates(xy, edges_arr)
    fig.add_trace(go.Scatter(
        x=xe_bg, y=ye_bg, mode="lines",
        line=dict(width=1, color="lightgray"),
        opacity=0.2, hoverinfo="none", showlegend=False
    ))
    
    # --- background: all nodes, very faint ---
    fig.add_trace(go.Scatter(
        x=xy[:, 0], y=xy[:, 1], mode="markers",
        marker=dict(size=3, color="lightgray"),
        customdata=np.arange(len(xy)),
        hovertemplate="node %{customdata}<extra></extra>",
        showlegend=False
    ))
    
    fig.update_layout(
        paper_bgcolor="white",
        plot_bgcolor="white",
        margin=dict(l=0, r=0, t=0, b=0),
        xaxis=dict(visible=False),
        yaxis=dict(visible=False, scaleanchor="x", scaleratio=1),
    )
    return fig

# --- Helper functions for the interactive cluster hierarchy plot ---

def build_cluster_records(root: ClusterTree) -> List[Dict[str, Any]]:
    """Traverse the cluster tree and create a flat list of records for visualization."""
    records = []
    q = deque([(root, (), None)])  # (node, parent_path_tuple, parent_record_id)
    
    while q:
        node, parent_path, parent_id = q.popleft()
        
        if parent_id is None:
            path = ()
        else:
            parent_children = records[parent_id]["node"]._child_clusters()
            idx = list(parent_children).index(node) + 1
            path = (*records[parent_id]["path_tuple"], idx)
            
        rid = len(records)
        records.append(dict(
            id=rid, 
            node=node, 
            children=[], 
            path_tuple=path,
            path_str="C" + ("" if not path else "_" + "_".join(map(str, path)))
        ))
        
        if parent_id is not None:
            records[parent_id]["children"].append(rid)
            
        for child in node._child_clusters():
            q.append((child, path, rid))
            
    return records

def get_subtree_ids(records: List[Dict[str, Any]], start_cid: int) -> Set[int]:
    """Get all record IDs in the subtree starting from a given cluster ID."""
    subtree = {start_cid}
    q = deque([start_cid])
    while q:
        u = q.popleft()
        children_ids = records[u]["children"]
        subtree.update(children_ids)
        q.extend(children_ids)
    return subtree


## Data Loading and Pre-computation

In [None]:
sim = np.loadtxt(PALACE / "netvlad_similarity_matrix.txt", delimiter=",")

In [None]:
plt.imshow(np.triu(sim))
plt.title("Image Similarity Matrix")
plt.show()

> TODO: run MegaLoc and show 

In [None]:
df = pd.read_csv(PALACE / 'visibility_graph.csv')
graph : VisibilityGraph = list(zip(df["i"], df["j"]))

In [None]:
print("Number of edges in visibility graph:", len(graph))
print("Number of keys:", len(visibility_graph_keys(graph)))

In [None]:
# Poses were created with this code but then saved in palace/poses.pkl with save_poses in tiny file
# colmap_path = Path("../results/ba_output")
# poses, img_fnames = read_images_txt(str(colmap_path / "images.txt"))
# save_poses(poses, PALACE / "poses.pkl")
poses = load_poses(PALACE / "poses.pkl")

In [None]:
# Extract 2D translations for plotting
xy = np.array([p.translation()[:2] for p in poses])
N = len(xy)

# Ensure all edges in the graph correspond to existing poses
edges_arr = np.asarray(graph, dtype=int)
valid_mask = (edges_arr[:, 0] < N) & (edges_arr[:, 1] < N)
edges_arr = edges_arr[valid_mask]


## Plot 1: Full Visibility Graph

In [None]:
fig = go.Figure()

# --- Add edges ---
xe, ye = get_edge_coordinates(xy, edges_arr)
fig.add_trace(
    go.Scatter(
        x=xe, y=ye, mode="lines",
        line=dict(width=1, color="lightgray"),
        hoverinfo="none", showlegend=False
    )
)

# --- Add poses as markers ---
fig.add_trace(go.Scatter(x=xy[:, 0], y=xy[:, 1], mode="markers", marker=dict(size=5)))

# --- Finalize layout ---
fig.update_layout(
    xaxis_title="x",
    yaxis_title="y",
    yaxis_scaleanchor="x",
    yaxis_scaleratio=1,
    margin=dict(l=0, r=0, t=0, b=0),
)

fig.show()

## Graph Partitioning with METIS

In [None]:
partitioner = MetisPartitioner()
cluster_tree = partitioner.run(graph)

In [None]:
leaves = tuple(cluster_tree.leaves()) if cluster_tree is not None else ()
for index, leaf in enumerate(leaves, 1):
    keys = leaf.local_keys()
    print(f"Leaf {index} has {len(keys)} keys.")

In [None]:
# The __repr__ method in cluster_tree.py has been updated to produce this view
cluster_tree

In [None]:
bayes_tree = partitioner.symbolic_bayes_tree(graph)

In [None]:
# Uncomment to view the symbolic Bayes tree (requires graphviz installation)
# gv.Source(bayes_tree.dot())

## Plot 2: Visualizing Leaf Clusters

In [None]:
fig = create_base_figure_with_background(xy, edges_arr)

for idx, leaf in enumerate(leaves, 1):
    leaf_name = f"Leaf {idx}"
    legendgroup = f"leaf{idx}"

    # Get node indices for the current leaf
    nodes = np.array([k for k in leaf.all_keys() if 0 <= k < N], dtype=int)
    if nodes.size == 0:
        continue

    # Get edges within the leaf
    mask = np.isin(edges_arr[:, 0], nodes) & np.isin(edges_arr[:, 1], nodes)
    leaf_edges = edges_arr[mask]
    xe, ye = get_edge_coordinates(xy, leaf_edges)
    
    # Add leaf edges trace
    fig.add_trace(go.Scatter(
        x=xe, y=ye, mode="lines",
        line=dict(width=1), hoverinfo="none",
        name=leaf_name, legendgroup=legendgroup
    ))

    # Add leaf nodes trace
    fig.add_trace(go.Scatter(
        x=xy[nodes, 0], y=xy[nodes, 1], mode="markers",
        marker=dict(size=6),
        name=leaf_name, legendgroup=legendgroup, showlegend=False,
        customdata=nodes, hovertemplate="node %{customdata}<extra></extra>"
    ))

fig.update_layout(
    legend=dict(
        groupclick="togglegroup", 
        orientation="v", x=0, xanchor="right", y=1, yanchor="top"
    ),
    margin=dict(l=0, r=0, t=0, b=0),
)

fig.show()
# fig.write_html("visibility_graph_leaves.html", include_plotlyjs="cdn", full_html=True)

## Plot 3: Interactive Cluster Hierarchy Explorer

In [None]:
# --- 1. Process cluster tree hierarchy ---
records = build_cluster_records(cluster_tree)

# --- 2. Create base figure ---
fig = create_base_figure_with_background(xy, edges_arr)

# --- 3. Pre-generate all possible traces (initially invisible) ---
trace_indices = []  # Stores (edge_idx, node_local_idx, node_all_idx) for each cluster
for rec in records:
    node = rec["node"]
    k_local = np.array([k for k in node.local_keys() if 0 <= k < N], dtype=int)
    k_all = np.array([k for k in node.all_keys() if 0 <= k < N], dtype=int)

    # a) Local edges trace
    mask = np.isin(edges_arr[:, 0], k_local) & np.isin(edges_arr[:, 1], k_local)
    local_edges = edges_arr[mask]
    xe, ye = get_edge_coordinates(xy, local_edges)
    fig.add_trace(go.Scatter(x=xe, y=ye, mode="lines", line=dict(width=1), hoverinfo="none", visible=False, showlegend=False))
    edge_idx = len(fig.data) - 1
    
    # b) Local nodes trace
    fig.add_trace(go.Scatter(x=xy[k_local, 0], y=xy[k_local, 1], mode="markers", marker=dict(size=6),
                             customdata=k_local, hovertemplate="node %{customdata}<extra></extra>",
                             visible=False, showlegend=False))
    node_local_idx = len(fig.data) - 1
    
    # c) All nodes (in subtree) trace
    fig.add_trace(go.Scatter(x=xy[k_all, 0], y=xy[k_all, 1], mode="markers", marker=dict(size=6),
                             customdata=k_all, hovertemplate="node %{customdata}<extra></extra>",
                             visible=False, showlegend=False))
    node_all_idx = len(fig.data) - 1

    trace_indices.append((edge_idx, node_local_idx, node_all_idx))

# --- 4. Define button logic ---
num_base_traces = 2 # (background edges, background nodes)
num_traces = len(fig.data)

def create_visibility_mask(visible_indices: List[int]) -> List[bool]:
    mask = [False] * num_traces
    mask[0] = True # Always show background edges
    mask[1] = False # Hide background nodes when showing specific clusters
    for i in visible_indices:
        mask[i] = True
    return mask

# --- 5. Build dropdown menus ---
buttons_subtree = []
buttons_local = []

for rec in records:
    cid = rec["id"]
    # Subtree view: show all nodes of the root, and all local edges of descendants
    subtree_nodes_to_show = [trace_indices[cid][2]]
    subtree_edges_to_show = [trace_indices[i][0] for i in get_subtree_ids(records, cid)]
    buttons_subtree.append(dict(label=rec["path_str"], method="update", 
                              args=[{"visible": create_visibility_mask(subtree_nodes_to_show + subtree_edges_to_show)}]))

    # Local view: show only local nodes and local edges
    local_traces_to_show = [trace_indices[cid][0], trace_indices[cid][1]]
    buttons_local.append(dict(label=rec["path_str"], method="update", 
                            args=[{"visible": create_visibility_mask(local_traces_to_show)}]))

# --- 6. Update figure layout with menus ---
fig.update_layout(
    updatemenus=[
        dict(type="dropdown", direction="down", x=0.01, y=0.99, xanchor="left", yanchor="top",
             buttons=buttons_subtree, showactive=True, active=0, 
             pad={"r": 10, "t": 10}, 
             bgcolor="#f0f0f0", bordercolor="Black", borderwidth=1,
             font={"size": 11}),
        dict(type="dropdown", direction="down", x=0.25, y=0.99, xanchor="left", yanchor="top",
             buttons=buttons_local, showactive=True, active=0, 
             pad={"r": 10, "t": 10}, 
             bgcolor="#f0f0f0", bordercolor="Black", borderwidth=1,
             font={"size": 11}),
    ],
    annotations=[
        dict(text="Subtree View:", x=0.01, y=1.05, xref="paper", yref="paper", align="left", showarrow=False),
        dict(text="Local View:", x=0.25, y=1.05, xref="paper", yref="paper", align="left", showarrow=False)
    ],
    margin=dict(l=0, r=0, t=40, b=0),
)

# Set initial view to the root cluster's subtree
fig.data[trace_indices[0][2]].visible = True # root's 'all_nodes'
for i in get_subtree_ids(records, 0):
    fig.data[trace_indices[i][0]].visible = True # all local edges

fig.show()