In [2]:
import numpy as np
import torch
import networkx as nx
from torch_geometric.datasets import Planetoid
from sklearn.cluster import KMeans
from collections import Counter
from torch_geometric.utils import to_networkx
from transformers import AutoTokenizer, AutoModel
from collections import deque
from pyvis.network import Network
from pathlib import Path

dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
CORA_LABELS = [
    "Case_Based",
    "Genetic_Algorithms",
    "Neural_Networks",
    "Probabilistic_Methods",
    "Reinforcement_Learning",
    "Rule_Learning",
    "Theory"
]


In [4]:
G = to_networkx(data, to_undirected=False) 
print(type(G))
print("Nodes:", G.number_of_nodes())
print("Edges:", G.number_of_edges())


<class 'networkx.classes.digraph.DiGraph'>
Nodes: 2708
Edges: 10556


In [5]:
for node_id in range(data.y.shape[0]):
    y = int(data.y[node_id])
    G.nodes[node_id]["cora_label"] = CORA_LABELS[y]


In [6]:

model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
bert_model = AutoModel.from_pretrained(model_name)


In [7]:
def encode_with_distilbert(texts):
    """
    texts: list of strings
    returns: torch.Tensor of shape [len(texts), hidden_size]
    """
    enc = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    with torch.no_grad():
        outputs = bert_model(**enc)
        # DistilBERT has no pooler; use the [CLS]-token representation
        cls_embeddings = outputs.last_hidden_state[:, 0, :]  # [batch, hidden]
    return cls_embeddings


In [8]:
def node_to_text(node_id):
    """
    Convert the Cora feature vector for this node into a fake text string.
    Each active feature index becomes a token like 'w123'.
    """
    x_vec = data.x[node_id]  # shape [1433]

    # Find indices of features that are 'on'
    active_idx = (x_vec > 0).nonzero(as_tuple=True)[0].tolist()

    # Turn indices into tokens
    tokens = [f"w{i}" for i in active_idx]

    # Join into a pseudo-sentence
    text = " ".join(tokens) if tokens else "empty"

    return text

In [9]:
def crawl_citation_ladder(G: nx.DiGraph, start_node: int, max_depth: int = 2, max_nodes: int | None = None):
    """
    Crawl 'downward' through the citation graph starting from start_node.

    - Only follows successors (outgoing edges).
    - Uses BFS by depth.
    - Stops at depth = max_depth.
    - Optionally stops if number of visited nodes reaches max_nodes.

    Returns:
        subgraph: nx.DiGraph containing visited nodes and edges
        depths: dict {node: depth_level}
    """
    if start_node not in G:
        raise ValueError(f"Start node {start_node} is not in the graph.")

    # BFS queue and depth tracking
    queue = deque([start_node])
    depths = {start_node: 0}
    visited = {start_node}

    ladder_nodes = {start_node}
    ladder_edges = []

    while queue:
        u = queue.popleft()
        d = depths[u]

        # don't go deeper than max_depth
        if d >= max_depth:
            continue

        for v in G.successors(u):  # only go "down" along outgoing citations
            ladder_edges.append((u, v))
            ladder_nodes.add(v)

            if v not in visited:
                visited.add(v)
                depths[v] = d + 1

                # enforce max_nodes if provided
                if max_nodes is None or len(ladder_nodes) < max_nodes:
                    queue.append(v)

    # Build a subgraph with just these nodes and edges
    sub = nx.DiGraph()
    # copy node attributes from original graph
    for n in ladder_nodes:
        sub.add_node(n, **G.nodes[n])
    for u, v in ladder_edges:
        if u in sub and v in sub and G.has_edge(u, v):
            sub.add_edge(u, v, **G.edges[u, v])

    # attach depth info as a node attribute
    for n, depth in depths.items():
        if n in sub.nodes:
            sub.nodes[n]["depth"] = depth

    return sub, depths


In [10]:
COLOR_PALETTE = [
    "#66b3ff",  # blue
    "#99ff99",  # green
    "#ffcc66",  # yellow-orange
    "#cc99ff",  # purple
    "#ff99cc",  # pink
    "#aaffc3",  # mint
    "#ffd8b1",  # peach
    "#8dd3c7",  # teal
    "#bebada",  # lavender
    "#fb8072",  # salmon
]

def get_color_for_depth(depth: int):
    """
    Returns a color for a given depth.
    - depth 0 is always red.
    - depths 1..N cycle through the palette.
    """
    if depth == 0:
        return "#ff6666"  # red for start node
    
    # Cycle through palette for depths 1+
    idx = (depth - 1) % len(COLOR_PALETTE)
    return COLOR_PALETTE[idx]


In [11]:

def get_color_for_cluster(cluster_id: int | None):
    if cluster_id is None:
        return "#dddddd"
    idx = cluster_id % len(COLOR_PALETTE)
    return COLOR_PALETTE[idx]


In [12]:
def add_pyvis_legend(net, legend_items, x_start=-500, y_start=500, spacing=80):
    """
    legend_items = list of (label, color)
    Creates non-physics, fixed-position nodes for a legend box.
    """
    x = x_start
    y = y_start
    
    for label, color in legend_items:
        net.add_node(
            f"legend_{label}", 
            label=label,
            color=color,
            shape="box",
            physics=False,
            x=x,
            y=y,
            font={"size": 20},
        )
        y -= spacing


In [13]:
def visualize_ladder_with_pyvis(subgraph, start_node, html_file="citation_ladder.html"):
    net = Network(
        height="750px",
        width="100%",
        directed=True,
        notebook=False,
        cdn_resources="in_line",
    )

    # (optional) simple physics config
    net.set_options("""
    {
      "physics": {
        "enabled": true,
        "barnesHut": {
          "gravitationalConstant": -2500,
          "springLength": 170,
          "springConstant": 0.04
        }
      }
    }
    """)

    for n, data in subgraph.nodes(data=True):
        depth = data.get("depth")
        cora_label = data.get("cora_label", "Unknown")

        # Visible label on the node
        label = f"{n}"

        # Tooltip (HTML, nicer than \n)
        title = (
            f"Node: {n}\n"
            f"Depth: {depth}\n"
            f"Cora Label: {cora_label}"
        )
        
        # Color by depth
        color = get_color_for_depth(depth)

        net.add_node(
            n,
            label=label,
            title=title,
            color=color if n != start_node else "#ff6666",  # keep start red
            borderWidth=3 if n == start_node else 1,
        )

    # Add edges (directed)
    for u, v in subgraph.edges():
        net.add_edge(u, v, arrows="to")

    # Generate HTML and write as UTF-8 (Windows-safe)
    html_str = net.generate_html()
    Path(html_file).write_text(html_str, encoding="utf-8")
    print(f"Saved visualization to {html_file}")

In [14]:
def add_bert_embeddings_to_subgraph(subgraph):
    """
    For each node in the subgraph:
    - build pseudo-text via node_to_text
    - encode with DistilBERT
    - store embedding on the node as 'bert_embedding'
    """
    node_list = list(subgraph.nodes())
    texts = [node_to_text(n) for n in node_list]

    # encode_with_distilbert should return a torch.Tensor
    embs = encode_with_distilbert(texts)      # shape: [N, hidden]
    embs_np = embs.cpu().numpy()              # convert to numpy

    for nid, e in zip(node_list, embs_np):
        subgraph.nodes[nid]["bert_embedding"] = e


In [15]:
def describe_clusters(subgraph):
    cluster_label_counts = {}

    for n, data in subgraph.nodes(data=True):
        cid = data.get("bert_cluster", None)
        label = data.get("cora_label", "Unknown")

        if cid is None:
            continue

        cluster_label_counts.setdefault(cid, Counter())
        cluster_label_counts[cid][label] += 1

    # Pretty print
    for cid, counter in sorted(cluster_label_counts.items()):
        print(f"\n=== Cluster {cid} ===")
        total = sum(counter.values())
        for label, count in counter.most_common():
            pct = 100 * count / total
            print(f"  {label:25s} {count:3d} papers ({pct:4.1f}%)")

    return cluster_label_counts

cluster_stats = describe_clusters(ladder_subgraph)

NameError: name 'ladder_subgraph' is not defined

In [16]:
def add_bert_clusters_to_subgraph(subgraph, n_clusters: int | None = None):
    """
    Run KMeans on BERT embeddings of nodes in the subgraph.
    - Adds 'bert_cluster' attribute to each node.
    - n_clusters defaults to min(5, number of nodes).
    """
    node_list = list(subgraph.nodes())

    # collect embeddings
    embs = []
    for n in node_list:
        emb = subgraph.nodes[n].get("bert_embedding", None)
        if emb is None:
            raise ValueError(f"Node {n} is missing 'bert_embedding'. Run add_bert_embeddings_to_subgraph first.")
        embs.append(emb)
    embs = np.vstack(embs)

    if n_clusters is None:
        n_clusters = min(5, len(node_list))  # up to 5 clusters by default

    kmeans = KMeans(n_clusters=n_clusters, random_state=0, n_init=10)
    cluster_ids = kmeans.fit_predict(embs)

    for nid, cid in zip(node_list, cluster_ids):
        subgraph.nodes[nid]["bert_cluster"] = int(cid)

    return n_clusters


In [17]:
start = 2  #interesting nodes nr: 1701, 1986   
max_depth = 4

ladder_subgraph, depths = crawl_citation_ladder(G, start_node=start, max_depth=max_depth)
print("Nodes:", ladder_subgraph.number_of_nodes())
print("Edges:", ladder_subgraph.number_of_edges())


Nodes: 859
Edges: 1536


In [18]:
add_bert_embeddings_to_subgraph(ladder_subgraph)
add_bert_clusters_to_subgraph(ladder_subgraph, n_clusters=clusters)

NameError: name 'clusters' is not defined

In [19]:
visualize_ladder_with_pyvis(ladder_subgraph,start_node=start, html_file="citation_ladder.html")

Saved visualization to citation_ladder.html
