# This notebook imports the semantic network built and visualizes it given a starting node in a interactive way

## 0. Imports and setup

In [27]:
# Cell 2: imports & config
import networkx as nx
import pandas as pd
from pyvis.network import Network
from IPython.display import IFrame, display
import numpy as np
from sentence_transformers import SentenceTransformer

# Files you exported previously
GEXF_PATH = "materials_semantic_network.gexf"     # or use GraphML if you prefer
# CSVs are not required for the viz, but handy to debug:
NODES_CSV = "materials_semantic_nodes.csv"
EDGES_CSV = "materials_semantic_edges.csv"

# Category colors
COLOR_BY_CATEGORY = {
    "source":      "#1f77b4",  # blue
    "function":    "#2ca02c",  # green
    "application": "#d62728",  # red
}

# Tolerance for floating-point comparisons
W1_TOL = 0.1


## 1. Load graph and normalize attributes

In [28]:
# Cell 3: load graph and normalize attributes
G = nx.read_gexf(GEXF_PATH)  # preserves node/edge attributes

def normalize_graph(G):
    # Standardize node attributes: ensure 'label' and 'category'
    for n, d in G.nodes(data=True):
        if "label" not in d or not str(d["label"]).strip():
            # fall back: strip the prefix if you used "type::label" as node id
            label_guess = str(n).split("::", 1)[-1]
            d["label"] = label_guess
        # some builds used 'type' instead of 'category'
        if "category" not in d or not str(d["category"]).strip():
            if "type" in d and str(d["type"]).strip():
                d["category"] = str(d["type"]).strip()
            else:
                # last fallback: try to infer from node id prefix
                if "::" in str(n):
                    d["category"] = str(n).split("::", 1)[0]
                else:
                    d["category"] = "unknown"
        # lower-case category for consistency
        d["category"] = str(d["category"]).lower().strip()

    # Standardize edge attributes: ensure 'edge_type' and 'weight'
    for u, v, d in G.edges(data=True):
        # some builds used 'kind' (e.g., 'cooc' or 'sim')
        et = d.get("edge_type") or d.get("kind") or ""
        et = str(et).lower().strip()
        # map short labels
        if et == "cooc":
            et = "cooccurrence"
        elif et == "sim":
            et = "similarity"
        d["edge_type"] = et if et else "unknown"

        if "weight" not in d:
            d["weight"] = 1.0
        else:
            try:
                d["weight"] = float(d["weight"])
            except Exception:
                d["weight"] = 1.0

normalize_graph(G)
print(f"Loaded graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")


Loaded graph: 6800 nodes, 36695 edges


## 2. sBERT setup for node search

In [29]:
# Pick the same model you used for intra-category edges for consistency
SBERT_MODEL = "all-MiniLM-L6-v2"  # or your previous choice
sbert = SentenceTransformer(SBERT_MODEL)

def _normalize(v):
    v = np.asarray(v, dtype=np.float32)
    n = np.linalg.norm(v, axis=-1, keepdims=True) + 1e-12
    return v / n

# Caches: node_ids, labels, and normalized embeddings by category
cat_cache = {
    "source":    {"nodes": [], "labels": [], "emb": None},
    "function":  {"nodes": [], "labels": [], "emb": None},
    "application":{"nodes": [], "labels": [], "emb": None},
}

def rebuild_category_embeddings(G):
    for cat in cat_cache.keys():
        nodes, labels = [], []
        for n, d in G.nodes(data=True):
            if d.get("category") == cat:
                nodes.append(n)
                labels.append(d.get("label", n))
        if labels:
            embs = sbert.encode(labels, batch_size=256, show_progress_bar=False)
            embs = _normalize(embs)
        else:
            embs = np.zeros((0, 384), dtype=np.float32)  # dimension depends on model
        cat_cache[cat]["nodes"]  = nodes
        cat_cache[cat]["labels"] = labels
        cat_cache[cat]["emb"]    = embs

rebuild_category_embeddings(G)
print({k: len(v["nodes"]) for k,v in cat_cache.items()})

{'source': 1314, 'function': 4144, 'application': 1342}


## 3. Helpers

In [32]:
# Cell 4: helpers for querying and filtering

def norm_text(s: str) -> str:
    return " ".join(str(s).strip().lower().split())

#def search_nodes(G, query, category=None, max_results=50):
#    """
#    Case-insensitive substring search on node 'label'.
#    Optionally filter by 'category' in {'source','function','application'}.
#    """
#    q = norm_text(query)
#    out = []
#    for n, d in G.nodes(data=True):
#        if category and d.get("category") != category:
#            continue
#        if q in norm_text(d.get("label", "")):
#            out.append(n)
#            if len(out) >= max_results:
#                break
#    return out

def search_nodes_semantic(G, query, category, top_k=1):
    """
    Find the most semantically similar node(s) to `query` within a given category
    using sBERT cosine similarity. Returns list of (node_id, label, score).
    """
    category = category.lower()
    if category not in cat_cache:
        raise ValueError("category must be one of {'source','function','application'}")

    labels = cat_cache[category]["labels"]
    embs   = cat_cache[category]["emb"]
    nodes  = cat_cache[category]["nodes"]

    if embs is None or len(nodes) == 0:
        return []

    q_emb = sbert.encode([query], show_progress_bar=False)
    q_emb = _normalize(q_emb)[0]  # shape (d,)

    # cosine similarity = dot product on normalized vectors
    sims = embs @ q_emb
    idx  = np.argpartition(-sims, min(top_k, len(sims)-1))[:top_k]
    idx  = idx[np.argsort(-sims[idx])]

    return [nodes[i] for i in idx]

def build_cooccurrence_only_graph(G):
    """
    Keep only edges where edge_type == 'cooccurrence' AND weight â‰ˆ 1.
    Remove all other edges.
    """
    H = nx.Graph()
    for n, d in G.nodes(data=True):
        H.add_node(n, **d)
    for u, v, d in G.edges(data=True):
        if d.get("edge_type") == "cooccurrence" and abs(d.get("weight", 1.0) - 1.0) <= W1_TOL:
            H.add_edge(u, v, **d)
    return H

def ego_subgraph_from_seeds(H, seeds, radius=2):
    """
    Union of ego graphs (radius hops) around each seed in H.
    """
    seeds = [s for s in seeds if s in H]
    if not seeds:
        return nx.Graph()  # empty
    S = nx.Graph()
    for s in seeds:
        E = nx.ego_graph(H, s, radius=radius)
        S = nx.compose(S, E)
    return S


## 3. Pick query and build

In [41]:
# Cell 5: set your query and (optionally) restrict to a category
QUERY_TEXT   = "3D-printing"       # <-- change this to your search term
QUERY_CAT    = "application"          # None, or "source"/"function"/"application"
EGO_RADIUS   = 2            # hops around the seed(s)

# 1) Find matching nodes
matches = search_nodes_semantic(G, QUERY_TEXT, category=QUERY_CAT, top_k=1)
print(f"Matches found ({len(matches)}): {[G.nodes[n]['label'] for n in matches[:10]]}")

# 2) Build co-occurrence-only graph
G_cooc = build_cooccurrence_only_graph(G)
print(f"Co-occurrence-only graph: {G_cooc.number_of_nodes()} nodes, {G_cooc.number_of_edges()} edges")

# 3) Extract an ego subgraph from matches (using only co-occurrence edges)
S = ego_subgraph_from_seeds(G_cooc, matches, radius=EGO_RADIUS)
print(f"Subgraph for viz: {S.number_of_nodes()} nodes, {S.number_of_edges()} edges")

# Optional pruning if the subgraph is too large
MAX_NODES = 300
if S.number_of_nodes() > MAX_NODES:
    # keep seeds + their neighbors up to degree sort
    keep = set(matches)
    # add neighbors until cap
    for m in matches:
        keep.update(list(S.neighbors(m)))
        if len(keep) >= MAX_NODES:
            break
    S = S.subgraph(list(keep)).copy()
    print(f"Pruned to {S.number_of_nodes()} nodes, {S.number_of_edges()} edges")


Matches found (1): ['high precision 3d printed objects']
Co-occurrence-only graph: 6800 nodes, 10753 edges
Subgraph for viz: 22 nodes, 26 edges


## 4. Interactive visualization

In [42]:
# Cell 6: Interactive visualization (Pyvis)
HTML_OUT = "materials_cooccurrence_query.html"

net = Network(height="800px", width="100%", bgcolor="#ffffff", notebook=True, directed=False, cdn_resources='in_line')
net.toggle_physics(False)
# You can fine-tune physics via options if needed:
# net.set_options('{"physics":{"solver":"forceAtlas2Based","forceAtlas2Based":{"gravitationalConstant":-50}}}')

# Add nodes with category colors and degree-based size
for n, d in S.nodes(data=True):
    label = d.get("label", str(n))
    cat   = d.get("category", "unknown")
    color = COLOR_BY_CATEGORY.get(cat, "#888888")
    size  = 10 + 2 * S.degree(n)

    # Tooltip title shows category
    title = f"{cat.capitalize()}: {label}"
    net.add_node(n, label=label, color=color, title=title, size=size)

# Add edges (all are co-occurrence weight=1)
for u, v, d in S.edges(data=True):
    et = d.get("edge_type", "cooccurrence")
    net.add_edge(u, v, value=1, color="#555555", title=et)

net.show(HTML_OUT)
print("Graph written:", HTML_OUT)

from pathlib import Path

overlay_html = """
<div id="legend-box" style="
  position:fixed; right:20px; bottom:20px;
  background:#fff; border:1px solid #ccc; border-radius:6px;
  padding:10px 12px; font: 13px/1.2 sans-serif; z-index: 999999;
  box-shadow: 0 2px 8px rgba(0,0,0,0.15);">
  <div style="font-weight:600; margin-bottom:6px;">Legend</div>
  <div><span style="color:#4e79a7;">&#9679;</span> Source</div>
  <div><span style="color:#59a14f;">&#9679;</span> Function</div>
  <div><span style="color:#e15759;">&#9679;</span> Application</div>
</div>
"""

html_path = Path(HTML_OUT)
html = html_path.read_text(encoding="utf-8")

# Insert the overlay just before </body>
html = html.replace("</body>", overlay_html + "\n</body>")

html_path.write_text(html, encoding="utf-8")
print("Legend injected into:", HTML_OUT)

IFrame(src=HTML_OUT, width="100%", height="820")

materials_cooccurrence_query.html
Graph written: materials_cooccurrence_query.html
Legend injected into: materials_cooccurrence_query.html
