**GraphReader:**

- `read_and_clean_graph`: Reads a graph from a GraphML file, assigns node IDs, cleans up attribute names, and prints summary information about the graph.

**LayoutUtility:**

- `fr_layout_nx`: Performs a Fruchterman-Reingold layout on the graph using NetworkX. It allows for customization of layout parameters and prints information about the layout process.

**ZCoordinateAdder:**

- `add_z_coordinate_to_nodes`: Adds a z-coordinate to each node in the graph based on its centrality value. It calculates the bounds of the x and y coordinates, normalizes centrality values, scales the z-coordinates, and adds them to the nodes.

**GraphBundler2d:**

- `prune_edges_by_percentile_weight`: Removes edges with weights below a specified percentile threshold.
  bundle_edges: Performs edge bundling using the Hammer Bundle algorithm. It first prunes edges and then performs bundling with user-defined parameters. It also groups the bundled edges by edge ID and includes source and target positions.

**EdgeZInterpolator:**

- `interpolate_z_to_edges`: Interpolates z-coordinates for each edge in a DataFrame using cubic spline interpolation. It considers the source and target node z-coordinates and the edge path to assign z-coordinates to all points along the bundled edge.

**Apply3DEdgeBundling:**

- `apply_3d_bundling`: Applies 3D edge bundling to a graph's edges. It utilizes neighbor information, forces, and smoothing to create a bundled representation of edges in 3D space. It allows for customization of various parameters like the number of iterations, step size, smoothing iterations, and neighbor radius.

**GraphSaver:**

- `save_igraph_nodes_to_json`: Saves node data from an igraph graph to a JSON file. It offers options to return the JSON string as well as specify which node attributes to include.


# Imports


In [1]:
# Standard library imports
import json
import random
import time
from collections import defaultdict
from multiprocessing import Pool, cpu_count
from typing import Dict, List, Optional, Tuple, Union
import logging

# Data manipulation and analysis
import numpy as np
import pandas as pd

# Disable SettingWithCopyWarning
pd.options.mode.chained_assignment = None

# Graphs and networks
import networkx as nx
import igraph as ig

# Visualization
import matplotlib.pyplot as plt
import colorcet as cc
from matplotlib.colors import to_hex, to_rgb

# Data visualization and processing
import datashader as ds
import datashader.transfer_functions as tf
from datashader.bundling import hammer_bundle

# Scientific computing
from scipy.spatial import cKDTree
from scipy.interpolate import CubicSpline, interp1d


# Progress bars
from tqdm import tqdm

# Performance optimization
from numba import jit, prange

# Custom modules
from fa2_modified import ForceAtlas2

# Warnings
import warnings

In [2]:
# 2. Constants and configuration
INPUT_GRAPH_PATH = "../data/07-clustered-graphs/alpha0.3_k10_res0.002.graphml"
CLUSTER_INFO_LABEL_TREE = "../output/cluster-qualifications/ClusterInfoLabelTree.xlsx"
CLUSTER_LABEL_DICT_PATH = "../data/99-testdata/cluster_label_dict.json"
CLUSTER_TREE_PATH = "../output/cluster-qualifications/ClusterHierachy_noComments.json"
OUTPUT_DIR = "../data/99-testdata/"
THREEJS_OUTPUT_DIR = (
    "/Users/jlq293/Projects/Random Projects/LW-ThreeJS/2d_ssrinetworkviz/src/data/"
)
CLUSTER_HIERACHY_FOR_LEGEND_PATH = (
    "../output/cluster-qualifications/ClusterHierachy_noComments.json"
)

# utility functions


## Graph Reader


In [3]:
class GraphReader:
    @staticmethod
    def read_and_clean_graph(path: str) -> ig.Graph:
        g = ig.Graph.Read_GraphML(path)
        g.vs["node_id"] = [int(i) for i in range(g.vcount())]

        if "id" in g.vs.attribute_names():
            g.vs["node_name"] = g.vs["id"]
            del g.vs["id"]

        if "cluster" in g.vs.attribute_names():
            g.vs["cluster"] = [int(cluster) for cluster in g.vs["cluster"]]

        if "year" in g.vs.attribute_names():
            g.vs["year"] = [int(year) for year in g.vs["year"]]

        if "eid" in g.vs.attribute_names():
            del g.vs["eid"]

        if "centrality_alpha0.3_k10_res0.006" in g.vs.attribute_names():
            del g.vs["centrality_alpha0.3_k10_res0.006"]

        if "centrality_alpha0.3_k10_res0.002" in g.vs.attribute_names():
            g.vs["centrality"] = g.vs["centrality_alpha0.3_k10_res0.002"]
            del g.vs["centrality_alpha0.3_k10_res0.002"]

        g.es["edge_id"] = list(range(g.ecount()))
        print("Node Attributes:", g.vs.attribute_names())
        print("Edge Attributes:", g.es.attribute_names())
        # print number of nodes and edges
        print(f"Number of nodes: {g.vcount()}")
        print(f"Number of edges: {g.ecount()}")
        return g

    @staticmethod
    def subgraph_of_clusters(G, clusters):
        if isinstance(G, nx.Graph):
            nodes = [
                node for node in G.nodes if G.nodes[node].get("cluster") in clusters
            ]
            return G.subgraph(nodes)
        elif isinstance(G, ig.Graph):
            nodes = [v.index for v in G.vs if v["cluster"] in clusters]
            return G.subgraph(nodes)
        else:
            raise TypeError("Input must be a NetworkX Graph or an igraph Graph")

    @staticmethod
    def add_cluster_labels(
        G: Union[nx.Graph, ig.Graph],
        labels_file_path: str = "../output/cluster-qualifications/raw_cluster_labels.json",
    ) -> Tuple[Union[nx.Graph, ig.Graph], Dict[float, str]]:
        """
        Add cluster labels to the graph nodes.

        Args:
            G (Union[nx.Graph, ig.Graph]): The input graph (NetworkX or igraph).
            labels_file_path (str): Path to the JSON file containing cluster labels.

        Returns:
            Tuple[Union[nx.Graph, ig.Graph], Dict[float, str]]:
                The graph with added cluster labels and the cluster label dictionary.
        """
        with open(labels_file_path) as file:
            cluster_label_dict = json.load(file)
        cluster_label_dict = {float(k): v[0] for k, v in cluster_label_dict.items()}

        if isinstance(G, nx.Graph):
            for node in G.nodes:
                cluster = G.nodes[node]["cluster"]
                G.nodes[node]["cluster_label"] = cluster_label_dict.get(
                    cluster, "Unknown"
                )
        elif isinstance(G, ig.Graph):
            G.vs["cluster_label"] = [
                cluster_label_dict.get(v["cluster"], "Unknown") for v in G.vs
            ]
        else:
            raise TypeError("Input must be a NetworkX Graph or an igraph Graph")

        return G, cluster_label_dict

## LayoutUtility


In [4]:
import time
import networkx as nx
import igraph as ig
from typing import Union, Dict, Tuple, Optional


class LayoutUtility:
    """
    Layout utility class for igraph layout operations. made for fruchterman-reingold layout.

    Args:
        g (Union[nx.Graph, ig.Graph]): The input graph (NetworkX or igraph).
        layout_params (Optional[Dict]): The layout parameters.

    Returns:
        Tuple[nx.Graph, Dict]: The graph with assigned coordinates and the layout dictionary.
    """

    @staticmethod
    def fr_layout_nx(
        g: Union[nx.Graph, ig.Graph], layout_params: Optional[Dict] = None
    ) -> Tuple[nx.Graph, Dict]:
        print("Starting Fruchterman-Reingold layout process...")
        start_time = time.time()

        if layout_params is None:
            layout_params = {
                "iterations": 100,
                "threshold": 0.00001,
                "weight": "weight",
                "scale": 1,
                "center": (0, 0),
                "dim": 2,
                "seed": 1887,
            }
        print(f"Layout parameters: {layout_params}")

        if not isinstance(g, nx.Graph):
            print("Converting to NetworkX Graph...")
            G = g.to_networkx()
            print("Conversion complete.")
        else:
            G = g

        print(f"Graph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges.")

        print("Calculating layout...")
        layout_start_time = time.time()
        pos = nx.spring_layout(G, **layout_params)
        layout_end_time = time.time()
        print(
            f"Layout calculation completed in {layout_end_time - layout_start_time:.2f} seconds."
        )

        print("Processing layout results...")
        node_xy_dict = {node: pos[node] for node in G.nodes}

        x_values, y_values = zip(*node_xy_dict.values())
        min_x, max_x = min(x_values), max(x_values)
        min_y, max_y = min(y_values), max(y_values)

        print(f"Layout boundaries:")
        print(f"X-axis: Min = {min_x:.2f}, Max = {max_x:.2f}")
        print(f"Y-axis: Min = {min_y:.2f}, Max = {max_y:.2f}")

        print("Assigning coordinates to nodes...")
        for node in G.nodes:
            G.nodes[node]["x"] = node_xy_dict[node][0]
            G.nodes[node]["y"] = node_xy_dict[node][1]

        end_time = time.time()
        total_time = end_time - start_time
        print(f"Layout process completed in {total_time:.2f} seconds.")

        return G, pos


# Usage example:
# g = ... # your graph object
# G, pos = LayoutUtility.fr_layout_nx(g)

## Z Coordinate Adder

adds a z coodrinate bases on the centrality of the node


In [5]:
class ZCoordinateAdder:
    def __init__(self, g, scale_factor=0.15):
        self.g = g
        self.scale_factor = scale_factor

    def add_z_coordinate_to_nodes(self):
        """
        Add a z-coordinate to the nodes of the graph based on their centrality values.

        Args:
            g (nx.Graph): The input graph.
            scale_factor (float): The scaling factor for the z-coordinates. (they should not be as spread out as x and y)

        Returns:
            nx.Graph: The graph with the z-coordinate added to the nodes.
        """
        # Calculate the bounds of x and y coordinates
        # Assuming self.g is a NetworkX graph
        xvalues = [attributes["x"] for _, attributes in self.g.nodes(data=True)]
        yvalues = [attributes["y"] for _, attributes in self.g.nodes(data=True)]
        min_x, max_x = min(xvalues), max(xvalues)
        min_y, max_y = min(yvalues), max(yvalues)

        print("Bounds of the layout:")
        print(f"Min x: {min_x}, Max x: {max_x}")
        print(f"Min y: {min_y}, Max y: {max_y}")

        # Extract centrality values from nodes
        centralities = np.array(
            [self.g.nodes[node]["centrality"] for node in self.g.nodes]
        )

        # Normalize centrality values to range [0, 1]
        centrality_min = centralities.min()
        centrality_max = centralities.max()
        centralities_normalized = (centralities - centrality_min) / (
            centrality_max - centrality_min
        )

        # Adjust normalized centrality values to range [-1, 1]
        centralities_adjusted = centralities_normalized * 2 - 1

        # Scale down the z-values to make them less pronounced
        z_coordinates = centralities_adjusted * self.scale_factor

        # Add z-coordinate to nodes
        for i, node in enumerate(self.g.nodes):
            self.g.nodes[node]["z"] = z_coordinates[i]

        # Describe the distribution of z values
        print("Description of the Z coordinate values:")
        print(pd.Series(z_coordinates).describe())

        print("Z coordinate added to nodes")
        return self.g

## Pruning and 2D Bundling


In [103]:
class GraphBundler2d:
    """
    A class for bundling edges in graphs using the Hammer Bundle algorithm.

    This class supports both igraph and NetworkX graph objects as input.
    """

    def __init__(
        self,
        graph: Union[ig.Graph, nx.Graph],
        pruning_weight_percentile: float = 50,
        bundle_kwargs: Optional[Dict] = None,
    ):
        """
        Initialize the GraphBundler.

        Args:
            graph (Union[ig.Graph, nx.Graph]): The input graph.
            pruning_weight_percentile (float): The percentile to use for pruning edges (default is 50).
            bundle_kwargs (Optional[Dict]): Optional parameters for the bundling algorithm.
        """
        self.graph = self._ensure_igraph(graph)
        self.pruning_weight_percentile = pruning_weight_percentile
        self.bundle_kwargs = bundle_kwargs or {
            "decay": 0.90,
            "initial_bandwidth": 0.10,
            "iterations": 15,
            "include_edge_id": True,
        }
        self.bundled_edges = None

    def _ensure_igraph(self, graph: Union[ig.Graph, nx.Graph]) -> ig.Graph:
        """
        Ensure the input graph is an igraph object.

        Args:
            graph (Union[ig.Graph, nx.Graph]): The input graph.

        Returns:
            ig.Graph: The graph as an igraph object.

        Raises:
            ValueError: If the input is neither an igraph nor a NetworkX graph object.
        """
        if isinstance(graph, ig.Graph):
            return graph
        if isinstance(graph, nx.Graph):
            return ig.Graph.from_networkx(graph)
        raise ValueError(
            "Input graph must be either an igraph or NetworkX graph object."
        )

    def prune_edges_by_percentile_weight(
        self, g: ig.Graph, percentile: float
    ) -> ig.Graph:
        """
        Remove edges from the graph that have weight less than or equal to the specified percentile weight.

        Args:
            g (ig.Graph): The input graph. Must have a 'weight' attribute for edges.
            percentile (float): The percentile to use as the threshold for pruning edges.

        Returns:
            ig.Graph: A new graph with edges removed based on the specified percentile.

        Raises:
            ValueError: If the input graph has no 'weight' attribute for edges.
        """
        # Check if 'weight' attribute exists
        if "weight" not in g.es.attributes():
            raise ValueError("Input graph must have a 'weight' attribute for edges.")

        # Get initial number of edges and isolates
        initial_edge_count = g.ecount()
        initial_isolates = len(g.vs.select(_degree=0))

        # Get all weights and calculate the specified percentile
        weights = g.es["weight"]
        weight_threshold = np.percentile(weights, percentile)

        # Identify edges to keep
        edges_to_keep = [
            edge.index for edge in g.es if edge["weight"] > weight_threshold
        ]
        threshold_edges = [
            edge.index for edge in g.es if edge["weight"] == weight_threshold
        ]

        # Randomly select from threshold edges to reach target number of edges
        target_edge_count = int(initial_edge_count * (1 - percentile / 100))
        edges_to_add = target_edge_count - len(edges_to_keep)
        if edges_to_add > 0:
            random.shuffle(threshold_edges)
            edges_to_keep.extend(threshold_edges[:edges_to_add])

        # Create a new graph with only the selected edges
        g_pruned = g.subgraph_edges(edges_to_keep, delete_vertices=False)

        # Get final number of edges and isolates
        final_edge_count = g_pruned.ecount()
        final_isolates = len(g_pruned.vs.select(_degree=0))

        # Print results
        print(f"Pruning edges by weight percentile: {percentile}%")
        print("-" * 20)
        print(f"Number of edges before: {initial_edge_count}")
        print(f"Number of edges after: {final_edge_count}")
        print(f"Number of isolates before: {initial_isolates}")
        print(f"Number of isolates after: {final_isolates}")

        return g_pruned

    def bundle_edges(self) -> Optional[pd.DataFrame]:
        """
        Perform edge bundling on the graph.

        Returns:
            Optional[pd.DataFrame]: A DataFrame containing the bundled edges,
            or None if an error occurs.
        """
        g_pruned = self.prune_edges_by_percentile_weight(
            self.graph, self.pruning_weight_percentile
        )
        self.graph = g_pruned

        print("Starting edge bundling process...")

        try:
            df_nodes = pd.DataFrame(
                {
                    "x": self.graph.vs["x"],
                    "y": self.graph.vs["y"],
                    "z": self.graph.vs["z"],
                    "cluster": self.graph.vs["cluster"],
                }
            )
            edges_df = pd.DataFrame(
                {
                    "source": [e.source for e in self.graph.es],
                    "target": [e.target for e in self.graph.es],
                    "edge_id": self.graph.es["edge_id"],
                    "weight": self.graph.es["weight"],
                }
            )
            bundled_edges = hammer_bundle(df_nodes, edges_df, **self.bundle_kwargs)
            bundled_edges = pd.DataFrame(
                bundled_edges, columns=["x", "y", "edge_id", "weight"]
            )
            self.bundled_edges = self._group_bundled_edges(bundled_edges)
            return self.bundled_edges, g_pruned
        except Exception as e:
            print(f"An error occurred during edge bundling: {e}")
            return None

    def _group_bundled_edges(self, bundled_edges: pd.DataFrame) -> pd.DataFrame:
        """
        Group the bundled edges by edge_id and include source and target positions.

        Args:
            bundled_edges (pd.DataFrame): DataFrame containing the bundled edges.

        Returns:
            pd.DataFrame: A DataFrame with grouped bundled edges including source and target positions.
        """

        def _get_node_positions(node_id):
            return {
                "x": self.graph.vs[node_id]["x"],
                "y": self.graph.vs[node_id]["y"],
            : self.graph.vs[node_id]["z"],
            }

        grouped = bundled_edges.groupby("edge_id")
        result = pd.DataFrame(
            {
                "source": [
                    self.graph.es.find(edge_id=eid).source for eid in grouped.groups
                ],
                "target": [
                    self.graph.es.find(edge_id=eid).target for eid in grouped.groups
                ],
                "x": [group["x"].values for _, group in grouped],
                "y": [group["y"].values for _, group in grouped],
                "weight": grouped["weight"].first(),
            },
            index=grouped.groups.keys(),
        )

        # Convert the DataFrame's index into a column named 'edge_id'
        result.reset_index(inplace=True)
        result.rename(columns={"index": "edge_id"}, inplace=True)

        # Add source and target positions
        result["source_position"] = result["source"].apply(_get_node_positions)
        result["target_position"] = result["target"].apply(_get_node_positions)

        return result

SyntaxError: invalid syntax (2395667157.py, line 170)

## Interpolation of Z coordinates

add z coordinates to the positions of bundled edges using cubic spline interpolation.


In [7]:
class EdgeZInterpolator:
    def __init__(self, pruned_bundled_edges_2d, graph, centrality_scale_factor=0.2):
        """
        Initialize the EdgeZInterpolator.

        Args:
            pruned_bundled_edges_2d (pd.DataFrame): The DataFrame containing bundled edges.
            graph (ig.Graph): The original graph object.
        """
        self.pruned_bundled_edges_2d = pruned_bundled_edges_2d
        self.graph = graph
        self.adjusted_edges_3d = None

    def interpolate_z_to_edges(self):
        """
        Interpolate z-coordinates for each edge in the pruned_bundled_edges_2d DataFrame.
        """
        self.adjusted_edges_3d = self.pruned_bundled_edges_2d.apply(
            self._interpolate_z, axis=1
        )
        self.adjusted_edges_3d = pd.concat(
            [self.pruned_bundled_edges_2d, self.adjusted_edges_3d], axis=1
        )
        print("Initial Z Coordinates added to edges")
        return self.adjusted_edges_3d

    def _interpolate_z(self, row):
        """
        Interpolate z-coordinates for a single edge.

        Args:
            row (pd.Series): A pandas Series representing a single row in the pruned_bundled_edges_2d DataFrame.

        Returns:
            pd.Series: A pandas Series containing the x, y, and interpolated z-coordinates.
        """
        x = np.array(row["x"])
        y = np.array(row["y"])

        source_z = row["source_position"]["z"]
        target_z = row["target_position"]["z"]

        distances = np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2)
        cumulative_distances = np.cumsum(distances)

        if cumulative_distances.size == 0 or cumulative_distances[-1] == 0:
            num_points = len(x)
            z = np.linspace(source_z, target_z, num_points)
        else:
            t = np.insert(cumulative_distances, 0, 0) / cumulative_distances[-1]
            cs = CubicSpline([0, 1], [source_z, target_z])
            z = cs(t)

        return pd.Series(
            {"interpolated_x": x, "interpolated_y": y, "interpolated_z": z}
        )

## 3D Bundling


In [8]:
class Apply3DEdgeBundling:
    """Applies 3D edge bundling to a graph's edges.
    bundled_edges: DataFrame containing edge data with interpolated 3D coordinates.
    bundling_iterations: Number of iterations for the bundling algorithm.More iterations result in more bundling.
    step_size: Controls the magnitude of force application in each iteration. Smaller steps provide more stable bundling but require more iterations for the same effect.
    compatibility_threshold: Threshold for determining edge compatibility (not used in current implementation). High: More strict compatibility, fewer edges are bundled together.
    smoothing_iterations: Number of smoothing passes applied to each edge. High: More smoothing iterations create smoother curves but may lose some detail. Low: Fewer smoothing iterations preserve more original path details but may result in jagged edges.
    neighbor_radius: Radius for finding neighboring points, 'auto' for automatic inference. High: Larger radius considers more distant points, potentially leading to more global bundling. Low: Smaller radius only considers nearby points, resulting in more local bundling.
    radius_multiplier: Used to adjust the automatically inferred radius. High: Increases the automatically inferred neighbor radius, considering more distant points. Low: Decreases the automatically inferred neighbor radius, focusing on more local interactions.
    n_jobs: Number of CPU cores to use for parallel processing.
    points: Array of 3D coordinates for all edge points.
    neighbor_indices: Array of indices of neighboring points for each point.
    neighbor_counts: Array of neighbor counts for each point.
    point_to_edge: Array mapping each point to its corresponding edge index.
    """

    def __init__(
        self,
        bundled_edges,
        bundling_iterations=20,
        step_size=0.3,
        compatibility_threshold=0.3,
        smoothing_iterations=5,
        neighbor_radius="auto",
        radius_multiplier=0.2,
    ):
        self.bundled_edges = bundled_edges
        self.bundling_iterations = bundling_iterations
        self.step_size = step_size
        self.compatibility_threshold = compatibility_threshold
        self.smoothing_iterations = smoothing_iterations
        self.neighbor_radius = neighbor_radius
        self.radius_multiplier = radius_multiplier
        self.n_jobs = min(4, max(1, cpu_count() - 2))  # Limit to 4 processes
        self.adjusted_edges = None

        # Set up logging
        logging.basicConfig(
            level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
        )
        self.logger = logging.getLogger(__name__)

    def _infer_neighbor_radius(self, points):
        try:
            min_coords = np.min(points, axis=0)
            max_coords = np.max(points, axis=0)
            diagonal = np.linalg.norm(max_coords - min_coords)

            base_percentage = 0.05
            point_count_factor = np.log10(len(points)) / 10
            adjusted_percentage = base_percentage / (1 + point_count_factor)

            inferred_radius = diagonal * adjusted_percentage * self.radius_multiplier

            self.logger.info(f"Diagonal of bounding box: {diagonal}")
            self.logger.info(f"Adjusted percentage: {adjusted_percentage:.6f}")
            self.logger.info(f"Inferred radius: {inferred_radius:.6f}")

            return inferred_radius
        except Exception as e:
            self.logger.error(f"Error in _infer_neighbor_radius: {str(e)}")
            raise

    @staticmethod
    @jit(nopython=True, parallel=True)
    def _apply_forces(
        points, neighbor_indices, neighbor_counts, point_to_edge, step_size
    ):
        new_points = points.copy()
        for i in prange(1, len(points) - 1):
            edge_index = point_to_edge[i]
            force = np.zeros(3)
            count = 0
            for j in range(neighbor_counts[i]):
                n = neighbor_indices[i, j]
                if n != -1 and point_to_edge[n] != edge_index:
                    direction = points[n] - points[i]
                    distance = np.linalg.norm(direction)
                    if distance > 0:
                        force += direction / distance
                    count += 1
            if count > 0:
                force_magnitude = np.linalg.norm(force)
                if force_magnitude > 0:
                    new_points[i] += step_size * (force / force_magnitude)
        return new_points

    @staticmethod
    @jit(nopython=True)
    def _smooth_edge(edge_points, smoothing_iterations):
        for _ in range(smoothing_iterations):
            new_points = edge_points.copy()
            new_points[1:-1] = 0.5 * edge_points[1:-1] + 0.25 * (
                edge_points[:-2] + edge_points[2:]
            )
            edge_points = new_points
        return edge_points

    def apply_3d_bundling(self):
        try:
            self.adjusted_edges = self.bundled_edges.copy()

            # Convert interpolated x, y, z to points
            self.adjusted_edges["points"] = self.adjusted_edges.apply(
                lambda row: np.column_stack(
                    (
                        row["interpolated_x"],
                        row["interpolated_y"],
                        row["interpolated_z"],
                    )
                ),
                axis=1,
            )

            all_points = np.vstack(self.adjusted_edges["points"].values)

            point_to_edge = np.repeat(
                np.arange(len(self.adjusted_edges)),
                self.adjusted_edges["points"].apply(len),
            )

            if self.neighbor_radius == "auto":
                self.neighbor_radius = self._infer_neighbor_radius(all_points)
                self.logger.info(
                    f"Inferred neighbor radius: {self.neighbor_radius:.4f}"
                )

            tree = cKDTree(all_points)
            neighbors_list = tree.query_ball_point(all_points, r=self.neighbor_radius)

            max_neighbors = max(len(n) for n in neighbors_list)
            neighbor_indices = np.full(
                (len(all_points), max_neighbors), -1, dtype=np.int64
            )
            neighbor_counts = np.zeros(len(all_points), dtype=np.int64)

            for i, neighbors in enumerate(neighbors_list):
                neighbor_counts[i] = len(neighbors)
                neighbor_indices[i, : len(neighbors)] = neighbors

            self.logger.info(
                f"Starting 3D edge bundling with {self.bundling_iterations} iterations"
            )
            self.logger.info(f"Using {self.n_jobs} CPU cores for parallel processing")

            with tqdm(
                total=self.bundling_iterations, desc="3D Bundling Progress"
            ) as pbar:
                for iteration in range(self.bundling_iterations):
                    iteration_start_time = time.time()

                    all_points = self._apply_forces(
                        all_points,
                        neighbor_indices,
                        neighbor_counts,
                        point_to_edge,
                        self.step_size,
                    )

                    edge_points = np.split(
                        all_points,
                        np.cumsum(self.adjusted_edges["points"].apply(len))[:-1],
                    )
                    with Pool(self.n_jobs) as pool:
                        smoothed_edges = pool.starmap(
                            self._smooth_edge,
                            [(edge, self.smoothing_iterations) for edge in edge_points],
                        )
                    all_points = np.concatenate(smoothed_edges)

                    iteration_time = time.time() - iteration_start_time
                    self.logger.info(
                        f"Iteration {iteration + 1}/{self.bundling_iterations} completed in {iteration_time:.2f}s"
                    )
                    pbar.update(1)

            # Create new columns for bundled coordinates
            self.adjusted_edges["bundled_x"] = None
            self.adjusted_edges["bundled_y"] = None
            self.adjusted_edges["bundled_z"] = None

            # Update the adjusted_edges with new bundled coordinates
            start = 0
            for i, length in enumerate(self.adjusted_edges["points"].apply(len)):
                self.adjusted_edges.at[i, "bundled_x"] = all_points[
                    start : start + length, 0
                ].tolist()
                self.adjusted_edges.at[i, "bundled_y"] = all_points[
                    start : start + length, 1
                ].tolist()
                self.adjusted_edges.at[i, "bundled_z"] = all_points[
                    start : start + length, 2
                ].tolist()
                start += length

            self.logger.info("3D edge bundling applied successfully")
        except Exception as e:
            self.logger.error(f"Error in apply_3d_bundling: {str(e)}")
            raise

    def get_bundled_edges(self):
        return self.adjusted_edges

## EdgesSaver


In [186]:
class EdgesSaver:
    """
    A utility class for saving graph data to JSON format, particularly for use in JavaScript applications.
    """

    @staticmethod
    def add_color_bool_to_edges(
        bundled_edges_3d: pd.DataFrame, g: ig.Graph
    ) -> pd.DataFrame:
        """
        Add color boolean if source and target node are of the same cluster.
        """
        bundled_edges_3d["color"] = False
        for idx, edge in bundled_edges_3d.iterrows():
            source_cluster = g.vs[edge["source"]]["cluster"]
            target_cluster = g.vs[edge["target"]]["cluster"]
            if source_cluster == target_cluster:
                bundled_edges_3d.at[idx, "color"] = True

        print(
            f"{bundled_edges_3d['color'].sum()} out of {len(bundled_edges_3d)} edges have the same source and target cluster."
        )
        return bundled_edges_3d

    @staticmethod
    def inspect_nan_edges(bundled_edges_3d: pd.DataFrame, x_col, y_col, z_col):
        """
        Inspect and print information about edges containing NaN values.
        Args:
            bundled_edges_3d (pd.DataFrame): DataFrame containing adjusted edge data.
        """

        # Function to check if any element in a list is NaN
        def has_nan(lst):
            return any(pd.isna(x) for x in lst)

        # Filter rows where any of bundled_x, bundled_y, or bundled_z contains a NaN
        nan_edges = bundled_edges_3d[
            bundled_edges_3d[x_col].apply(has_nan)
            | bundled_edges_3d[y_col].apply(has_nan)
            | bundled_edges_3d[z_col].apply(has_nan)
            | pd.isna(bundled_edges_3d["weight"])
        ]

        if nan_edges.empty:
            print("No edges with NaN values found.")
        else:
            print(f"Found {len(nan_edges)} edges with NaN values:")
            for idx, edge in nan_edges.iterrows():
                print(f"Edge ID: {edge['edge_id']}")
                print(f"  Source: {edge['source']}, Target: {edge['target']}")
                print(f"  Weight: {edge['weight']}")
                print("  NaN positions:")
                for i, (x, y, z) in enumerate(
                    zip(edge["x"], edge["y"], edge["bundled_z"])
                ):
                    if pd.isna(x) or pd.isna(y) or pd.isna(z):
                        print(f"    Point {i}: x={x}, y={y}, z={z}")
                print()

    @staticmethod
    def prepare_edges_for_js(
        bundled_edges_3d: pd.DataFrame, x_col, y_col, z_col
    ) -> List[Dict]:
        """
        Prepare adjusted edges data for efficient use in JavaScript.
        Args:
            bundled_edges_3d (pd.DataFrame): DataFrame containing adjusted edge data.
        Returns:
            List[Dict]: List of edge objects ready for JSON serialization.
        Raises:
            ValueError: If bundled_edges_3d is None.
        """
        if bundled_edges_3d is None:
            raise ValueError("Adjusted edges data is not available.")

        # First, inspect edges with NaN values
        EdgesSaver.inspect_nan_edges(bundled_edges_3d, x_col, y_col, z_col)

        # Then proceed with the rest of the method
        edges_for_js = []
        for _, edge in bundled_edges_3d.iterrows():
            edge_object = {
                "id": int(edge["edge_id"]),
                "source": int(edge["source"]),
                "target": int(edge["target"]),
                "weight": (
                    float(edge["weight"]) if not pd.isna(edge["weight"]) else None
                ),
                "colored": bool(edge["color"]) if "color" in edge else False,
                "points": [
                    {"x": float(x), "y": float(y), "z": float(z)}
                    for x, y, z in zip(edge[x_col], edge[y_col], edge[z_col])
                    if not (pd.isna(x) or pd.isna(y) or pd.isna(z))
                ],
            }
            edges_for_js.append(edge_object)
        return edges_for_js

    @staticmethod
    def save_edges_for_js(
        bundled_edges_3d: pd.DataFrame,
        output_files: Union[str, List[str]],
        add_color_bool: bool = False,
        g: ig.Graph = None,
        return_json: bool = False,
        x_col: str = "x",
        y_col: str = "y",
        z_col: str = "z",
    ) -> Optional[List[Dict]]:
        """
        Save adjusted edges to one or more JSON files optimized for JavaScript use.
        Args:
            bundled_edges_3d (pd.DataFrame): DataFrame containing adjusted edge data.
            output_files (Union[str, List[str]]): Path or list of paths to the output JSON file(s).
            add_color_bool (bool): If True, add color boolean based on cluster information.
            g (ig.Graph): Graph object required if add_color_bool is True.
            return_json (bool): If True, return the JSON data as well as saving it.
        Returns:
            Optional[List[Dict]]: List of edge objects if return_json is True, else None.
        """
        if add_color_bool:
            if not g:
                raise ValueError("Graph object is required to add color boolean.")
            bundled_edges_3d = EdgesSaver.add_color_bool_to_edges(bundled_edges_3d, g)

        edges_data = EdgesSaver.prepare_edges_for_js(
            bundled_edges_3d, x_col, y_col, z_col
        )

        # Convert single path to list for consistent processing
        if isinstance(output_files, str):
            output_files = [output_files]

        # Save to all specified paths
        for output_file in output_files:
            with open(output_file, "w") as f:
                json.dump(edges_data, f)
            print(f"Edges data saved to {output_file}")

        return edges_data if return_json else None

## Nodes Saver


In [10]:
import json
import igraph as ig
from typing import List, Dict, Optional, Union


class NodesSaver:
    """
    A utility class for saving graph data to JSON format, particularly for use in JavaScript applications.
    """

    @staticmethod
    def save_igraph_nodes_to_json(
        g: ig.Graph,
        paths: Union[str, List[str]],
        return_json: bool = False,
        attributes: List[str] = None,
    ) -> Optional[List[Dict]]:
        """
        Save the igraph nodes to one or more JSON files.
        Args:
            g (ig.Graph): The input graph.
            paths (Union[str, List[str]]): Path or list of paths to save the JSON file(s).
            return_json (bool): If True, return the JSON data as well as saving it.
            attributes (List[str]): List of node attributes to include in the JSON.
        Returns:
            Optional[List[Dict]]: List of node dictionaries if return_json is True, else None.
        Raises:
            ValueError: If a specified attribute is missing from a node.
        """
        if attributes is None:
            attributes = [
                "node_id",
                "node_name",
                "doi",
                "year",
                "title",
                "cluster",
                "centrality",
                "x",
                "y",
                "z",
            ]

        # Fix encoding of titles
        g.vs["title"] = [NodesSaver.fix_encoding(title) for title in g.vs["title"]]

        nodes_json = []
        for node in g.vs:
            if not all(attr in node.attributes() for attr in attributes):
                raise ValueError(f"Missing attribute in node: {node.attributes()}")
            node_dict = {attr: node[attr] for attr in attributes}
            nodes_json.append(node_dict)

        # Convert single path to list for consistent processing
        if isinstance(paths, str):
            paths = [paths]

        # Save to all specified paths
        for path in paths:
            with open(path, "w") as f:
                json.dump(nodes_json, f)
            print(f"Graph nodes saved to {path}")

        return nodes_json if return_json else None

    @staticmethod
    def fix_encoding(title: str) -> str:
        """
        Fix the encoding of a string.
        Args:
            title (str): The input string to fix.
        Returns:
            str: The fixed string.
        """
        try:
            decoded_title = title.encode("utf-8").decode("unicode_escape")
            return decoded_title.encode("latin1").decode("utf-8")
        except UnicodeEncodeError:
            # If the above method fails, return the original title
            return title

## VisualizationUtility


In [11]:
class VisualizationUtility:
    @staticmethod
    def plot_graph_with_bundled_edges(g, bundled_edges, **kwargs):
        """
        Plot the graph with bundled edges.
        Args:
        g (igraph.Graph or networkx.Graph): The graph object containing node positions and cluster information.
        bundled_edges (pd.DataFrame): DataFrame containing the bundled edge coordinates.
        **kwargs: Additional keyword arguments for customizing the plot.
            figsize (tuple): Figure size in inches. Default is (10, 10).
            node_size (int): Size of the nodes in the scatter plot. Default is 10.
            edge_alpha (float): Alpha (transparency) of the edges. Default is 0.2.
            edge_width (float): Width of the edge lines. Default is 0.2.
            node_alpha (float): Alpha (transparency) of the nodes. Default is 0.7.
            edge_color (str): Color of the edges. Default is "black".
            cmap (str): Colormap for the nodes. Default is "tab20".
        Returns:
        None: Displays the plot.
        """
        # Default values
        defaults = {
            "figsize": (10, 10),
            "node_size": 10,
            "edge_alpha": 0.2,
            "edge_width": 0.2,
            "node_alpha": 0.7,
            "edge_color": "black",
            "cmap": "tab20",
        }
        # transform the graph if not igraph
        if not isinstance(g, ig.Graph):
            g = ig.Graph.from_networkx(g)
            print("Converted to igraph Graph")

        # Update defaults with any provided kwargs
        defaults.update(kwargs)

        plt.figure(figsize=defaults["figsize"])

        # Plot edges
        plt.plot(
            bundled_edges["x"],
            bundled_edges["y"],
            color=defaults["edge_color"],
            alpha=defaults["edge_alpha"],
            linewidth=defaults["edge_width"],
        )

        # Get unique clusters and map them to consecutive integers
        unique_clusters = sorted(set(g.vs["cluster"]))
        cluster_map = {c: i for i, c in enumerate(unique_clusters)}

        # Map cluster values to consecutive integers
        cluster_colors = [cluster_map[c] for c in g.vs["cluster"]]

        # Create a custom colormap
        cmap = plt.get_cmap(defaults["cmap"])
        n_colors = len(unique_clusters)
        custom_cmap = cmap(np.linspace(0, 1, n_colors))

        # Plot nodes
        scatter = plt.scatter(
            g.vs["x"],
            g.vs["y"],
            s=defaults["node_size"],
            c=cluster_colors,
            cmap=cmap,
            alpha=defaults["node_alpha"],
        )

        # Add a colorbar
        # plt.colorbar(scatter, label="Cluster", ticks=range(len(unique_clusters)))
        # plt.clim(-0.5, len(unique_clusters) - 0.5)

        plt.axis("off")
        plt.tight_layout()
        plt.show()
        # Draw the network

# FULL RUN

1. read graph
2. layout
3. prune edges
4. bundle edges
5. visualize
6. adjust for 3d plotting
   1. add z coordinate to nodes
   2. add z coordinate to bundled edges
7. use extra 3d bundling step
8. save nodes and edges for 3d plotting


## 2D Steps


In [12]:
g = GraphReader.read_and_clean_graph(INPUT_GRAPH_PATH)

cluster_list = list(range(40, 51))

# subset to only cluster 0 to 100
g = GraphReader.subgraph_of_clusters(g, cluster_list)

total_nodes = len(g.vs)
################################################################################################
layout_params = {
    # "k": 0.5, # distance between nodes; best to leave it to algo
    "iterations": 20,  # (default=50) use 100
    "threshold": 0.0001,  # default 0.0001
    "weight": "weight",
    "scale": 1,
    "center": (0, 0),
    "dim": 2,
    "seed": 1887,
}

g_fr, pos = LayoutUtility.fr_layout_nx(g, layout_params)

print("#" * 100)
print("Layout done")
print("#" * 100)

################################################################################################

g_fr_z = ZCoordinateAdder(g_fr, scale_factor=0.15).add_z_coordinate_to_nodes()
print("#" * 100)
print("Z coordinate added to nodes")
print("#" * 100)
################################################################################################
bundle_kwargs = {
    "decay": 0.90,
    "initial_bandwidth": 0.10,
    "iterations": 15,
    "include_edge_id": True,
}

bundler = GraphBundler2d(
    g_fr_z, pruning_weight_percentile=75, bundle_kwargs=bundle_kwargs
)
pruned_bundled_edges_2d, g_fr_z_bundled_pruned = bundler.bundle_edges()

print("#" * 100)
print("Edge bundling done")
print("#" * 100)

################################################################################################
# if total_nodes < 5000:
#    VisualizationUtility.plot_graph_with_bundled_edges(
#        g_fr_z_bundled_pruned, pruned_bundled_edges_2d
#    )

################################################################################################


interpolator = EdgeZInterpolator(pruned_bundled_edges_2d, g_fr_z_bundled_pruned)
adjusted_edges_3d = interpolator.interpolate_z_to_edges()

print("#" * 100)
print("Z interpolation done")
print("#" * 100)
################################################################################################

  g = ig.Graph.Read_GraphML(path)


Node Attributes: ['doi', 'year', 'title', 'cluster', 'node_id', 'node_name', 'centrality']
Edge Attributes: ['weight', 'edge_id']
Number of nodes: 40643
Number of edges: 602779
Starting Fruchterman-Reingold layout process...
Layout parameters: {'iterations': 20, 'threshold': 0.0001, 'weight': 'weight', 'scale': 1, 'center': (0, 0), 'dim': 2, 'seed': 1887}
Converting to NetworkX Graph...
Conversion complete.
Graph has 12292 nodes and 132380 edges.
Calculating layout...
Layout calculation completed in 101.01 seconds.
Processing layout results...
Layout boundaries:
X-axis: Min = -1.00, Max = 0.90
Y-axis: Min = -0.99, Max = 0.82
Assigning coordinates to nodes...
Layout process completed in 101.49 seconds.
####################################################################################################
Layout done
####################################################################################################
Bounds of the layout:
Min x: -1.0, Max x: 0.8989654779434204
Min y: -0.9858

In [13]:
adjusted_edges_3d.head(2)

Unnamed: 0,edge_id,source,target,x,y,weight,source_position,target_position,interpolated_x,interpolated_y,interpolated_z
0,50.0,0,1,"[0.3307619094848633, 0.34947332739830017]","[0.19845418632030487, 0.30729520320892334]",0.663864,"{'x': 0.3307619094848633, 'y': 0.1984541863203...","{'x': 0.34947332739830017, 'y': 0.307295203208...","[0.3307619094848633, 0.34947332739830017]","[0.19845418632030487, 0.30729520320892334]","[-0.13823016870468868, -0.1340834649785878]"
1,102.0,2,35,"[-0.4124625027179718, -0.39503922612022846, -0...","[-0.4567071497440338, -0.40076967835872734, -0...",0.80623,"{'x': -0.4124625027179718, 'y': -0.45670714974...","{'x': -0.36125415563583374, 'y': -0.4384048879...","[-0.4124625027179718, -0.39503922612022846, -0...","[-0.4567071497440338, -0.40076967835872734, -0...","[-0.1310213731838211, -0.13087355803861248, -0..."


## 3D edge bundling


In [14]:
# Usage
params = {
    "bundling_iterations": 10,
    "step_size": 0.3,
    "compatibility_threshold": 0.6,
    "smoothing_iterations": 5,
    "neighbor_radius": "auto",
    "radius_multiplier": 0.2,
}

try:
    bundler_3d = Apply3DEdgeBundling(adjusted_edges_3d, **params)
    bundler_3d.apply_3d_bundling()
    bundled_edges_3d = bundler_3d.get_bundled_edges()
except Exception as e:
    logging.error(f"An error occurred during 3D bundling: {str(e)}")

2024-07-24 14:49:36,102 - INFO - Diagonal of bounding box: 2.6338850009501926
2024-07-24 14:49:36,102 - INFO - Adjusted percentage: 0.031992
2024-07-24 14:49:36,102 - INFO - Inferred radius: 0.016853
2024-07-24 14:49:36,103 - INFO - Inferred neighbor radius: 0.0169
2024-07-24 14:50:20,344 - INFO - Starting 3D edge bundling with 10 iterations
2024-07-24 14:50:20,349 - INFO - Using 4 CPU cores for parallel processing
3D Bundling Progress:   0%|          | 0/10 [00:00<?, ?it/s]2024-07-24 14:51:08,626 - INFO - Iteration 1/10 completed in 48.19s
3D Bundling Progress:  10%|█         | 1/10 [00:48<07:13, 48.19s/it]2024-07-24 14:51:40,528 - INFO - Iteration 2/10 completed in 31.90s
3D Bundling Progress:  20%|██        | 2/10 [01:20<05:08, 38.61s/it]2024-07-24 14:52:17,253 - INFO - Iteration 3/10 completed in 36.72s
3D Bundling Progress:  30%|███       | 3/10 [01:56<04:24, 37.75s/it]2024-07-24 14:52:52,526 - INFO - Iteration 4/10 completed in 35.27s
3D Bundling Progress:  40%|████      | 4/10 [

In [178]:
# Save to multiple paths
nodes_json = NodesSaver.save_igraph_nodes_to_json(
    g_fr_z_bundled_pruned,
    [
        OUTPUT_DIR + f"nodes_3d_clusters{min(cluster_list)}to{max(cluster_list)}.json",
        THREEJS_OUTPUT_DIR
        + f"nodes_3d_clusters{min(cluster_list)}to{max(cluster_list)}.json",
    ],
    return_json=True,
)

# print first 2 nodes
nodes_json[:2]

Graph nodes saved to ../data/99-testdata/nodes_3d_clusters45to50.json
Graph nodes saved to /Users/jlq293/Projects/Random Projects/LW-ThreeJS/2d_ssrinetworkviz/src/data/nodes_3d_clusters45to50.json


[{'node_id': 6,
  'node_name': 'Horita_1982',
  'doi': '',
  'year': 1982,
  'title': 'Centrally administered thyrotropin-releasing hormone (TRH) stimulates colonic transit and diarrhea production by a vagally mediated serotonergic mechanism in the rabbit',
  'cluster': 48,
  'centrality': 0.136328538692653,
  'x': 0.6405702233314514,
  'y': -0.3191435635089874,
  'z': -0.10927152627688637},
 {'node_id': 8,
  'node_name': 'Mcelroy_1982',
  'doi': '10.1007/BF00432770',
  'year': 1982,
  'title': 'The effects of fenfluramine and fluoxetine on the acquisition of a conditioned avoidance response in rats',
  'cluster': 50,
  'centrality': 0.0151485334617365,
  'x': 0.5508721470832825,
  'y': -0.4072713851928711,
  'z': -0.14564939253508402}]

In [16]:
edges_json = EdgesSaver.save_edges_for_js(
    bundled_edges_3d,
    [
        OUTPUT_DIR
        + f"bundled_edges_3d_clusters{min(cluster_list)}to{max(cluster_list)}.json",
        THREEJS_OUTPUT_DIR
        + f"bundled_edges_3d_clusters{min(cluster_list)}to{max(cluster_list)}.json",
    ],
    add_color_bool=True,
    g=g_fr_z_bundled_pruned,
    return_json=True,
)

edges_json[99]

31719 out of 33095 edges have the same source and target cluster.
No edges with NaN values found.
Edges data saved to ../data/99-testdata/bundled_edges_3d_clusters20to50.json
Edges data saved to /Users/jlq293/Projects/Random Projects/LW-ThreeJS/2d_ssrinetworkviz/src/data/bundled_edges_3d_clusters20to50.json


{'id': 9282,
 'source': 55,
 'target': 628,
 'weight': 0.809315085411072,
 'colored': True,
 'points': [{'x': 0.1797104785509199,
   'y': 0.08640227269980194,
   'z': -0.11921884171978157},
  {'x': 0.16762777694242395,
   'y': 0.07244703881857949,
   'z': -0.13564365842741866},
  {'x': 0.15903036949242866,
   'y': 0.06342433086241353,
   'z': -0.15484917390021652},
  {'x': 0.15385792925439534,
   'y': 0.061205861293863566,
   'z': -0.1751324490140186},
  {'x': 0.15009907378869408,
   'y': 0.0664796134829248,
   'z': -0.19236211035300899},
  {'x': 0.14717861149056716,
   'y': 0.08213241384961667,
   'z': -0.20430724700450514},
  {'x': 0.14578842711877277,
   'y': 0.11159344556639125,
   'z': -0.21088572610528544},
  {'x': 0.14567282979197266,
   'y': 0.1536696058033668,
   'z': -0.21200172298069808},
  {'x': 0.14628638132851618,
   'y': 0.20140796591181168,
   'z': -0.20667695328035182},
  {'x': 0.1497297300315264,
   'y': 0.24635118507521314,
   'z': -0.19244827127885808},
  {'x': 0.15

# FAILED QUALITY CHECK


In [90]:
correct_int = 0
false_int = 0
correct_bund = 0
false_bund = 0
correct_both = 0
false_both = 0
for i, row in bundled_edges_3d.iterrows():
    s_x = round(row["source_position"]["x"], 7)
    x_int = round(row["interpolated_x"][0], 7)
    x_bund = round(row["bundled_x"][0], 7)
    s_y = round(row["source_position"]["y"], 7)
    y_int = round(row["interpolated_y"][0], 7)
    y_bund = round(row["bundled_y"][0], 7)
    s_z = round(row["source_position"]["z"], 7)
    z_int = round(row["interpolated_z"][0], 7)
    z_bund = round(row["bundled_z"][0], 7)

    if s_x == x_int and s_y == y_int and s_z == z_int:
        correct_int += 1
    else:
        false_int += 1

    if s_x == x_bund and s_y == y_bund and s_z == z_bund:
        correct_bund += 1
    else:
        false_bund += 1

    if (
        s_x == x_int
        and s_y == y_int
        and s_z == z_int
        and s_x == x_bund
        and s_y == y_bund
        and s_z == z_bund
    ):
        correct_both += 1
    else:
        false_both += 1


print("correct_int: ", correct_int)
print("false_int: ", false_int)

print("correct_bund: ", correct_bund)
print("false_bund: ", false_bund)

print("correct_both: ", correct_both)
print("false_both: ", false_both)

correct_int:  33095
false_int:  0
correct_bund:  5810
false_bund:  27285
correct_both:  5810
false_both:  27285


In [282]:
g = GraphReader.read_and_clean_graph(INPUT_GRAPH_PATH)

cluster_list = list(range(45, 51))

# subset to only cluster 0 to 100
g = GraphReader.subgraph_of_clusters(g, cluster_list)

total_nodes = len(g.vs)
################################################################################################
layout_params = {
    # "k": 0.5, # distance between nodes; best to leave it to algo
    "iterations": 20,  # (default=50) use 100
    "threshold": 0.0001,  # default 0.0001
    "weight": "weight",
    "scale": 1,
    "center": (0, 0),
    "dim": 2,
    "seed": 1887,
}

g_fr, pos = LayoutUtility.fr_layout_nx(g, layout_params)

print("#" * 100)
print("Layout done")
print("#" * 100)

################################################################################################

g_fr_z = ZCoordinateAdder(g_fr, scale_factor=0.15).add_z_coordinate_to_nodes()
print("#" * 100)
print("Z coordinate added to nodes")
print("#" * 100)
################################################################################################

bundle_kwargs = {
    "decay": 0.95,
    "initial_bandwidth": 0.05,
    "iterations": 50,
    "include_edge_id": True,
}
bundler = GraphBundler2d(
    g_fr_z, pruning_weight_percentile=75, bundle_kwargs=bundle_kwargs
)
pruned_bundled_edges_2d, g_fr_z_bundled_pruned = bundler.bundle_edges()

print("#" * 100)
print("Edge bundling done")
print("#" * 100)

################################################################################################
# if total_nodes < 5000:
#    VisualizationUtility.plot_graph_with_bundled_edges(
#        g_fr_z_bundled_pruned, pruned_bundled_edges_2d
#    )

################################################################################################
#
#
# interpolator = EdgeZInterpolator(pruned_bundled_edges_2d, g_fr_z_bundled_pruned)
# adjusted_edges_3d = interpolator.interpolate_z_to_edges()
#
# print("#" * 100)
# print("Z interpolation done")
# print("#" * 100)
#################################################################################################

  g = ig.Graph.Read_GraphML(path)


Node Attributes: ['doi', 'year', 'title', 'cluster', 'node_id', 'node_name', 'centrality']
Edge Attributes: ['weight', 'edge_id']
Number of nodes: 40643
Number of edges: 602779
Starting Fruchterman-Reingold layout process...
Layout parameters: {'iterations': 20, 'threshold': 0.0001, 'weight': 'weight', 'scale': 1, 'center': (0, 0), 'dim': 2, 'seed': 1887}
Converting to NetworkX Graph...
Conversion complete.
Graph has 1976 nodes and 17666 edges.
Calculating layout...
Layout calculation completed in 4.33 seconds.
Processing layout results...
Layout boundaries:
X-axis: Min = -0.82, Max = 1.00
Y-axis: Min = -0.88, Max = 0.80
Assigning coordinates to nodes...
Layout process completed in 4.39 seconds.
####################################################################################################
Layout done
####################################################################################################
Bounds of the layout:
Min x: -0.819740355014801, Max x: 1.0
Min y: -0.88231164216

In [283]:
pruned_bundled_edges_2d.head(1)  # , g_fr_z_bundled_pruned

Unnamed: 0,edge_id,source,target,x,y,weight,source_position,target_position
0,149.0,1,280,"[0.5508721470832825, 0.55393686033313, 0.55205...","[-0.4072713851928711, -0.3896731439439783, -0....",0.936576,"{'x': 0.5508721470832825, 'y': -0.407271385192...","{'x': 0.445087194442749, 'y': -0.2985744476318..."


In [284]:
pruned_bundled_edges_2d.source_position.apply(lambda x: x["z"])

0      -0.145649
1      -0.144170
2      -0.120307
3      -0.120307
4      -0.120307
          ...   
4412   -0.130054
4413   -0.034908
4414   -0.034908
4415   -0.000343
4416   -0.090827
Name: source_position, Length: 4417, dtype: float64

In [285]:
import numpy as np
import pandas as pd


class BundleQualityChecker:
    @staticmethod
    def connection_quality_check(
        df,
        source_col="source_position",
        target_col="target_position",
        x_col="x",
        y_col="y",
        z_col="z",
        tolerance=1e-6,
    ):
        """
        Perform quality checks on the bundled edges, accounting for positive and negative coordinates.

        Parameters:
        df (pd.DataFrame): The dataframe containing edge data
        source_col (str): Name of the column containing source node positions
        target_col (str): Name of the column containing target node positions
        x_col (str): Name of the column containing bundled x coordinates
        y_col (str): Name of the column containing bundled y coordinates
        z_col (str): Name of the column containing bundled z coordinates
        tolerance (float): Tolerance for floating point comparisons

        Returns:
        dict: A dictionary containing the results of various checks
        """
        results = {
            "total_edges": len(df),
            "start_point_mismatches": 0,
            "end_point_mismatches": 0,
            "invalid_z_interpolations": 0,
            "high_z_connection_count": df["is_high_z_connection"].sum(),
        }

        for _, row in df.iterrows():
            # Check start point
            if (
                abs(row[x_col][0] - row[source_col]["x"]) > tolerance
                or abs(row[y_col][0] - row[source_col]["y"]) > tolerance
                or abs(row[z_col][0] - row[source_col]["z"]) > tolerance
            ):
                results["start_point_mismatches"] += 1

            # Check end point
            if (
                abs(row[x_col][-1] - row[target_col]["x"]) > tolerance
                or abs(row[y_col][-1] - row[target_col]["y"]) > tolerance
                or abs(row[z_col][-1] - row[target_col]["z"]) > tolerance
            ):
                results["end_point_mismatches"] += 1

            # Check z interpolation
            z_min, z_max = min(row[source_col]["z"], row[target_col]["z"]), max(
                row[source_col]["z"], row[target_col]["z"]
            )
            if any(z < z_min - tolerance or z > z_max + tolerance for z in row[z_col]):
                results["invalid_z_interpolations"] += 1

        # Calculate percentages
        total = results["total_edges"]
        results["start_point_mismatch_percentage"] = (
            results["start_point_mismatches"] / total
        ) * 100
        results["end_point_mismatch_percentage"] = (
            results["end_point_mismatches"] / total
        ) * 100
        results["invalid_z_interpolation_percentage"] = (
            results["invalid_z_interpolations"] / total
        ) * 100
        results["high_z_connection_percentage"] = (
            results["high_z_connection_count"] / total
        ) * 100

        return results

    @staticmethod
    def perform_quality_check(bundled_edges_3d):
        if bundled_edges_3d is None:
            print("No bundled edges available.")
            return None

        results = BundleQualityChecker.connection_quality_check(bundled_edges_3d)

        print("Quality Check Results:")
        print(f"Total edges: {results['total_edges']}")
        print(
            f"Start point mismatches: {results['start_point_mismatches']} ({results['start_point_mismatch_percentage']:.2f}%)"
        )
        print(
            f"End point mismatches: {results['end_point_mismatches']} ({results['end_point_mismatch_percentage']:.2f}%)"
        )
        print(
            f"Invalid z interpolations: {results['invalid_z_interpolations']} ({results['invalid_z_interpolation_percentage']:.2f}%)"
        )
        print(
            f"High z connections: {results['high_z_connection_count']} ({results['high_z_connection_percentage']:.2f}%)"
        )

        return results

    @staticmethod
    def analyze_edge_points(bundled_edges_3d):
        if bundled_edges_3d is None:
            print("No bundled edges available. Run adjust_bundling_for_3d first.")
            return None

        point_counts = bundled_edges_3d["x"].apply(len)

        analysis = {
            "min_points": point_counts.min(),
            "max_points": point_counts.max(),
            "mean_points": point_counts.mean(),
            "median_points": point_counts.median(),
        }

        print("Edge Point Analysis:")
        print(f"Minimum points per edge: {analysis['min_points']}")
        print(f"Maximum points per edge: {analysis['max_points']}")
        print(f"Mean points per edge: {analysis['mean_points']:.2f}")
        print(f"Median points per edge: {analysis['median_points']}")

        return analysis

In [286]:
class EdgeBundler3d:
    def __init__(
        self,
        bundled_edges_2d,
        z_threshold_percentile=75,
        vertical_influence=0.8,
    ):
        self.bundled_edges_2d = bundled_edges_2d
        self.z_threshold_percentile = z_threshold_percentile
        self.vertical_influence = vertical_influence
        self.bundled_edges_3d = None

    def define_z_threshold(self, z_values):
        z_threshold = np.percentile(z_values, self.z_threshold_percentile)
        print(f"Z threshold set to {z_threshold:.4f}")
        return z_threshold

    def adjust_bundling_for_3d(self):
        z_values = pd.concat(
            [
                self.bundled_edges_2d.source_position.apply(lambda x: x["z"]),
                self.bundled_edges_2d.target_position.apply(lambda x: x["z"]),
            ]
        )
        z_threshold = self.define_z_threshold(z_values)

        adjusted_edges = []
        for _, row in tqdm(
            self.bundled_edges_2d.iterrows(), total=len(self.bundled_edges_2d)
        ):
            edge_id = row["edge_id"]
            source = row["source_position"]
            target = row["target_position"]
            source_z = source["z"]
            target_z = target["z"]
            x = np.array(row["x"])
            y = np.array(row["y"])

            num_points = len(x)
            is_high_z_connection = (source_z > z_threshold and source_z > target_z) or (
                target_z > z_threshold and target_z > source_z
            )

            # Determine high and low nodes
            if source_z > target_z:
                high_node, low_node = source, target
            else:
                high_node, low_node = target, source

            # Create control points for cubic spline
            if is_high_z_connection:
                # For high-z connections, create a smooth vertical path
                mid_x = (source["x"] + target["x"]) / 2
                mid_y = (source["y"] + target["y"]) / 2
                control_points_t = [0, 0.25, 0.75, 1]
                control_points_x = [source["x"], mid_x, mid_x, target["x"]]
                control_points_y = [source["y"], mid_y, mid_y, target["y"]]
            else:
                # For non-high-z connections, use more of the original path
                control_points_t = [0, 1 / 3, 2 / 3, 1]
                control_points_x = [
                    source["x"],
                    x[num_points // 3],
                    x[2 * num_points // 3],
                    target["x"],
                ]
                control_points_y = [
                    source["y"],
                    y[num_points // 3],
                    y[2 * num_points // 3],
                    target["y"],
                ]

            # Apply cubic spline interpolation with natural boundary conditions
            cs_x = CubicSpline(control_points_t, control_points_x, bc_type="natural")
            cs_y = CubicSpline(control_points_t, control_points_y, bc_type="natural")

            # Use the original number of points, but ensure it's at least 20 for smoother curves
            t = np.linspace(0, 1, num_points)
            x_adj = cs_x(t)
            y_adj = cs_y(t)

            # Create a smooth z-coordinate transition

            z_min, z_max = min(source_z, target_z), max(source_z, target_z)

            if is_high_z_connection:
                # For high-z connections, create a more pronounced vertical effect
                z_mid = (z_min + z_max) / 2
                z_control = [z_min, z_max, z_max, z_max]
                cs_z = CubicSpline(control_points_t, z_control, bc_type="clamped")
                z = np.clip(cs_z(t), z_min, z_max)
            else:
                # For non-high-z connections, use linear interpolation
                z = np.interp(t, [0, 1], [z_min, z_max])

            adjusted_edges.append(
                {
                    "edge_id": edge_id,
                    "bundled_x": x_adj.tolist(),
                    "bundled_y": y_adj.tolist(),
                    "bundled_z": z.tolist(),
                    "is_high_z_connection": is_high_z_connection,
                }
            )

        # Create a new dataframe with the adjusted edges
        adjusted_df = pd.DataFrame(adjusted_edges)

        # Merge the new dataframe with the original one
        merged_df = pd.merge(
            self.bundled_edges_2d, adjusted_df, on="edge_id", how="left"
        )

        # Rename columns to avoid confusion
        merged_df = merged_df.rename(
            columns={
                "x": "original_x",
                "y": "original_y",
                "bundled_x": "x",
                "bundled_y": "y",
                "bundled_z": "z",
            }
        )

        self.bundled_edges_3d = merged_df
        return self.bundled_edges_3d

    # Add this method to your GraphBundler3d class

In [287]:
bundler_3d = EdgeBundler3d(pruned_bundled_edges_2d)
adjusted_edges_3d = bundler_3d.adjust_bundling_for_3d()

Z threshold set to -0.0193


100%|██████████| 4417/4417 [00:01<00:00, 2466.46it/s]


In [288]:
adjusted_edges_3d

Unnamed: 0,edge_id,source,target,original_x,original_y,weight,source_position,target_position,x,y,z,is_high_z_connection
0,149.0,1,280,"[0.5508721470832825, 0.55393686033313, 0.55205...","[-0.4072713851928711, -0.3896731439439783, -0....",0.936576,"{'x': 0.5508721470832825, 'y': -0.407271385192...","{'x': 0.445087194442749, 'y': -0.2985744476318...","[0.5508721470832825, 0.5482098857147208, 0.544...","[-0.4072713851928711, -0.38927841208412234, -0...","[-0.1460439847414049, -0.14600811272264846, -0...",False
1,271.0,2,7,"[0.49992355704307556, 0.5130205154418945]","[-0.5801289677619934, -0.6044682264328003]",0.816939,"{'x': 0.49992355704307556, 'y': -0.58012896776...","{'x': 0.5130205154418945, 'y': -0.604468226432...","[0.49992355704307556, 0.5130205154418945]","[-0.5801289677619934, -0.6044682264328003]","[-0.1448336408037463, -0.1441700028383177]",False
2,971.0,3,16,"[0.5524348020553589, 0.5497690907852348, 0.544...","[-0.34138861298561096, -0.339813561933337, -0....",0.795334,"{'x': 0.5524348020553589, 'y': -0.341388612985...","{'x': 0.45631155371665955, 'y': -0.40058809518...","[0.5524348020553589, 0.5496931334538281, 0.545...","[-0.34138861298561096, -0.3414993881606116, -0...","[-0.13980900556810782, -0.13785882390783016, -...",False
3,974.0,3,44,"[0.5524348020553589, 0.5585797046720011, 0.566...","[-0.34138861298561096, -0.3394194623550637, -0...",0.806165,"{'x': 0.5524348020553589, 'y': -0.341388612985...","{'x': 0.662406861782074, 'y': -0.4320113360881...","[0.5524348020553589, 0.5597469018915837, 0.567...","[-0.34138861298561096, -0.3394893537229342, -0...","[-0.1436823801075662, -0.14108513664731787, -0...",False
4,981.0,3,549,"[0.5524348020553589, 0.5483290342278686, 0.541...","[-0.34138861298561096, -0.3477689670685117, -0...",0.792766,"{'x': 0.5524348020553589, 'y': -0.341388612985...","{'x': 0.49700018763542175, 'y': -0.46284896135...","[0.5524348020553589, 0.5482820949181224, 0.541...","[-0.34138861298561096, -0.34846635889273353, -...","[-0.12947644557359098, -0.12794823613888104, -...",False
...,...,...,...,...,...,...,...,...,...,...,...,...
4412,601809.0,1931,1970,"[-0.34231740236282343, -0.354872689493168, -0....","[-0.21710354089736938, -0.21300441973668005, -...",0.800884,"{'x': -0.3423174023628235, 'y': -0.21710354089...","{'x': -0.43638426065444946, 'y': -0.2356306910...","[-0.3423174023628235, -0.35194551355679315, -0...","[-0.21710354089736938, -0.2147058240132259, -0...","[-0.13220869745589156, -0.13190087264187106, -...",False
4413,602031.0,1936,1954,"[0.5319066643714905, 0.5140461921691895]","[0.2123633772134781, 0.2381635308265686]",0.663164,"{'x': 0.5319066643714905, 'y': 0.2123633772134...","{'x': 0.5140461921691895, 'y': 0.2381635308265...","[0.5319066643714905, 0.5140461921691895]","[0.2123633772134781, 0.23816353082656858]","[-0.10930934920714531, -0.034908049376908085]",False
4414,602032.0,1936,1975,"[0.5319066643714905, 0.5406255788234311, 0.548...","[0.2123633772134781, 0.2014645253000802, 0.191...",0.666620,"{'x': 0.5319066643714905, 'y': 0.2123633772134...","{'x': 0.5727753043174744, 'y': 0.3141528069972...","[0.5319066643714905, 0.5406255788234311, 0.548...","[0.2123633772134781, 0.2014645253000802, 0.191...","[-0.10072348897801676, -0.07878500911098053, -...",False
4415,602314.0,1949,1959,"[0.48814910650253296, 0.46257609128952026]","[-0.20887938141822815, -0.11553271114826202]",0.672519,"{'x': 0.48814910650253296, 'y': -0.20887938141...","{'x': 0.46257609128952026, 'y': -0.11553271114...","[0.48814910650253296, 0.46257609128952026]","[-0.20887938141822815, -0.11553271114826202]","[-0.08850872864036753, -0.0003430996737725345]",True


In [289]:
checker = BundleQualityChecker()
checker.perform_quality_check(adjusted_edges_3d)

checker.analyze_edge_points(adjusted_edges_3d)

Quality Check Results:
Total edges: 4417
Start point mismatches: 2500 (56.60%)
End point mismatches: 2500 (56.60%)
Invalid z interpolations: 0 (0.00%)
High z connections: 1744 (39.48%)
Edge Point Analysis:
Minimum points per edge: 2
Maximum points per edge: 75
Mean points per edge: 8.30
Median points per edge: 7.0


{'min_points': 2,
 'max_points': 75,
 'mean_points': 8.300656554222323,
 'median_points': 7.0}

In [290]:
adjusted_edges_3d.columns

Index(['edge_id', 'source', 'target', 'original_x', 'original_y', 'weight',
       'source_position', 'target_position', 'x', 'y', 'z',
       'is_high_z_connection'],
      dtype='object')

In [294]:
adjusted_edges_3d

test_edge_df = adjusted_edges_3d.copy()
test_edge_df["x"] = test_edge_df["original_x"]
test_edge_df["y"] = test_edge_df["original_y"]
test_edge_df["z"] = test_edge_df["original_y"].apply(lambda x: [1] * len(x))

In [304]:
adjusted_edges_3d.head(2)

Unnamed: 0,edge_id,source,target,original_x,original_y,weight,source_position,target_position,x,y,z,is_high_z_connection
0,149.0,1,280,"[0.5508721470832825, 0.55393686033313, 0.55205...","[-0.4072713851928711, -0.3896731439439783, -0....",0.936576,"{'x': 0.5508721470832825, 'y': -0.407271385192...","{'x': 0.445087194442749, 'y': -0.2985744476318...","[0.5508721470832825, 0.5482098857147208, 0.544...","[-0.4072713851928711, -0.38927841208412234, -0...","[-0.1460439847414049, -0.14600811272264846, -0...",False
1,271.0,2,7,"[0.49992355704307556, 0.5130205154418945]","[-0.5801289677619934, -0.6044682264328003]",0.816939,"{'x': 0.49992355704307556, 'y': -0.58012896776...","{'x': 0.5130205154418945, 'y': -0.604468226432...","[0.49992355704307556, 0.5130205154418945]","[-0.5801289677619934, -0.6044682264328003]","[-0.1448336408037463, -0.1441700028383177]",False


In [312]:
import numpy as np
from itertools import combinations
from tqdm import tqdm
from scipy.spatial import cKDTree


def find_common_coordinates(adjusted_edges_3d, tolerance=1e-7):
    # Extract x and y coordinates, excluding first and last points
    xs = [x[1:-1] for x in adjusted_edges_3d["original_x"].values]
    ys = [y[1:-1] for y in adjusted_edges_3d["original_y"].values]

    # Combine all points into a single array and keep track of edge indices
    all_points = []
    edge_indices = []
    for i, (x, y) in enumerate(zip(xs, ys)):
        points = np.column_stack((x, y))
        all_points.append(points)
        edge_indices.extend([i] * len(points))

    all_points = np.vstack(all_points)
    edge_indices = np.array(edge_indices)

    # Build a KD-tree for efficient nearest neighbor search
    tree = cKDTree(all_points)

    # Find pairs of points within the tolerance
    pairs = tree.query_pairs(r=tolerance)

    # Filter pairs to only include points from different edges
    common_coords = []
    for p1, p2 in pairs:
        if edge_indices[p1] != edge_indices[p2]:
            i, j = edge_indices[p1], edge_indices[p2]
            idx1 = p1 - np.sum(edge_indices < i)
            idx2 = p2 - np.sum(edge_indices < j)
            common_coords.append((i, j, idx1, idx2))

    return common_coords


# Usage
common_coordinates = find_common_coordinates(adjusted_edges_3d)
# for i, j, idx1, idx2 in common_coordinates:
#    print(
#        f"Edge {i} and Edge {j} have a common coordinate at indices {idx1} and {idx2}"
#    )
len(common_coordinates)

1133

In [None]:
common_coordinates

1419

In [299]:
# check for common points amongst edges

xs = adjusted_edges_3d["original_x"]

In [300]:
xs

0       [0.5508721470832825, 0.55393686033313, 0.55205...
1               [0.49992355704307556, 0.5130205154418945]
2       [0.5524348020553589, 0.5497690907852348, 0.544...
3       [0.5524348020553589, 0.5585797046720011, 0.566...
4       [0.5524348020553589, 0.5483290342278686, 0.541...
                              ...                        
4412    [-0.34231740236282343, -0.354872689493168, -0....
4413             [0.5319066643714905, 0.5140461921691895]
4414    [0.5319066643714905, 0.5406255788234311, 0.548...
4415           [0.48814910650253296, 0.46257609128952026]
4416    [-0.46193641424179077, -0.447839639490751, -0....
Name: original_x, Length: 4417, dtype: object

In [296]:
# save to multiple paths
nodes_json = NodesSaver.save_igraph_nodes_to_json(
    g_fr_z_bundled_pruned,
    [
        OUTPUT_DIR + f"nodes_3d_clusters{min(cluster_list)}to{max(cluster_list)}.json",
        THREEJS_OUTPUT_DIR
        + f"nodes_3d_clusters{min(cluster_list)}to{max(cluster_list)}.json",
    ],
    return_json=True,
)


edges_json = EdgesSaver.save_edges_for_js(
    test_edge_df,
    [
        OUTPUT_DIR
        + f"bundled_edges_3d_clusters{min(cluster_list)}to{max(cluster_list)}.json",
        THREEJS_OUTPUT_DIR
        + f"bundled_edges_3d_clusters{min(cluster_list)}to{max(cluster_list)}.json",
    ],
    add_color_bool=True,
    g=g_fr_z_bundled_pruned,
    return_json=True,
)

Graph nodes saved to ../data/99-testdata/nodes_3d_clusters45to50.json
Graph nodes saved to /Users/jlq293/Projects/Random Projects/LW-ThreeJS/2d_ssrinetworkviz/src/data/nodes_3d_clusters45to50.json
4409 out of 4417 edges have the same source and target cluster.
No edges with NaN values found.
Edges data saved to ../data/99-testdata/bundled_edges_3d_clusters45to50.json
Edges data saved to /Users/jlq293/Projects/Random Projects/LW-ThreeJS/2d_ssrinetworkviz/src/data/bundled_edges_3d_clusters45to50.json


# 3D Visualization


In [257]:
class GraphProcessor3d:
    def __init__(self, G: Union[nx.Graph, ig.Graph]):
        self.g = self._ensure_igraph(G)

    @staticmethod
    def _ensure_igraph(G: Union[nx.Graph, ig.Graph]) -> ig.Graph:
        if isinstance(G, nx.Graph):
            return ig.Graph.from_networkx(G)
        return G

    def apply_layout(self, layout_name: str = "fruchterman_reingold_3d", **kwargs):
        """
        Apply a layout to the graph. For Fruchterman-Reingold 3D, available kwargs are:
        - dim: The dimension of the layout (default: 3)
        - weights: Edge weights to be used. Can be a list or the name of an edge attribute.
        - niter: The number of iterations to perform (default: 500)
        - start_temp: The starting temperature (default: 10)
        - seed: Random seed to use (default: None)
        """
        if layout_name == "fruchterman_reingold_3d":
            # Set default values
            layout_kwargs = {
                "dim": 3,
                "weights": "weight",
                "niter": 500,
                "start_temp": 10,
                "seed": None,
            }
            # Update with provided kwargs
            layout_kwargs.update(kwargs)

            layout = self.g.layout_fruchterman_reingold_3d(**layout_kwargs)
        else:
            layout = self.g.layout(layout_name, **kwargs)

        for i, coords in enumerate(layout):
            self.g.vs[i]["x"] = coords[0]
            self.g.vs[i]["y"] = coords[1]
            self.g.vs[i]["z"] = coords[2]

        print(f"Layout {layout_name} applied to the graph.")
        print("Coordinates stored in node attributes.")
        return self.g


# example usage
gp3d = GraphProcessor3d(g)
g = gp3d.apply_layout("fruchterman_reingold_3d")

Layout fruchterman_reingold_3d applied to the graph.
Coordinates stored in node attributes.


# color and label assignment for full graph


In [18]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.colors import to_hex
from matplotlib.colors import to_rgb

In [19]:
class ClusterColorAssigner:
    """
    TO DO:
    1. clusters not mutually exclusive - need to assign to multiple categories
    2. colors too similar - need to assign more distinct colors

    A class for assigning colors to clusters based on their characteristics.

    This class provides methods to categorize clusters into color palettes,
    assign specific colors within those palettes, and create a mapping
    between clusters and their assigned colors.

    Attributes:
        colormaps (dict): A dictionary mapping color names to matplotlib colormaps.
        condition_names (dict): A dictionary mapping color names to condition names.

    Methods:
        assign_color_categories(clust_hierarchy): Assigns color categories to clusters.
        print_color_mapping(): Prints the mapping of conditions to color palettes.
        assign_colors(df, colormap): Assigns specific colors to clusters within a palette.
        create_color_dataframes(clust_hierarchy): Creates separate DataFrames for each color category.
        process_cluster_hierarchy(clust_hierarchy): Processes the entire cluster hierarchy.

    Usage:
        color_assigner = ClusterColorAssigner()
        processed_hierarchy, color_dict = color_assigner.process_cluster_hierarchy(cluster_hierarchy_df)
    """

    def __init__(self):
        self.colormaps = {
            "blue": plt.get_cmap("Blues"),
            "red": plt.get_cmap("Reds"),
            "green": plt.get_cmap("Greens"),
            "purple": plt.get_cmap("Purples"),
        }
        self.condition_names = {
            "blue": "pharmacology",
            "red": "indications",
            "green": "safety",
            "purple": "other",
        }

    def assign_color_categories(self, clust_hierarchy):
        conditions = [
            clust_hierarchy["pharmacology"] == 1,
            clust_hierarchy["indications"] == 1,
            clust_hierarchy["safety"] == 1,
            clust_hierarchy["other"] == 1,
        ]
        choices = ["blue", "red", "green", "purple"]
        clust_hierarchy["color_pal"] = np.select(conditions, choices, default="")
        return clust_hierarchy

    def print_color_mapping(self):
        print("Mapping of conditions to color palettes:")
        for color, condition in self.condition_names.items():
            print(f"{condition.capitalize()}: {color}")

    @staticmethod
    def assign_colors(df, colormap):
        num_colors = df.shape[0]
        colors = [to_hex(colormap(x)) for x in np.linspace(0.1, 0.9, num_colors)]
        df["color"] = colors
        return df

    def create_color_dataframes(self, clust_hierarchy):
        color_dfs = {}
        for color_name, colormap in self.colormaps.items():
            df_color = clust_hierarchy[
                clust_hierarchy["color_pal"] == color_name
            ].copy()
            if not df_color.empty:
                color_dfs[color_name] = self.assign_colors(df_color, colormap)
        return color_dfs

    def process_cluster_hierarchy(self, clust_hierarchy):
        clust_hierarchy = self.assign_color_categories(clust_hierarchy)
        self.print_color_mapping()
        color_dfs = self.create_color_dataframes(clust_hierarchy)
        clust_hierarchy = pd.concat(color_dfs.values())
        cluster_color_dict = dict(
            zip(clust_hierarchy["cluster"], clust_hierarchy["color"])
        )
        return clust_hierarchy, cluster_color_dict

    def save_dict_to_json(self, dict, path):
        with open(path, "w") as f:
            json.dump(dict, f)
        print(f"Cluster color dictionary saved to {path}")

In [20]:
# Assuming clust_hierarchy is your input DataFrame

clust_hierarchy = pd.read_excel(CLUSTER_INFO_LABEL_TREE)

color_assigner = ClusterColorAssigner()
clust_hierarchy, cluster_color_dict = color_assigner.process_cluster_hierarchy(
    clust_hierarchy
)

color_assigner.save_dict_to_json(
    cluster_color_dict, OUTPUT_DIR + "cluster_color_dict.json"
)

cluster_label_dict = dict(
    zip(clust_hierarchy["cluster"], clust_hierarchy["clusterlabel"])
)

color_assigner.save_dict_to_json(cluster_label_dict, CLUSTER_LABEL_DICT_PATH)

print("\nCluster color dictionary (first 5 items):")
print(dict(list(cluster_color_dict.items())[:5]))
print("\nCluster label dictionary (first 5 items):")
print(dict(list(cluster_label_dict.items())[:5]))

Mapping of conditions to color palettes:
Pharmacology: blue
Indications: red
Safety: green
Other: purple
Cluster color dictionary saved to ../data/99-testdata/cluster_color_dict.json
Cluster color dictionary saved to ../data/99-testdata/cluster_label_dict.json

Cluster color dictionary (first 5 items):
{0: '#e3eef9', 2: '#dfebf7', 3: '#dbe9f6', 5: '#d6e6f4', 6: '#d3e3f3'}

Cluster label dictionary (first 5 items):
{0: 'Serotonin Receptor Studies', 2: 'Risks of Prenatal Exposure', 3: 'Quantification of SSRIs in Biological Samples', 5: 'SSRIs and the Cytochrome P450 System', 6: 'SSRI Neuroscience'}


# legend json creation


In [21]:
def transform_dict_to_legend(cluster_hierachy_dict, cluster_label_dict):
    """
    Adds the cluster labels to the cluster hierarchy dictionary to create a legend.
    """
    # make sure keys in cluster_label_dict are integers
    cluster_label_dict = {int(k): v for k, v in cluster_label_dict.items()}
    legend = cluster_hierachy_dict.copy()
    for key, value in legend.items():
        if isinstance(value, dict):
            # Recursively transform dictionaries
            transform_dict(value, cluster_label_dict)
        elif isinstance(value, list):
            new_list = []
            for item in value:
                if isinstance(item, int) and item in cluster_label_dict:
                    new_list.append({item: cluster_label_dict[item]})
                else:
                    new_list.append(item)
            legend[key] = new_list
    return legend


with open(CLUSTER_HIERACHY_FOR_LEGEND_PATH, "r") as f:
    cluster_hierachy_dict = json.load(f)

with open(CLUSTER_LABEL_DICT_PATH, "r") as f:
    cluster_label_dict = json.load(f)

legend = transform_dict_to_legend(cluster_hierachy_dict, cluster_label_dict)

NameError: name 'transform_dict' is not defined

In [None]:
legend

In [None]:
# Save as JSON
with open(OUTPUT_DIR + "legend_full_label_tree_clusternr.json", "w") as json_file:
    json.dump(cat_tree, json_file, indent=4)

    # Save as JSON
with open(
    THREEJS_OUTPUT_DIR + "legend_tree.json",
    "w",
) as json_file:
    json.dump(cat_tree, json_file, indent=4)