In [None]:
node_shape: str = "box"
filter_to_relevant: str = ""

In [None]:
import graphviz
from code_data_science import data_table as dt
import code_data_science.palette as palette
import html

df = dt.read_csv("../samples/spring_component_relationships.csv")

In [None]:
import networkx as nx

G = nx.from_pandas_edgelist(
    df, "dependantType", "dependencyType", create_using=nx.DiGraph()
)

relevant_nodes = set()
ancestor_nodes = set()
descendant_nodes = set()

if filter_to_relevant:
    for node in G.nodes:
        if filter_to_relevant in node:
            relevant_nodes.add(node)  # Include the node itself
            # Add all descendants and ancestors of the node
            descendant_nodes.update(nx.descendants(G, node))
            ancestor_nodes.update(nx.ancestors(G, node))

In [None]:
# Continue with your Graphviz setup
graphviz.set_jupyter_format("svg")
dot = graphviz.Digraph("spring-relationships", comment="Spring component relationships")

dot.graph_attr = {
    "overlap": "prism",
    "normalize": "true",
    "overlap_scaling": "100",
    "nodesep": "1",
}

added_nodes = set()
added_edges = set()


def map_relationship(row):
    # Check if dependantType node already added
    if row["dependantType"] not in added_nodes:
        if filter_to_relevant != "":
            if row["dependantType"] in relevant_nodes:
                dot.node(
                    row["dependantType"],
                    shape=node_shape,
                    style="filled",
                    fillcolor=palette.__moderneColorMap["red"][200],
                )
                added_nodes.add(row["dependantType"])
            if row["dependantType"] in ancestor_nodes:
                dot.node(
                    row["dependantType"],
                    shape=node_shape,
                    style="filled",
                    fillcolor=palette.__moderneColorMap["blue"][200],
                )
                added_nodes.add(row["dependantType"])
        else:
            dot.node(
                row["dependantType"],
                shape=node_shape,
                style="filled",
                fillcolor=palette.__moderneColorMap["blue"][200],
            )
            added_nodes.add(row["dependantType"])

    # Check if dependencyType node already added
    if row["dependencyType"] not in added_nodes:
        if filter_to_relevant != "":
            if row["dependencyType"] in relevant_nodes:
                dot.node(
                    row["dependencyType"],
                    shape=node_shape,
                    style="filled",
                    fillcolor=palette.__moderneColorMap["red"][200],
                )
                added_nodes.add(row["dependantType"])
            if row["dependencyType"] in descendant_nodes:
                dot.node(
                    row["dependencyType"],
                    shape=node_shape,
                    style="filled",
                    fillcolor=palette.__moderneColorMap["blue"][200],
                )
                added_nodes.add(row["dependencyType"])
        else:
            dot.node(
                row["dependencyType"],
                shape=node_shape,
                style="filled",
                fillcolor=palette.__moderneColorMap["blue"][200],
            )
            added_nodes.add(row["dependencyType"])

    # Add edge
    if (row["dependantType"], row["dependencyType"]) not in added_edges:
        if filter_to_relevant != "":
            if row["dependantType"] in descendant_nodes:
                dot.edge(row["dependantType"], row["dependencyType"])
                added_edges.add((row["dependantType"], row["dependencyType"]))
            if row["dependencyType"] in ancestor_nodes:
                dot.edge(row["dependantType"], row["dependencyType"])
                added_edges.add((row["dependantType"], row["dependencyType"]))
            if row["dependantType"] in relevant_nodes:
                dot.edge(row["dependantType"], row["dependencyType"])
                added_edges.add((row["dependantType"], row["dependencyType"]))
            if row["dependencyType"] in relevant_nodes:
                dot.edge(row["dependantType"], row["dependencyType"])
                added_edges.add((row["dependantType"], row["dependencyType"]))
        else:
            dot.edge(row["dependantType"], row["dependencyType"])
            added_edges.add((row["dependantType"], row["dependencyType"]))


# Apply mapping to each row
df.apply(map_relationship, axis=1)

# Display the graph
dot