In [None]:
from pyvis.network import Network
import os
import random

In [None]:
def visualize_triples(triples, output_file="knowledge_graph.html"):
    import os
    from pyvis.network import Network

    # Create network
    net = Network(
        height="1200px",
        width="100%",
        directed=True,
        notebook=False,
        bgcolor="#111111",
        font_color="white"
    )

    added_nodes = set()
    added_edges = set()

    type_colors = {}

    def get_color_for_type(node_type):
        if node_type not in type_colors:
            # generate a random but visually pleasant color
            type_colors[node_type] = "#{:06x}".format(random.randint(0x444444, 0xFFFFFF))
        return type_colors[node_type]


    for t in triples:
        head = t["head"]
        tail = t["tail"]
        head_type = t.get("head_type")
        tail_type = t.get("tail_type")

        head_cui = t.get("head_cui", "N/A")
        tail_cui = t.get("tail_cui", "N/A")

        relation = t["relation"]

        # --- Add head node ---
        if head not in added_nodes:
            net.add_node(
                head,
                label=head,  # or: f"{head}\n({head_cui})"
                title=f"<b>{head}</b><br>Type: {head_type}<br>CUI: {head_cui}",
                color=get_color_for_type(head_type),
                group=head_type
            )
            added_nodes.add(head)

        # --- Add tail node ---
        if tail not in added_nodes:
            net.add_node(
                tail,
                label=tail,  # or: f"{tail}\n({tail_cui})"
                title=f"<b>{tail}</b><br>Type: {tail_type}<br>CUI: {tail_cui}",
                color=get_color_for_type(tail_type),
                group=tail_type
            )
            added_nodes.add(tail)

        # --- Add edge once ---
        edge_key = (head, tail, relation)
        if edge_key not in added_edges:
            net.add_edge(
                head,
                tail,
                label=relation,
                title=relation
            )
            added_edges.add(edge_key)

    # Physics options
    net.set_options("""
    {
        "physics": {
            "forceAtlas2Based": {
                "gravitationalConstant": -50,
                "centralGravity": 0.01,
                "springLength": 150,
                "springConstant": 0.05
            },
            "minVelocity": 0.75,
            "solver": "forceAtlas2Based"
        }
    }
    """)

    # Save graph
    net.save_graph(output_file)
    print(f"Graph saved to: {os.path.abspath(output_file)}")

    try:
        import webbrowser
        webbrowser.open(f"file://{os.path.abspath(output_file)}")
    except:
        print("Could not open browser automatically")


In [3]:
import json

OUT_PATH = "/Users/cj2837/Documents/Courses/Project/outputs/std_nodes_triples_cleaned.json"

with open(OUT_PATH, "r") as f:
    triples = json.load(f)

print("Loaded", len(triples), "triples")


Loaded 1052 triples


In [None]:
visualize_triples(triples, output_file="/Users/cj2837/Documents/Courses/Project/graph/final_menopause_kg.html")
