In [None]:
import numpy as np
import networkx as nx
from pathlib import Path
import polars as pl
import umap.umap_ as umap

from bokeh.plotting import figure, show, ColumnDataSource
from bokeh.models import HoverTool, CustomJS
from bokeh.io import output_notebook

# -------------------------------------------------------------
# 1) LOAD DATA
# -------------------------------------------------------------
data_dir = Path().absolute() / ".." / "data" / "dagster"

# Load the graph
G = nx.read_graphml(data_dir / "recursive_causality" / "cm0i27jdj0000aqpa73ghpcxf.graphml")

# Load the Polars DataFrame
df = pl.read_parquet(data_dir / "deduplicated_graph_w_embeddings" / "cm0i27jdj0000aqpa73ghpcxf.snappy")

# df should have:
#   - 'label' (node identifier)
#   - 'embedding' (vector)
#   - 'category' (for coloring)
print(df.head())

# -------------------------------------------------------------
# 2) EXTRACT EMBEDDINGS, RUN UMAP
# -------------------------------------------------------------
embeddings = np.stack(df["embedding"].to_list())  # shape: (N, embed_dim)
labels     = df["label"].to_list()                # shape: (N,)
categories = df["category"].to_list()             # shape: (N,)

# Optionally ensure they are strings (Bokeh color mapping works well with strings)
categories = [str(c) for c in categories]

# UMAP for 2D
reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42, verbose=True)
umap_coords = reducer.fit_transform(embeddings)   # shape: (N, 2)

x_coords = umap_coords[:, 0]
y_coords = umap_coords[:, 1]

# -------------------------------------------------------------
# 3) BUILD A NEIGHBORS DICTIONARY
# -------------------------------------------------------------
# We'll create a dict { node_id_str: [neighbor_node_id_str,...], ... }
# so we can display neighbors in a tooltip and highlight them in Bokeh.
neighbors_dict = {}
for node_id in G.nodes():
    # Some GraphML nodes might be numeric or strings. Make them strings to match df["label"].
    str_id = str(node_id)
    # The neighbors are also turned into strings.
    neighbors_dict[str_id] = [str(neigh) for neigh in G[node_id]]

# Then for each row in df, we can look up neighbors by 'label'
neighbors_list = []
for lbl in labels:
    neighbors_list.append(neighbors_dict.get(str(lbl), []))

# -------------------------------------------------------------
# 4) MAP LABEL -> INDEX, PREPARE EDGE COORDS
# -------------------------------------------------------------
label_to_idx = {str(lbl): i for i, lbl in enumerate(labels)}

x0_list, y0_list, x1_list, y1_list = [], [], [], []
edge_source_ids, edge_target_ids = [], []

for (u, v) in G.edges():
    su, sv = str(u), str(v)
    if su in label_to_idx and sv in label_to_idx:
        i_u = label_to_idx[su]
        i_v = label_to_idx[sv]
        x0_list.append(x_coords[i_u])
        y0_list.append(y_coords[i_u])
        x1_list.append(x_coords[i_v])
        y1_list.append(y_coords[i_v])

        # Keep track of the node IDs in each edge for highlight logic
        edge_source_ids.append(su)
        edge_target_ids.append(sv)

# -------------------------------------------------------------
# 5) BUILD BOKEH DATA SOURCES
# -------------------------------------------------------------
node_source = ColumnDataSource(data=dict(
    x=x_coords,
    y=y_coords,
    node_id=[str(lbl) for lbl in labels],  # Ensure string
    category=categories,
    neighbors=neighbors_list,
    alpha=[0.9]*len(labels),  # We can dynamically change this in the callback
))

edge_source = ColumnDataSource(data=dict(
    x0=x0_list,
    y0=y0_list,
    x1=x1_list,
    y1=y1_list,
    source_id=edge_source_ids,
    target_id=edge_target_ids,
    alpha=[0.05]*len(x0_list),  # faint
))

# -------------------------------------------------------------
# 6) CREATE BOKEH FIGURE
# -------------------------------------------------------------
output_notebook()  # if in Jupyter; else use output_file("some_name.html") + show(p)

p = figure(
    width=900,
    height=900,
    title="UMAP + Category Coloring + Interactive Hover (Highlight Neighbors)",
    tools="pan,wheel_zoom,box_zoom,reset,hover,tap",
    active_scroll="wheel_zoom"
)

# Draw edges as faint segments
edges_glyph = p.segment(
    x0='x0', y0='y0', x1='x1', y1='y1',
    source=edge_source,
    line_width=1,
    line_color="gray",
    alpha='alpha'
)

# Draw nodes, colored by category
nodes_glyph = p.scatter(
    'x', 'y',
    source=node_source,
    size=8,
    fill_color='category',  # if your categories are strings, Bokeh picks distinct colors
    line_color=None,
    alpha='alpha'
)

# -------------------------------------------------------------
# 7) ADD HOVER TOOL
# -------------------------------------------------------------
hover_tool = p.select_one(HoverTool)
hover_tool.tooltips = [
    ("Node", "@node_id"),
    ("Category", "@category"),
    ("Neighbors", "@neighbors"),
]
hover_tool.renderers = [nodes_glyph]  # only hover on nodes

# -------------------------------------------------------------
# 8) DEFINE JS CALLBACK TO HIGHLIGHT NEIGHBORS
# -------------------------------------------------------------
callback_code = """
// We have access to 'cb_data', plus edge_source and node_source from 'args'.

const inds = cb_data.index.indices;  // indices of hovered node(s)

/////////////////////////////////////////////////
// 1) Reset all edges and nodes to faint
/////////////////////////////////////////////////
for (let i = 0; i < edge_source.data['alpha'].length; i++) {
    edge_source.data['alpha'][i] = 0.05;  // faint edges
}
for (let i = 0; i < node_source.data['alpha'].length; i++) {
    node_source.data['alpha'][i] = 0.2;   // dim nodes
}

/////////////////////////////////////////////////
// 2) If we have a hovered node, highlight it
/////////////////////////////////////////////////
if (inds.length === 1) {
    const hovered_index = inds[0];
    
    // highlight the hovered node
    node_source.data['alpha'][hovered_index] = 1.0;
    
    // get node_id and neighbors
    const hovered_id = node_source.data['node_id'][hovered_index];
    const neighbor_ids = node_source.data['neighbors'][hovered_index];
    
    /////////////////////////////////////////////////////////
    // 3) highlight neighbor nodes
    /////////////////////////////////////////////////////////
    const node_ids = node_source.data['node_id'];
    for (let i = 0; i < node_ids.length; i++) {
        if (neighbor_ids.includes(node_ids[i])) {
            node_source.data['alpha'][i] = 1.0;
        }
    }
    
    /////////////////////////////////////////////////////////
    // 4) highlight edges that connect hovered node & neighbors
    /////////////////////////////////////////////////////////
    const e_src = edge_source.data['source_id'];
    const e_tgt = edge_source.data['target_id'];
    for (let i = 0; i < e_src.length; i++) {
        if ( (e_src[i] === hovered_id && neighbor_ids.includes(e_tgt[i])) ||
             (e_tgt[i] === hovered_id && neighbor_ids.includes(e_src[i])) ) {
            edge_source.data['alpha'][i] = 0.8;  // highlight
        }
    }
}

edge_source.change.emit();
node_source.change.emit();
"""

hover_callback = CustomJS(
    args=dict(edge_source=edge_source, node_source=node_source),
    code=callback_code
)
hover_tool.callback = hover_callback

# -------------------------------------------------------------
# 9) SHOW PLOT
# -------------------------------------------------------------
show(p)
