In [None]:
import networkx as nx
from bokeh.plotting import figure, show
from bokeh.models import (Circle, MultiLine, EdgesAndLinkedNodes, NodesAndLinkedEdges,
                         HoverTool, TapTool, BoxSelectTool,
                         ColumnDataSource, StaticLayoutProvider, Div, CustomJS, GraphRenderer) # Import GraphRenderer
from bokeh.layouts import column, row
from bokeh.palettes import Spectral4 # Example palette
import pandas as pd
from collections import Counter # Ensure Counter is imported for preprocessing functions
import seaborn as sns # Ensure seaborn is imported for preprocessing functions
import matplotlib.pyplot as plt # Ensure matplotlib is imported for preprocessing functions
from typing import List, Dict, Optional

# =============================================================================
# 1. Your Preprocessing Functions (Ensure they are defined/imported)
#    (Copied from your request for completeness, with minor corrections)
# =============================================================================

def add_node_label(G, node_key_ref):
    # node_key_ref = {'Paper': 'title', 'Author': 'name', 'Affiliation': '', 'Journal': '', 'Venue': ''}
    for nid, node_data in G.nodes(data=True):
        node_type = node_data.get('nodeType')
        label_key = node_key_ref.get(node_type)
        # Ensure label_key is valid and the attribute exists
        if label_key and label_key in node_data:
            node_data['vizLabel'] = node_data.get(label_key, f"ID: {nid}") # Fallback label
        elif 'nodeType' in node_data:
            node_data['vizLabel'] = f"{node_data['nodeType']} ID: {nid}" # Fallback with type
        else:
            node_data['vizLabel'] = f"ID: {nid}" # Generic fallback

def add_edges_label(G):
    # Corrected: Use G.edges, not G.edgs
    for u, v, key, edge_data in G.edges(data=True, keys=True): # Use keys=True for MultiDiGraph
        rel_type = edge_data.get('relationshipType', 'UNKNOWN')
        # Use key to distinguish parallel edges if necessary
        edge_data['vizLabel'] = f"{u} -> {v} ({rel_type}, key={key})"

def assign_node_size(
        G,
        sig_nid_lst: Optional[List[str]] = None,
        min_node_size: Optional[int] = 10,
        max_node_size: Optional[int] = 50,
        ):
    """assign node size (Corrected Logic)"""
    paper_cites_ref, author_writes_ref = {}, {}

    # First pass: Calculate counts for all relevant nodes
    for nid, node_data in G.nodes(data=True):
        node_type = node_data.get('nodeType')
        if node_type == 'Paper':
            # Corrected: Use G.in_edges(nid, data=True)
            in_edges_info = G.in_edges(nid, data=True)
            # Corrected: Check the edge data directly
            cites_cnt = sum(1 for u, v, data in in_edges_info if data.get('relationshipType') == 'CITES')
            paper_cites_ref[nid] = cites_cnt
        elif node_type == 'Author':
            # Corrected: Use G.out_edges(nid, data=True)
            out_edges_info = G.out_edges(nid, data=True)
            # Corrected: Check the edge data directly
            writes_cnt = sum(1 for u, v, data in out_edges_info if data.get('relationshipType') == 'WRITES')
            author_writes_ref[nid] = writes_cnt

    # Determine min/max counts *after* collecting all counts
    max_cites_cnt = max(paper_cites_ref.values()) if paper_cites_ref else 0
    min_cites_cnt = min(paper_cites_ref.values()) if paper_cites_ref else 0
    max_writes_cnt = max(author_writes_ref.values()) if author_writes_ref else 0
    min_writes_cnt = min(author_writes_ref.values()) if author_writes_ref else 0

    # Avoid division by zero if all counts are the same
    cites_range = max_cites_cnt - min_cites_cnt if max_cites_cnt > min_cites_cnt else 1
    writes_range = max_writes_cnt - min_writes_cnt if max_writes_cnt > min_writes_cnt else 1

    # Second pass: Assign sizes
    for nid, node_data in G.nodes(data=True):
        node_data['vizSize'] = min_node_size # Default size

        # Override for significant nodes
        if sig_nid_lst is not None and nid in sig_nid_lst:
            node_data['vizSize'] = max_node_size
            continue # Skip dynamic sizing if it's a significant node

        # Apply dynamic sizing based on counts for non-significant nodes
        node_type = node_data.get('nodeType')
        if node_type == 'Paper' and nid in paper_cites_ref:
            value = paper_cites_ref[nid]
            node_size = min_node_size + ((max_node_size - min_node_size) * (value - min_cites_cnt)) / cites_range
            node_data['vizSize'] = max(min_node_size, min(max_node_size, node_size)) # Clamp size
        elif node_type == 'Author' and nid in author_writes_ref:
            value = author_writes_ref[nid]
            node_size = min_node_size + ((max_node_size - min_node_size) * (value - min_writes_cnt)) / writes_range
            node_data['vizSize'] = max(min_node_size, min(max_node_size, node_size)) # Clamp size


def assign_edge_weight(
        G,
        edge_type_weight_ref,
        default_weight: Optional[float] = 0.1
        ):
    # edge_type_weight_ref = {'CITES':0.5, 'DISCUSS':0.4, 'WRITES':0.3, 'WORKS_IN':0.2, 'PRINTS_ON':0.1, 'RELEASES_IN':0.1}
    # Corrected: Use G.edges
    for u, v, data in G.edges(data=True):
        weight = data.get('weight')
        if weight is None:
            node_type = data.get('relationshipType')
            weight = edge_type_weight_ref.get(node_type)
            if weight is None:
                weight = default_weight
            data['weight'] = weight # Assign calculated weight back
        # Ensure vizWidth is always set based on the final weight
        data['vizWidth'] = data['weight'] * 10 # Scale weight for better visibility if needed


def assign_node_color(
        G,
        sig_nid_lst: Optional[List[str]] = None,
        default_colormap_name: Optional[str] = 'tab20', # Seaborn color map like 'tab10', 'colorblind', 'deep', 'muted' are good choices
        default_color_cnt: Optional[int] = 10 # Increased default count for tab20
        ):
    """assign color to node
    for significant node, add highlight border
    """
    highlight_border_color = '#FFD700' # Gold/Yellow - stands out well
    highlight_border_width = 4         # Significantly thicker border for highlighted nodes
    normal_border_width = 1           # Normal border width for non-highlighted nodes
    default_node_color = '#CCCCCC'     # Default color if type is missing or unmapped
    default_border_color = '#888888' # Default border color

    node_types_lst = [G.nodes[nid].get('nodeType') for nid in G.nodes]
    node_types_cnt = Counter(node_types_lst)
    # Get unique types, filtering out None if present, but handle None later
    unique_node_types = sorted([t for t in node_types_cnt if t is not None])
    unique_node_cnt = len(unique_node_types)

    if unique_node_cnt == 0:
        colors_hex = []
    elif unique_node_cnt <= default_color_cnt:
        colors_hex = sns.color_palette(default_colormap_name, n_colors=unique_node_cnt).as_hex()
    else:
        colors_hex = sns.color_palette(default_colormap_name, n_colors=default_color_cnt).as_hex()
        colors_hex.extend(['#808080']*(unique_node_cnt - default_color_cnt))

    # Create a mapping from node type to its assigned color
    type_to_color = dict(zip(unique_node_types, colors_hex))

    # Assign colors and border properties to nodes in the graph
    for nid, node_data in G.nodes(data=True):
        node_type = node_data.get('nodeType') # Use .get() for safety

        # Determine base color
        original_color = type_to_color.get(node_type, default_node_color) # Use default if type is None or not mapped
        node_data['vizColor'] = original_color

        # Determine border properties based on significance
        if sig_nid_lst is not None and nid in sig_nid_lst:
            node_data['vizBorderColor'] = highlight_border_color
            node_data['vizBorderWidth'] = highlight_border_width
        else:
            # Subtle border using a slightly darker shade of the node color or a fixed grey
            # node_data['vizBorderColor'] = original_color
            node_data['vizBorderColor'] = default_border_color # Use a fixed subtle border color
            node_data['vizBorderWidth'] = normal_border_width

def assign_edge_color(
        G,
        default_colormap_name: Optional[str] = 'Pastel1', # Use a different palette for edges
        default_color_cnt: Optional[int] = 9 # Pastel1 has 9 colors
        ):
    """assign color to edge"""
    default_edge_color = '#AAAAAA' # Default color for unmapped or None types

    # Corrected: Use G.edges(data=True)
    edge_types_lst = [d.get('relationshipType') for u, v, d in G.edges(data=True)]
    edge_types_cnt = Counter(edge_types_lst)
    # Filter out None, handle it later
    unique_edge_types = sorted([t for t in edge_types_cnt if t is not None])
    unique_edge_cnt = len(unique_edge_types)

    if unique_edge_cnt == 0:
        colors_hex = []
    elif unique_edge_cnt <= default_color_cnt:
        # Use the full palette if enough colors
        colors_hex = sns.color_palette(default_colormap_name, n_colors=unique_edge_cnt).as_hex()
    else:
        # Use the available colors and add grey for the rest
        colors_hex = sns.color_palette(default_colormap_name, n_colors=default_color_cnt).as_hex()
        colors_hex.extend(['#D3D3D3']*(unique_edge_cnt - default_color_cnt)) # Light grey for extras

    # Create a mapping from edge type to its assigned color
    type_to_color = dict(zip(unique_edge_types, colors_hex))

    # Assign colors to edges in the graph
    # Corrected: Use G.edges(data=True)
    for u, v, edge_data in G.edges(data=True):
        edge_type = edge_data.get('relationshipType') # Use .get() for safety
        edge_color = type_to_color.get(edge_type, default_edge_color) # Fallback to default
        edge_data['vizColor'] = edge_color


# =============================================================================
# 2. Bokeh Visualization Function
# =============================================================================

def visualize_graph_bokeh(G, title="new_test"):
    """
    Visualizes a preprocessed NetworkX MultiDiGraph using Bokeh.

    Args:
        G (nx.MultiDiGraph): The graph with 'viz*' attributes already added.
        title (str): The title for the Bokeh plot.
    """
    # --- 1. Calculate Layout ---
    # spring_layout is often good for general graphs. Adjust k for spacing.
    # kamada_kawai_layout is another good option, often slower but potentially better layout.
    try:
        # Use spring_layout, may need more iterations for large graphs
        pos = nx.spring_layout(G, k=0.5, iterations=50, seed=42)
        print("Layout calculated using spring_layout.")
    except Exception as e:
        print(f"Spring layout failed ({e}), trying Kamada-Kawai layout.")
        try:
            pos = nx.kamada_kawai_layout(G)
            print("Layout calculated using kamada_kawai_layout.")
        except Exception as e2:
            print(f"Kamada-Kawai layout also failed ({e2}), using random layout.")
            pos = nx.random_layout(G, seed=42)
            print("Layout calculated using random_layout.")


    # --- 2. Prepare Data Sources ---
    # Extract node attributes into a dictionary for ColumnDataSource
    node_ids = list(G.nodes())
    node_data = dict(
        index=node_ids,
        x=[pos[nid][0] for nid in node_ids],
        y=[pos[nid][1] for nid in node_ids],
        vizSize=[G.nodes[nid].get('vizSize', 10) for nid in node_ids],
        vizColor=[G.nodes[nid].get('vizColor', '#CCCCCC') for nid in node_ids],
        vizBorderColor=[G.nodes[nid].get('vizBorderColor', '#888888') for nid in node_ids],
        vizBorderWidth=[G.nodes[nid].get('vizBorderWidth', 1) for nid in node_ids],
        vizLabel=[G.nodes[nid].get('vizLabel', str(nid)) for nid in node_ids]
    )
    # Add ALL other node attributes for the click callback
    all_node_attrs = {}
    if G.nodes:
        first_node_data = next(iter(G.nodes(data=True)))[1]
        for key in first_node_data.keys():
            # Add attribute if it's not already handled explicitly
            if key not in ['vizSize', 'vizColor', 'vizBorderColor', 'vizBorderWidth', 'vizLabel']:
                # Prefix to avoid name clashes and easily identify in JS
                all_node_attrs[f"attr_{key}"] = [G.nodes[nid].get(key, 'N/A') for nid in node_ids]

    node_data.update(all_node_attrs)
    node_source = ColumnDataSource(data=node_data)

    # Extract edge attributes for ColumnDataSource
    start_nodes = [u for u, v, k in G.edges(keys=True)]
    end_nodes = [v for u, v, k in G.edges(keys=True)]
    edge_data = dict(
        start=start_nodes,
        end=end_nodes,
        vizWidth=[data.get('vizWidth', 1) for u, v, data in G.edges(data=True)],
        vizColor=[data.get('vizColor', '#AAAAAA') for u, v, data in G.edges(data=True)],
        vizLabel=[data.get('vizLabel', '') for u, v, data in G.edges(data=True)]
    )
    # Add ALL other edge attributes for the click callback
    all_edge_attrs = {}
    if G.edges:
        # Need to handle potential key differences if edges have varied attributes
        # Let's gather all unique keys first
        all_keys = set()
        for u, v, data in G.edges(data=True):
            all_keys.update(data.keys())

        for key in all_keys:
            if key not in ['vizWidth', 'vizColor', 'vizLabel']:
                # Prefix to avoid name clashes
                all_edge_attrs[f"attr_{key}"] = [G.get_edge_data(u, v, k).get(key, 'N/A')
                                                for u, v, k in G.edges(keys=True)] # Iterate with keys

    edge_data.update(all_edge_attrs)
    edge_source = ColumnDataSource(data=edge_data)

    # --- 3. Create Bokeh Plot ---
    plot = figure(title=title,
                  x_range=(-1.1, 1.1), y_range=(-1.1, 1.1), # Adjust range based on layout
                  tools="pan,wheel_zoom,box_zoom,reset,save", # Basic interaction tools
                  width=800, height=600, # Adjust size as needed
                  x_axis_location=None, y_axis_location=None) # Hide axes
    plot.grid.grid_line_color = None # Hide grid lines

    # --- 4. Setup GraphRenderer ---
    graph_renderer = GraphRenderer()
    graph_renderer.layout_provider = StaticLayoutProvider(graph_layout=pos)

    graph_renderer.node_renderer.data_source = node_source
    graph_renderer.edge_renderer.data_source = edge_source

    # --- 5. Configure Node Glyphs ---
    graph_renderer.node_renderer.glyph = Circle(
        radius='vizSize',
        fill_color='vizColor',
        line_color='vizBorderColor',
        line_width='vizBorderWidth',
        fill_alpha=0.8, # Slight transparency
        line_alpha=1.0
    )
    # Set selection and non-selection appearance for nodes
    graph_renderer.node_renderer.selection_glyph = Circle(
        radius='vizSize', fill_color='vizColor', line_color='red', line_width=3)
    graph_renderer.node_renderer.hover_glyph = Circle(
        radius='vizSize', fill_color='vizColor', line_color='orange', line_width=3)

    # --- 6. Configure Edge Glyphs ---
    graph_renderer.edge_renderer.glyph = MultiLine(
        line_color='vizColor',
        line_width='vizWidth',
        line_alpha=0.6 # Make edges slightly transparent
    )
    # Set selection and non-selection appearance for edges
    graph_renderer.edge_renderer.selection_glyph = MultiLine(
        line_color='red', line_width='vizWidth', line_alpha=1.0)
    graph_renderer.edge_renderer.hover_glyph = MultiLine(
        line_color='orange', line_width='vizWidth', line_alpha=1.0)


    # --- 7. Add Renderer to Plot ---
    plot.renderers.append(graph_renderer)

    # --- 8. Configure HoverTool ---
    # Tooltip for Nodes
    node_hover_tooltips = [
        ("Label", "@vizLabel"),
        ("Node ID", "@index"),
        # ("Type", "@attr_nodeType"), # Example if 'attr_nodeType' exists
    ]
    node_hover = HoverTool(tooltips=node_hover_tooltips, renderers=[graph_renderer.node_renderer])

    # Tooltip for Edges
    edge_hover_tooltips = [
        ("Label", "@vizLabel"),
        # ("Type", "@attr_relationshipType"), # Example if 'attr_relationshipType' exists
        # ("Weight", "@attr_weight"),          # Example if 'attr_weight' exists
    ]
    edge_hover = HoverTool(tooltips=edge_hover_tooltips, renderers=[graph_renderer.edge_renderer])

    plot.add_tools(node_hover, edge_hover)

    # --- 9. Configure TapTool and Info Display ---
    # Add a Div to display information on click
    info_div = Div(text="Click on a node or edge to see its details.", width=780)

    # JavaScript callback for TapTool (handles both nodes and edges)
    # This JS code accesses the selected data from the sources and updates the Div
    # It iterates through keys starting with 'attr_' to show all original attributes
    callback_code = """
        const node_indices = node_source.selected.indices;
        const edge_indices = edge_source.selected.indices;
        let html = "<b>Selected Element Details:</b><br><hr>";

        if (node_indices.length > 0) {
            const index = node_indices[0]; // Show info for the first selected node
            html += "<b>Type:</b> Node<br>";
            html += "<b>ID:</b> " + node_source.data['index'][index] + "<br>";
            for (const key in node_source.data) {
                // Display vizLabel separately if needed, or rely on attr_ fields
                // if (key === 'vizLabel') {
                //     html += "<b>Label:</b> " + node_source.data[key][index] + "<br>";
                // }
                if (key.startsWith('attr_')) {
                    const attr_name = key.substring(5); // Remove 'attr_' prefix
                    html += "<b>" + attr_name + ":</b> " + node_source.data[key][index] + "<br>";
                }
            }

        } else if (edge_indices.length > 0) {
            const index = edge_indices[0]; // Show info for the first selected edge
            html += "<b>Type:</b> Edge<br>";
            html += "<b>From:</b> " + edge_source.data['start'][index] + "<br>";
            html += "<b>To:</b> " + edge_source.data['end'][index] + "<br>";
            for (const key in edge_source.data) {
                // if (key === 'vizLabel') {
                //     html += "<b>Label:</b> " + edge_source.data[key][index] + "<br>";
                // }
                if (key.startsWith('attr_')) {
                    const attr_name = key.substring(5); // Remove 'attr_' prefix
                    html += "<b>" + attr_name + ":</b> " + edge_source.data[key][index] + "<br>";
                }
            }

        } else {
            html = "Click on a node or edge to see its details.";
        }

        info_div.text = html;
    """

    # Attach the callback to changes in selected indices of BOTH sources
    tap_callback = CustomJS(args=dict(node_source=graph_renderer.node_renderer.data_source,
                                     edge_source=graph_renderer.edge_renderer.data_source,
                                     info_div=info_div),
                            code=callback_code)

    # Add TapTool to the plot
    tap_tool = TapTool(renderers=[graph_renderer.node_renderer, graph_renderer.edge_renderer],
                        callback=tap_callback) # Using callback directly on TapTool
    plot.add_tools(tap_tool)

    # Configure interaction policies (optional but recommended)
    graph_renderer.selection_policy = NodesAndLinkedEdges() # Select node and its edges
    graph_renderer.inspection_policy = EdgesAndLinkedNodes() # Hover edge and its nodes

    # --- 10. Layout and Show ---
    layout = column(info_div, plot) # Arrange Div above the plot
    show(layout)


# =============================================================================
# 3. Example Usage
# =============================================================================
if __name__ == '__main__':
    # --- Create a Sample MultiDiGraph ---
    G = nx.MultiDiGraph()

    # Add nodes with types and attributes
    G.add_node("Paper1", nodeType='Paper', title='Intro to Graphs', year=2021, vizLabel="P1:Intro") # Add some base attributes
    G.add_node("Paper2", nodeType='Paper', title='Advanced Networks', year=2022, vizLabel="P2:Adv")
    G.add_node("Paper3", nodeType='Paper', title='Visualization Techniques', year=2023, vizLabel="P3:Viz")
    G.add_node("Author1", nodeType='Author', name='Alice', affiliation='Inst A', vizLabel="A1:Alice")
    G.add_node("Author2", nodeType='Author', name='Bob', affiliation='Inst B', vizLabel="A2:Bob")
    G.add_node("Venue1", nodeType='Venue', name='Conf X', vizLabel="V1:ConfX")
    G.add_node("Journal1", nodeType='Journal', name='Journal Y', vizLabel="J1:JY")
    G.add_node("MissingTypeNode", vizLabel="M?") # Node without 'nodeType'

    # Add edges with types and attributes
    G.add_edge("Author1", "Paper1", relationshipType='WRITES', weight=0.8, vizLabel="A1 writes P1")
    G.add_edge("Author1", "Paper2", relationshipType='WRITES', weight=0.9, vizLabel="A1 writes P2")
    G.add_edge("Author2", "Paper1", relationshipType='WRITES', weight=0.7, vizLabel="A2 writes P1")
    G.add_edge("Author2", "Paper3", relationshipType='WRITES', weight=0.8, vizLabel="A2 writes P3")
    G.add_edge("Paper2", "Paper1", relationshipType='CITES', weight=0.5, vizLabel="P2 cites P1") # P1 cited once
    G.add_edge("Paper3", "Paper1", relationshipType='CITES', weight=0.6, vizLabel="P3 cites P1") # P1 cited twice
    G.add_edge("Paper3", "Paper2", relationshipType='CITES', weight=0.4, vizLabel="P3 cites P2") # P2 cited once
    G.add_edge("Paper1", "Venue1", relationshipType='RELEASES_IN', weight=0.2, vizLabel="P1 in V1")
    G.add_edge("Paper2", "Journal1", relationshipType='PRINTS_ON', weight=0.3, vizLabel="P2 on J1")
    G.add_edge("Paper3", "Venue1", relationshipType='RELEASES_IN', weight=0.2, vizLabel="P3 in V1")
    # Add an edge without a relationshipType
    G.add_edge("Author1", "Author2", vizLabel="A1 -> A2 (Unknown)")
    # Add a parallel edge
    G.add_edge("Author1", "Paper1", key="review", relationshipType='REVIEWS', weight=0.1, vizLabel="A1 reviews P1")

    # --- Run Preprocessing ---
    print("Running preprocessing...")
    node_key_ref = {'Paper': 'title', 'Author': 'name', 'Venue': 'name', 'Journal': 'name'} # Define labels
    edge_type_weight_ref = {'CITES':0.5, 'DISCUSS':0.4, 'WRITES':0.3, 'WORKS_IN':0.2, 'PRINTS_ON':0.1, 'RELEASES_IN':0.1, 'REVIEWS': 0.05}

    # Apply your functions (assuming they are defined above or imported)
    add_node_label(G, node_key_ref)
    add_edges_label(G) # Make sure this handles MultiDiGraph keys if needed
    assign_node_size(G, sig_nid_lst=["Paper1"], min_node_size=8, max_node_size=30) # Highlight Paper1
    assign_edge_weight(G, edge_type_weight_ref, default_weight=0.05)
    assign_node_color(G, sig_nid_lst=["Paper1"], default_colormap_name='tab10')
    assign_edge_color(G, default_colormap_name='Pastel2')
    print("Preprocessing complete.")

    # --- Visualize ---
    print("Generating Bokeh visualization...")
    visualize_graph_bokeh(G, title="Interactive Publication Network")
    print("Done.")

In [2]:
import networkx as nx
from bokeh.plotting import figure, show # Removed from_networkx import
from bokeh.models import (Circle, MultiLine, EdgesAndLinkedNodes, NodesAndLinkedEdges,
                          HoverTool, TapTool, BoxSelectTool,
                          ColumnDataSource, StaticLayoutProvider, Div, CustomJS,
                          GraphRenderer) # Added GraphRenderer import explicitly
from bokeh.layouts import column, row
from bokeh.palettes import Spectral4 # Example palette
import pandas as pd
from collections import Counter # Ensure Counter is imported for preprocessing functions
import seaborn as sns # Ensure seaborn is imported for preprocessing functions
import matplotlib.pyplot as plt # Ensure matplotlib is imported for preprocessing functions
from typing import List, Dict, Optional

# =============================================================================
# 1. Your Preprocessing Functions (Keep them exactly as you provided)
#    ... (add_node_label, add_edges_label, assign_node_size, ...)
# =============================================================================
# (Your preprocessing functions go here - unchanged from your original code)

def add_node_label(G, node_key_ref):
    # node_key_ref = {'Paper': 'title', 'Author': 'name', 'Affiliation': '', 'Journal': '', 'Venue': ''}
    for nid, node_data in G.nodes(data=True):
        node_type = node_data.get('nodeType')
        label_key = node_key_ref.get(node_type)
        # Ensure label_key is valid and the attribute exists
        if label_key and label_key in node_data:
            node_data['vizLabel'] = node_data.get(label_key, f"ID: {nid}") # Fallback label
        elif 'nodeType' in node_data:
            node_data['vizLabel'] = f"{node_data['nodeType']} ID: {nid}" # Fallback with type
        else:
            node_data['vizLabel'] = f"ID: {nid}" # Generic fallback

def add_edges_label(G):
    # Corrected: Use G.edges, not G.edgs
    for u, v, key, edge_data in G.edges(data=True, keys=True): # Use keys=True for MultiDiGraph
        rel_type = edge_data.get('relationshipType', 'UNKNOWN')
        # Use key to distinguish parallel edges if necessary
        edge_data['vizLabel'] = f"{u} -> {v} ({rel_type}, key={key})"

def assign_node_size(
        G,
        sig_nid_lst: Optional[List[str]] = None,
        min_node_size: Optional[int] = 10,
        max_node_size: Optional[int] = 50,
        ):
    """assign node size (Corrected Logic)"""
    paper_cites_ref, author_writes_ref = {}, {}

    # First pass: Calculate counts for all relevant nodes
    for nid, node_data in G.nodes(data=True):
        node_type = node_data.get('nodeType')
        if node_type == 'Paper':
            # Corrected: Use G.in_edges(nid, data=True)
            in_edges_info = G.in_edges(nid, data=True)
            # Corrected: Check the edge data directly
            cites_cnt = sum(1 for u, v, data in in_edges_info if data.get('relationshipType') == 'CITES')
            paper_cites_ref[nid] = cites_cnt
        elif node_type == 'Author':
            # Corrected: Use G.out_edges(nid, data=True)
            out_edges_info = G.out_edges(nid, data=True)
            # Corrected: Check the edge data directly
            writes_cnt = sum(1 for u, v, data in out_edges_info if data.get('relationshipType') == 'WRITES')
            author_writes_ref[nid] = writes_cnt

    # Determine min/max counts *after* collecting all counts
    max_cites_cnt = max(paper_cites_ref.values()) if paper_cites_ref else 0
    min_cites_cnt = min(paper_cites_ref.values()) if paper_cites_ref else 0
    max_writes_cnt = max(author_writes_ref.values()) if author_writes_ref else 0
    min_writes_cnt = min(author_writes_ref.values()) if author_writes_ref else 0

    # Avoid division by zero if all counts are the same
    cites_range = max_cites_cnt - min_cites_cnt if max_cites_cnt > min_cites_cnt else 1
    writes_range = max_writes_cnt - min_writes_cnt if max_writes_cnt > min_writes_cnt else 1

    # Second pass: Assign sizes
    for nid, node_data in G.nodes(data=True):
        node_data['vizSize'] = min_node_size # Default size

        # Override for significant nodes
        if sig_nid_lst is not None and nid in sig_nid_lst:
            node_data['vizSize'] = max_node_size
            continue # Skip dynamic sizing if it's a significant node

        # Apply dynamic sizing based on counts for non-significant nodes
        node_type = node_data.get('nodeType')
        if node_type == 'Paper' and nid in paper_cites_ref:
            value = paper_cites_ref[nid]
            # Ensure value is numeric before calculation
            if isinstance(value, (int, float)):
                node_size = min_node_size + ((max_node_size - min_node_size) * (value - min_cites_cnt)) / cites_range
                node_data['vizSize'] = max(min_node_size, min(max_node_size, node_size)) # Clamp size
        elif node_type == 'Author' and nid in author_writes_ref:
            value = author_writes_ref[nid]
             # Ensure value is numeric before calculation
            if isinstance(value, (int, float)):
                node_size = min_node_size + ((max_node_size - min_node_size) * (value - min_writes_cnt)) / writes_range
                node_data['vizSize'] = max(min_node_size, min(max_node_size, node_size)) # Clamp size


def assign_edge_weight(
        G,
        edge_type_weight_ref,
        default_weight: Optional[float] = 0.1
        ):
    # edge_type_weight_ref = {'CITES':0.5, 'DISCUSS':0.4, 'WRITES':0.3, 'WORKS_IN':0.2, 'PRINTS_ON':0.1, 'RELEASES_IN':0.1}
    # Iterate through edges using keys for MultiDiGraph
    for u, v, k, data in G.edges(data=True, keys=True):
        weight = data.get('weight')
        if weight is None:
            edge_type = data.get('relationshipType') # Changed variable name for clarity
            weight = edge_type_weight_ref.get(edge_type)
            if weight is None:
                weight = default_weight
            data['weight'] = weight # Assign calculated weight back

        # Ensure vizWidth is always set based on the final weight
        # Convert weight to float before multiplication
        try:
           data['vizWidth'] = float(data['weight']) * 10 # Scale weight for better visibility
        except (ValueError, TypeError):
           data['vizWidth'] = float(default_weight) * 10 # Fallback width


def assign_node_color(
        G,
        sig_nid_lst: Optional[List[str]] = None,
        default_colormap_name: Optional[str] = 'tab20', # Seaborn color map like 'tab10', 'colorblind', 'deep', 'muted' are good choices
        default_color_cnt: Optional[int] = 10 # Increased default count for tab20
        ):
    """assign color to node
    for significant node, add highlight border
    """
    highlight_border_color = '#FFD700' # Gold/Yellow - stands out well
    highlight_border_width = 4        # Significantly thicker border for highlighted nodes
    normal_border_width = 1         # Normal border width for non-highlighted nodes
    default_node_color = '#CCCCCC'    # Default color if type is missing or unmapped
    default_border_color = '#888888' # Default border color

    node_types_lst = [G.nodes[nid].get('nodeType') for nid in G.nodes]
    node_types_cnt = Counter(node_types_lst)
     # Get unique types, filtering out None if present, but handle None later
    unique_node_types = sorted([t for t in node_types_cnt if t is not None])
    unique_node_cnt = len(unique_node_types)

    if unique_node_cnt == 0:
        colors_hex = []
    elif unique_node_cnt <= default_color_cnt:
        colors_hex = sns.color_palette(default_colormap_name, n_colors=unique_node_cnt).as_hex()
    else:
        colors_hex = sns.color_palette(default_colormap_name, n_colors=default_color_cnt).as_hex()
        colors_hex.extend(['#808080']*(unique_node_cnt - default_color_cnt))

    # Create a mapping from node type to its assigned color
    type_to_color = dict(zip(unique_node_types, colors_hex))

    # Assign colors and border properties to nodes in the graph
    for nid, node_data in G.nodes(data=True):
        node_type = node_data.get('nodeType') # Use .get() for safety

        # Determine base color
        original_color = type_to_color.get(node_type, default_node_color) # Use default if type is None or not mapped
        node_data['vizColor'] = original_color

        # Determine border properties based on significance
        if sig_nid_lst is not None and nid in sig_nid_lst:
            node_data['vizBorderColor'] = highlight_border_color
            node_data['vizBorderWidth'] = highlight_border_width
        else:
            # Subtle border using a slightly darker shade of the node color or a fixed grey
            # node_data['vizBorderColor'] = original_color
            node_data['vizBorderColor'] = default_border_color # Use a fixed subtle border color
            node_data['vizBorderWidth'] = normal_border_width

def assign_edge_color(
        G,
        default_colormap_name: Optional[str] = 'Pastel1', # Use a different palette for edges
        default_color_cnt: Optional[int] = 9 # Pastel1 has 9 colors
        ):
    """assign color to edge"""
    default_edge_color = '#AAAAAA' # Default color for unmapped or None types

    # Use G.edges(data=True, keys=True) for MultiDiGraph
    edge_types_lst = [d.get('relationshipType') for u, v, k, d in G.edges(data=True, keys=True)]
    edge_types_cnt = Counter(edge_types_lst)
    # Filter out None, handle it later
    unique_edge_types = sorted([t for t in edge_types_cnt if t is not None])
    unique_edge_cnt = len(unique_edge_types)

    if unique_edge_cnt == 0:
        colors_hex = []
    elif unique_edge_cnt <= default_color_cnt:
        # Use the full palette if enough colors
        colors_hex = sns.color_palette(default_colormap_name, n_colors=unique_edge_cnt).as_hex()
    else:
        # Use the available colors and add grey for the rest
        colors_hex = sns.color_palette(default_colormap_name, n_colors=default_color_cnt).as_hex()
        colors_hex.extend(['#D3D3D3']*(unique_edge_cnt - default_color_cnt)) # Light grey for extras

    # Create a mapping from edge type to its assigned color
    type_to_color = dict(zip(unique_edge_types, colors_hex))

    # Assign colors to edges in the graph
    # Use G.edges(data=True, keys=True) for MultiDiGraph
    for u, v, k, edge_data in G.edges(data=True, keys=True):
        edge_type = edge_data.get('relationshipType') # Use .get() for safety
        edge_color = type_to_color.get(edge_type, default_edge_color) # Fallback to default
        edge_data['vizColor'] = edge_color


# =============================================================================
# 2. Bokeh Visualization Function (REVISED)
# =============================================================================

def visualize_graph_bokeh(G, title="check"):
    """
    Visualizes a preprocessed NetworkX MultiDiGraph using Bokeh by manually
    configuring the GraphRenderer.

    Args:
        G (nx.MultiDiGraph): The graph with 'viz*' attributes already added.
        title (str): The title for the Bokeh plot.
    """
    # --- 1. Calculate Layout ---
    try:
        pos = nx.spring_layout(G, k=0.5, iterations=50, seed=42)
        print("Layout calculated using spring_layout.")
    except Exception as e:
        print(f"Spring layout failed ({e}), trying Kamada-Kawai layout.")
        try:
            pos = nx.kamada_kawai_layout(G)
            print("Layout calculated using kamada_kawai_layout.")
        except Exception as e2:
            print(f"Kamada-Kawai layout also failed ({e2}), using random layout.")
            pos = nx.random_layout(G, seed=42)
            print("Layout calculated using random_layout.")

    # Create a layout provider using the calculated positions
    graph_layout = StaticLayoutProvider(graph_layout=pos)

    # --- 2. Prepare Data Sources (Ensure all viz* and attr_* are included) ---
    node_ids = list(G.nodes())
    node_data = dict(
        index=node_ids, # Bokeh uses 'index' for node IDs in GraphRenderer
        # x and y positions are handled by the layout_provider, not needed in source
        vizSize=[G.nodes[nid].get('vizSize', 10) for nid in node_ids],
        vizColor=[G.nodes[nid].get('vizColor', '#CCCCCC') for nid in node_ids],
        vizBorderColor=[G.nodes[nid].get('vizBorderColor', '#888888') for nid in node_ids],
        vizBorderWidth=[G.nodes[nid].get('vizBorderWidth', 1) for nid in node_ids],
        vizLabel=[G.nodes[nid].get('vizLabel', str(nid)) for nid in node_ids]
    )
    # Add ALL other node attributes for the click callback
    all_node_attrs = {}
    if G.nodes:
        # Collect all unique attribute keys from all nodes
        all_keys = set()
        for nid in G.nodes():
           all_keys.update(G.nodes[nid].keys())

        for key in all_keys:
            if key not in ['vizSize', 'vizColor', 'vizBorderColor', 'vizBorderWidth', 'vizLabel']:
                all_node_attrs[f"attr_{key}"] = [G.nodes[nid].get(key, 'N/A') for nid in node_ids]

    node_data.update(all_node_attrs)
    node_source = ColumnDataSource(data=node_data)

    # Extract edge attributes for ColumnDataSource
    # Ensure start/end nodes match the node_ids ('index' in node_source)
    start_nodes = [u for u, v, k in G.edges(keys=True)]
    end_nodes = [v for u, v, k in G.edges(keys=True)]
    edge_data = dict(
        start=start_nodes, # Bokeh uses 'start' and 'end' for edge connections
        end=end_nodes,
        vizWidth=[data.get('vizWidth', 1) for u, v, k, data in G.edges(data=True, keys=True)],
        vizColor=[data.get('vizColor', '#AAAAAA') for u, v, k, data in G.edges(data=True, keys=True)],
        vizLabel=[data.get('vizLabel', '') for u, v, k, data in G.edges(data=True, keys=True)]
    )
     # Add ALL other edge attributes for the click callback
    all_edge_attrs = {}
    if G.edges:
        # Collect all unique attribute keys from all edges
        all_keys = set()
        for u, v, k, data in G.edges(data=True, keys=True):
            all_keys.update(data.keys())

        for key in all_keys:
            if key not in ['vizWidth', 'vizColor', 'vizLabel']:
                # Ensure edge data access uses keys for MultiDiGraph
                all_edge_attrs[f"attr_{key}"] = [G.get_edge_data(u, v, k).get(key, 'N/A')
                                                  for u, v, k in G.edges(keys=True)]

    edge_data.update(all_edge_attrs)
    edge_source = ColumnDataSource(data=edge_data)


    # --- 3. Create Bokeh Plot ---
    plot = figure(title=title,
                  x_range=(-1.1, 1.1), y_range=(-1.1, 1.1),
                  tools="pan,wheel_zoom,box_zoom,reset,save", # Basic tools + TapTool added later
                  width=800, height=600,
                  x_axis_location=None, y_axis_location=None)
    plot.grid.grid_line_color = None

    # --- 4. Manually Create and Configure GraphRenderer ---
    graph_renderer = GraphRenderer()

    # Assign layout provider
    graph_renderer.layout_provider = graph_layout

    # Configure Node Renderer
    graph_renderer.node_renderer.data_source = node_source # Use the prepared source
    graph_renderer.node_renderer.glyph = Circle(
        radius='vizSize', # Use 'radius' instead of 'size' for Circle
        fill_color='vizColor',
        line_color='vizBorderColor',
        line_width='vizBorderWidth',
        fill_alpha=0.8,
        line_alpha=1.0
    )
    graph_renderer.node_renderer.selection_glyph = Circle(
        radius='vizSize', fill_color='vizColor', line_color='red', line_width=3, fill_alpha=0.9, line_alpha=1.0)
    graph_renderer.node_renderer.hover_glyph = Circle(
        radius='vizSize', fill_color='vizColor', line_color='orange', line_width=3, fill_alpha=0.9, line_alpha=1.0)

    # Configure Edge Renderer
    graph_renderer.edge_renderer.data_source = edge_source # Use the prepared source
    graph_renderer.edge_renderer.glyph = MultiLine(
        line_color='vizColor',
        line_width='vizWidth',
        line_alpha=0.6
    )
    graph_renderer.edge_renderer.selection_glyph = MultiLine(
        line_color='red', line_width='vizWidth', line_alpha=1.0)
    graph_renderer.edge_renderer.hover_glyph = MultiLine(
        line_color='orange', line_width='vizWidth', line_alpha=1.0)

    # --- 5. Configure Interaction Policies ---
    # These define how selections and hovers affect linked elements
    graph_renderer.selection_policy = NodesAndLinkedEdges() # Clicking node selects node + its edges
    graph_renderer.inspection_policy = EdgesAndLinkedNodes() # Hovering edge highlights edge + connected nodes

    # --- 6. Add Renderer to Plot ---
    plot.renderers.append(graph_renderer)

    # --- 7. Configure HoverTool (Ensure it only targets specific renderers) ---
    node_hover_tooltips = [
        ("Label", "@vizLabel"),
        ("Node ID", "@index"), # Use 'index' which is the node ID column
        ("Type", "@attr_nodeType"), # Example showing an original attribute
    ]
    # Filter out tooltips for attributes that might not exist in all nodes
    node_hover_tooltips = [(label, field) for label, field in node_hover_tooltips if field in node_source.data]

    node_hover = HoverTool(tooltips=node_hover_tooltips,
                           renderers=[graph_renderer.node_renderer]) # *** Crucial: Target only nodes ***

    edge_hover_tooltips = [
        ("Label", "@vizLabel"),
        ("Type", "@attr_relationshipType"), # Example
        ("Weight", "@attr_weight"),         # Example
    ]
    # Filter out tooltips for attributes that might not exist in all edges
    edge_hover_tooltips = [(label, field) for label, field in edge_hover_tooltips if field in edge_source.data]

    edge_hover = HoverTool(tooltips=edge_hover_tooltips,
                           renderers=[graph_renderer.edge_renderer]) # *** Crucial: Target only edges ***

    # Add ONLY these specific hover tools
    plot.add_tools(node_hover, edge_hover)

    # --- 8. Configure TapTool and Info Display ---
    info_div = Div(text="Click on a node or edge to see its details.", width=780, height=100, styles={'overflow-y': 'auto'}) # Added height and scroll

    # JavaScript callback - unchanged, but ensure args point to the correct sources
    callback_code = """
        const node_indices = node_source.selected.indices;
        const edge_indices = edge_source.selected.indices;
        let html = "<b>Selected Element Details:</b><br><hr>";

        // Helper function to safely get data and format N/A or null
        function getData(source, key, index) {
            if (source.data[key] && index < source.data[key].length) {
                const value = source.data[key][index];
                return (value === null || value === undefined) ? 'N/A' : value;
            }
            return 'N/A';
        }

        if (node_indices.length > 0) {
            const index = node_indices[0]; // Show info for the first selected node
            html += "<b>Type:</b> Node<br>";
            html += "<b>ID:</b> " + getData(node_source, 'index', index) + "<br>";
            // Iterate through keys, prioritize vizLabel, then show attr_ fields
            if ('vizLabel' in node_source.data) {
                 html += "<b>Label:</b> " + getData(node_source, 'vizLabel', index) + "<br>";
            }
            for (const key in node_source.data) {
                if (key.startsWith('attr_')) {
                    const attr_name = key.substring(5); // Remove 'attr_' prefix
                    // Avoid duplicating label if it came from an attr_ field used for vizLabel
                    if (key !== 'attr_vizLabel') {
                         html += "<b>" + attr_name + ":</b> " + getData(node_source, key, index) + "<br>";
                    }
                }
            }

        } else if (edge_indices.length > 0) {
            const index = edge_indices[0]; // Show info for the first selected edge
            html += "<b>Type:</b> Edge<br>";
            html += "<b>From:</b> " + getData(edge_source, 'start', index) + "<br>";
            html += "<b>To:</b> " + getData(edge_source, 'end', index) + "<br>";
             // Iterate through keys, prioritize vizLabel, then show attr_ fields
            if ('vizLabel' in edge_source.data) {
                 html += "<b>Label:</b> " + getData(edge_source, 'vizLabel', index) + "<br>";
            }
            for (const key in edge_source.data) {
                 if (key.startsWith('attr_')) {
                    const attr_name = key.substring(5); // Remove 'attr_' prefix
                    // Avoid duplicating label if it came from an attr_ field used for vizLabel
                    if (key !== 'attr_vizLabel') {
                        html += "<b>" + attr_name + ":</b> " + getData(edge_source, key, index) + "<br>";
                    }
                }
            }

        } else {
            html = "Click on a node or edge to see its details.";
        }

        info_div.text = html;
    """

    tap_callback = CustomJS(args=dict(node_source=graph_renderer.node_renderer.data_source, # Use renderer's source
                                      edge_source=graph_renderer.edge_renderer.data_source, # Use renderer's source
                                      info_div=info_div),
                            code=callback_code)

    # Add TapTool, ensuring it triggers the callback for clicks on nodes/edges
    tap_tool = TapTool(renderers=[graph_renderer.node_renderer, graph_renderer.edge_renderer],
                       callback=tap_callback)
    plot.add_tools(tap_tool)


    # --- 9. Layout and Show ---
    layout = column(info_div, plot)
    show(layout)


# =============================================================================
# 3. Example Usage (Keep as is)
# =============================================================================
if __name__ == '__main__':
    # --- Create a Sample MultiDiGraph ---
    G = nx.MultiDiGraph()

    # Add nodes with types and attributes
    G.add_node("Paper1", nodeType='Paper', title='Intro to Graphs', year=2021) # Removed vizLabel here, let preprocessing handle it
    G.add_node("Paper2", nodeType='Paper', title='Advanced Networks', year=2022)
    G.add_node("Paper3", nodeType='Paper', title='Visualization Techniques', year=2023)
    G.add_node("Author1", nodeType='Author', name='Alice', affiliation='Inst A')
    G.add_node("Author2", nodeType='Author', name='Bob', affiliation='Inst B')
    G.add_node("Venue1", nodeType='Venue', name='Conf X')
    G.add_node("Journal1", nodeType='Journal', name='Journal Y')
    G.add_node("MissingTypeNode") # Node without 'nodeType'

    # Add edges with types and attributes
    G.add_edge("Author1", "Paper1", relationshipType='WRITES', weight=0.8) # Removed vizLabel here
    G.add_edge("Author1", "Paper2", relationshipType='WRITES', weight=0.9)
    G.add_edge("Author2", "Paper1", relationshipType='WRITES', weight=0.7)
    G.add_edge("Author2", "Paper3", relationshipType='WRITES', weight=0.8)
    G.add_edge("Paper2", "Paper1", relationshipType='CITES', weight=0.5) # P1 cited once
    G.add_edge("Paper3", "Paper1", relationshipType='CITES', weight=0.6) # P1 cited twice
    G.add_edge("Paper3", "Paper2", relationshipType='CITES', weight=0.4) # P2 cited once
    G.add_edge("Paper1", "Venue1", relationshipType='RELEASES_IN', weight=0.2)
    G.add_edge("Paper2", "Journal1", relationshipType='PRINTS_ON', weight=0.3)
    G.add_edge("Paper3", "Venue1", relationshipType='RELEASES_IN', weight=0.2)
    # Add an edge without a relationshipType
    G.add_edge("Author1", "Author2") # Let preprocessing handle label
    # Add a parallel edge
    G.add_edge("Author1", "Paper1", key="review", relationshipType='REVIEWS', weight=0.1) # Let preprocessing handle label

    # --- Run Preprocessing ---
    print("Running preprocessing...")
    # Define how to get labels from node attributes based on nodeType
    node_key_ref = {'Paper': 'title', 'Author': 'name', 'Venue': 'name', 'Journal': 'name'}
    # Define base weights for edge types (used if edge has no 'weight' attribute)
    edge_type_weight_ref = {'CITES':0.5, 'DISCUSS':0.4, 'WRITES':0.3, 'WORKS_IN':0.2, 'PRINTS_ON':0.1, 'RELEASES_IN':0.1, 'REVIEWS': 0.05}
    # Define which nodes should be highlighted (larger size, different border)
    significant_nodes = ["Paper1"]

    # Apply preprocessing functions to add 'viz*' attributes
    add_node_label(G, node_key_ref)
    add_edges_label(G) # Ensure this adds 'vizLabel' to edges
    assign_node_size(G, sig_nid_lst=significant_nodes, min_node_size=10, max_node_size=35) # Adjusted max size
    assign_edge_weight(G, edge_type_weight_ref, default_weight=0.05) # Adds 'weight' and 'vizWidth'
    assign_node_color(G, sig_nid_lst=significant_nodes, default_colormap_name='tab10') # Adds 'vizColor', 'vizBorderColor', 'vizBorderWidth'
    assign_edge_color(G, default_colormap_name='Pastel2') # Adds 'vizColor'

    # Add a small check for viz attributes after preprocessing
    print("Sample node data after preprocessing (Paper1):", G.nodes["Paper1"])
    print("Sample edge data after preprocessing (Author1 -> Paper1, WRITES):", G.get_edge_data("Author1", "Paper1")[0]) # Access first edge if parallel


    print("Preprocessing complete.")

    # --- Visualize ---
    print("Generating Bokeh visualization...")
    visualize_graph_bokeh(G, title="Interactive Publication Network")
    print("Done.")

Running preprocessing...
Sample node data after preprocessing (Paper1): {'nodeType': 'Paper', 'title': 'Intro to Graphs', 'year': 2021, 'vizLabel': 'Intro to Graphs', 'vizSize': 35, 'vizColor': '#2ca02c', 'vizBorderColor': '#FFD700', 'vizBorderWidth': 4}
Sample edge data after preprocessing (Author1 -> Paper1, WRITES): {'relationshipType': 'WRITES', 'weight': 0.8, 'vizLabel': 'Author1 -> Paper1 (WRITES, key=0)', 'vizWidth': 8.0, 'vizColor': '#e6f5c9'}
Preprocessing complete.
Generating Bokeh visualization...
Layout calculated using spring_layout.
Done.
