# Exploring Palace Partitioning

In [None]:
import pickle
from collections import deque
from pathlib import Path

import graphviz as gv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from gtsam import Pose3

from gtsfm.graph_partitioner.metis_partitioner import MetisPartitioner
from gtsfm.products.visibility_graph import (VisibilityGraph,
                                             visibility_graph_keys)
# from gtsfm.utils.io import read_cameras_txt, read_images_txt
from gtsfm.utils.io import load_poses, save_poses

PALACE = Path("../tests/data/palace")

In [None]:
sim = np.loadtxt(PALACE / "netvlad_similarity_matrix.txt", delimiter=",")

In [None]:
plt.imshow(np.triu(sim))
plt.title("Image Similarity Matrix")

> TODO: run MegaLoc and show 

In [None]:
df = pd.read_csv(PALACE / 'visibility_graph.csv')
graph : VisibilityGraph = list(zip(df["i"], df["j"]))

In [None]:
print("Number of edges in visibility graph:", len(graph))
print("Number of keys:", len(visibility_graph_keys(graph)))

In [None]:
# Poses were created with this code but then saved in palace/poses.pkl with save_poses in tiny file
# colmap_path = Path("../results/ba_output")
# poses, img_fnames = read_images_txt(str(colmap_path / "images.txt"))
# save_poses(poses, PALACE / "poses.pkl")
poses = load_poses(PALACE / "poses.pkl")

In [None]:
xy = np.array([p.translation() for p in poses])

# --- Precompute all edge segments at once ---
valid_edges = [(i, j) for i, j in graph if i < len(xy) and j < len(xy)]
if valid_edges:
    edges_arr = np.array(valid_edges)
    xe = np.empty(3 * len(edges_arr))
    ye = np.empty(3 * len(edges_arr))
    xe[0::3] = xy[edges_arr[:, 0], 0]
    ye[0::3] = xy[edges_arr[:, 0], 1]
    xe[1::3] = xy[edges_arr[:, 1], 0]
    ye[1::3] = xy[edges_arr[:, 1], 1]
    xe[2::3] = np.nan  # separator between segments
    ye[2::3] = np.nan

# --- Build figure with edges first (drawn underneath) ---
fig = go.Figure()

if valid_edges:
    fig.add_trace(
        go.Scatter(
            x=xe,
            y=ye,
            mode="lines",
            line=dict(width=1, color="lightgray"),
            hoverinfo="none",
            showlegend=False,
        )
    )

# --- Add poses as markers (drawn on top) ---
fig.add_trace(go.Scatter(x=xy[:, 0], y=xy[:, 1], mode="markers", marker=dict(size=5)))

fig.update_layout(
    xaxis_title="x",
    yaxis_title="y",
    yaxis_scaleanchor="x",
    yaxis_scaleratio=1,
    margin=dict(l=0, r=0, t=0, b=0),
)

fig.show()

In [None]:
partitioner = MetisPartitioner()
cluster_tree = partitioner.run(graph)


In [None]:
leaves = tuple(cluster_tree.leaves()) if cluster_tree is not None else ()
for index, leaf in enumerate(leaves, 1):
    keys = leaf.local_keys()
    print(f"Leaf {index} has {len(keys)} keys.")
    print(keys)

In [None]:
cluster_tree

In [None]:
bayes_tree = partitioner.symbolic_bayes_tree(graph)

In [None]:
# Comment out to see
# gv.Source(bayes_tree.dot())

In [None]:
xy = np.array([p.translation() for p in poses])
N = len(xy)

edges_arr = np.asarray(graph, dtype=int)
edges_arr = edges_arr[(edges_arr[:,0] < N) & (edges_arr[:,1] < N)]

fig = go.Figure()

# --- background: all edges, very faint (drawn first, under everything) ---
if edges_arr.size:
    xe_bg = np.empty(3 * len(edges_arr)); ye_bg = np.empty(3 * len(edges_arr))
    xe_bg[0::3] = xy[edges_arr[:,0], 0]; ye_bg[0::3] = xy[edges_arr[:,0], 1]
    xe_bg[1::3] = xy[edges_arr[:,1], 0]; ye_bg[1::3] = xy[edges_arr[:,1], 1]
    xe_bg[2::3] = np.nan;                ye_bg[2::3] = np.nan
    fig.add_trace(go.Scatter(
        x=xe_bg, y=ye_bg, mode="lines",
        line=dict(width=1, color="lightgray"),
        opacity=0.12, hoverinfo="none", showlegend=False
    ))

fig.add_trace(go.Scatter(
    x=xy[:, 0],
    y=xy[:, 1],
    mode="markers",
    marker=dict(size=3, color="lightgray"),
    customdata=np.arange(N),
    hovertemplate="node %{customdata}<extra></extra>",
))

# one legend entry toggles both traces in the same group
fig.update_layout(legend=dict(groupclick="togglegroup"))

for idx, leaf in enumerate(leaves, 1):
    leaf_name = f"Leaf {idx}"
    legendgroup = f"leaf{idx}"

    # node indices (guarded)
    nodes = np.array([k for k in leaf.all_keys() if 0 <= k < N], dtype=int)
    if nodes.size == 0:
        continue

    # edges inside the leaf
    if edges_arr.size:
        m = np.isin(edges_arr[:,0], nodes) & np.isin(edges_arr[:,1], nodes)
        E = edges_arr[m]
        if len(E):
            xe = np.empty(3 * len(E)); ye = np.empty(3 * len(E))
            xe[0::3] = xy[E[:,0], 0]; ye[0::3] = xy[E[:,0], 1]
            xe[1::3] = xy[E[:,1], 0]; ye[1::3] = xy[E[:,1], 1]
            xe[2::3] = np.nan;        ye[2::3] = np.nan
            fig.add_trace(go.Scatter(
                x=xe, y=ye, mode="lines",
                line=dict(width=1), hoverinfo="none",
                name=leaf_name, legendgroup=legendgroup, showlegend=True
            ))

    fig.add_trace(go.Scatter(
        x=xy[nodes, 0],
        y=xy[nodes, 1],
        mode="markers",
        marker=dict(size=6),
        name=leaf_name,
        legendgroup=legendgroup,
        showlegend=False,
        customdata=nodes,  # store node IDs
        hovertemplate="node %{customdata}<extra></extra>",  # clean tooltip
    ))

fig.update_layout(
    xaxis_title=None,
    yaxis_title=None,
    xaxis=dict(visible=False),
    yaxis=dict(visible=False),
    paper_bgcolor="white",
    plot_bgcolor="white",
    margin=dict(l=0, r=0, t=0, b=0),
    legend=dict(orientation="v", x=0, xanchor="right", y=1, yanchor="top"),
)

# reduce axis padding
fig.update_xaxes(automargin=True)
fig.update_yaxes(automargin=True)

fig.show()

In [None]:
# fig.write_html("visibility_graph.html", include_plotlyjs="cdn", full_html=True)

In [None]:
# --- positions & edges ---
xy = np.array([p.translation() for p in poses])
N = len(xy)
edges_arr = np.asarray(graph, dtype=int)
edges_arr = edges_arr[(edges_arr[:,0] < N) & (edges_arr[:,1] < N)]

# --- enumerate clusters and assign path labels like C_2_3_1 ---
records = []  # list of dicts: {id, node, children, path_tuple, path_str}
def build_records(root):
    q = deque([(root, (), None)])  # (node, parent_path, parent_id)
    while q:
        node, parent_path, parent_id = q.popleft()
        # determine this node's index among parent's children (1-based)
        # for root, use empty ()
        if parent_id is None:
            path = ()
        else:
            parent_children = records[parent_id]["node"]._child_clusters()
            idx = list(parent_children).index(node) + 1
            path = (*records[parent_id]["path_tuple"], idx)
        rid = len(records)
        records.append(dict(id=rid, node=node, children=[], path_tuple=path,
                            path_str="C" + ("" if not path else "_" + "_".join(map(str, path)))))
        # enqueue children
        for ch in node._child_clusters():
            q.append((ch, path, rid))
        if parent_id is not None:
            records[parent_id]["children"].append(rid)

build_records(cluster_tree)

# helpers
def subtree_ids(cid):
    S = {cid}
    dq = deque([cid])
    while dq:
        u = dq.popleft()
        dq.extend(records[u]["children"])
        S.update(records[u]["children"])
    return S

# --- figure: faint background ---
fig = go.Figure()
if edges_arr.size:
    xe_bg = np.empty(3 * len(edges_arr)); ye_bg = np.empty(3 * len(edges_arr))
    xe_bg[0::3] = xy[edges_arr[:,0], 0]; ye_bg[0::3] = xy[edges_arr[:,0], 1]
    xe_bg[1::3] = xy[edges_arr[:,1], 0]; ye_bg[1::3] = xy[edges_arr[:,1], 1]
    xe_bg[2::3] = np.nan;                ye_bg[2::3] = np.nan
    fig.add_trace(go.Scatter(x=xe_bg, y=ye_bg, mode="lines",
                             line=dict(width=1, color="lightgray"),
                             opacity=0.12, hoverinfo="none", showlegend=False))
BG = 1

trace_idx = []  # per cluster: (edge_local_idx or None, node_local_idx, node_all_idx)

for rec in records:
    node = rec["node"]

    k_local = np.array([k for k in node.local_keys() if 0 <= k < N], dtype=int)
    k_all   = np.array([k for k in node.all_keys()   if 0 <= k < N], dtype=int)

    # local edges
    e_idx = None
    if edges_arr.size and k_local.size:
        m = np.isin(edges_arr[:,0], k_local) & np.isin(edges_arr[:,1], k_local)
        E = edges_arr[m]
        if len(E):
            xe = np.empty(3*len(E)); ye = np.empty(3*len(E))
            xe[0::3] = xy[E[:,0],0]; ye[0::3] = xy[E[:,0],1]
            xe[1::3] = xy[E[:,1],0]; ye[1::3] = xy[E[:,1],1]
            xe[2::3] = np.nan;       ye[2::3] = np.nan
            fig.add_trace(go.Scatter(x=xe, y=ye, mode="lines",
                                     line=dict(width=1),
                                     hoverinfo="none", visible=False, showlegend=False))
            e_idx = len(fig.data) - 1

    # nodes_local
    fig.add_trace(go.Scatter(
        x=xy[k_local,0], y=xy[k_local,1], mode="markers",
        marker=dict(size=6),
        customdata=k_local, hovertemplate="node %{customdata}<extra></extra>",
        visible=False, showlegend=False
    ))
    n_loc_idx = len(fig.data) - 1

    # nodes_all
    fig.add_trace(go.Scatter(
        x=xy[k_all,0], y=xy[k_all,1], mode="markers",
        marker=dict(size=6),
        customdata=k_all, hovertemplate="node %{customdata}<extra></extra>",
        visible=False, showlegend=False
    ))
    n_all_idx = len(fig.data) - 1

    trace_idx.append((e_idx, n_loc_idx, n_all_idx))

M = len(fig.data)

from collections import deque
def subtree_ids(cid):
    S = {cid}
    q = deque([cid])
    while q:
        u = q.popleft()
        for v in records[u]["children"]:
            if v not in S:
                S.add(v); q.append(v)
    return S

def base_vis():
    v = [False]*M
    v[0] = True  # keep faint background
    return v

def mask_local(cid):
    v = base_vis()
    e, nloc, nall = trace_idx[cid]
    if e is not None: v[e] = True
    v[nloc] = True           # show only local keys
    return v

def mask_subtree(cid):
    v = base_vis()
    # edges: union of all local edges in subtree
    for u in subtree_ids(cid):
        e, _, _ = trace_idx[u]
        if e is not None: v[e] = True
    # nodes: one trace with all keys of the root cluster (no duplicates)
    _, _, nall = trace_idx[cid]
    v[nall] = True
    return v

# Build per-path buttons for each mode
buttons_subtree = [
    dict(label=rec["path_str"], method="update",
         args=[{"visible": mask_subtree(rec["id"])}])
    for rec in records
]
buttons_local = [
    dict(label=rec["path_str"], method="update",
         args=[{"visible": mask_local(rec["id"])}])
    for rec in records
]

fig.update_layout(
    updatemenus=[
        # Left dropdown: pick a cluster to show its entire subtree
        dict(
            type="dropdown", direction="down", x=0.00, y=1.10, xanchor="left",
            buttons=buttons_subtree, showactive=True,
            pad=dict(l=2, r=2, t=2, b=2)
        ),
        # Right dropdown: pick a cluster to show only its local edges/keys
        dict(
            type="dropdown", direction="down", x=0.30, y=1.10, xanchor="left",
            buttons=buttons_local, showactive=True,
            pad=dict(l=2, r=2, t=2, b=2)
        ),
    ],
    # no title, keep the canvas tight
    margin=dict(l=0, r=0, t=36, b=0),
)

# NB: Because Plotly lacks shared state between menus, the MODE buttons
# are initialized with the default path. After you choose a new path from the
# first dropdown, click the desired MODE button again to apply that mode
# for the newly-selected path. (Pure-Python Plotly can't read the current
# dropdown selection; this keeps it simple and reliable.)

fig.show()