In [29]:
import pickle
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import networkx as nx

from torch_geometric.nn import SAGEConv
from torch_geometric.utils import from_networkx
from sklearn.metrics import roc_auc_score, precision_score, recall_score

In [30]:
try:
    from sentence_transformers import SentenceTransformer
    EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
    embedder = SentenceTransformer(EMBEDDING_MODEL_NAME)
    print(f"Loaded embedding model: {EMBEDDING_MODEL_NAME}")
except ImportError:
    embedder = None
    print("sentence-transformers not installed. Please install it to embed text.")

Loaded embedding model: sentence-transformers/all-MiniLM-L6-v2


In [31]:
def compute_author_embedding(abstracts):
    if not embedder or len(abstracts) == 0:
        return torch.zeros(384, dtype=torch.float)  # 384 is dimension of all-MiniLM-L6-v2
    abs_embeddings = embedder.encode(abstracts, convert_to_tensor=True)
    return abs_embeddings.mean(dim=0)  # shape: [embedding_dim]

def compute_topic_embedding(node_id, topic_name=""):
    if not embedder:
        return torch.zeros(384, dtype=torch.float)
    text_to_embed = node_id
    
    topic_emb = embedder.encode([text_to_embed], convert_to_tensor=True)
    return topic_emb[0]  # shape: [embedding_dim]


def add_node_id_attribute(G):
    # If G has N nodes and each node is labeled by something (e.g. "A123" or 123),
    # we store that label in the 'node_id' attribute.
    nx.set_node_attributes(
        G,
        {node: str(node) for node in G.nodes()},
        name="node_id"
    )
    return G


def convert_graph_to_pyg(G):
    # Ensure node_id is set
    G = add_node_id_attribute(G)

    # Now from_networkx will pick up the node_id attribute
    pyg_data = from_networkx(G)

    return pyg_data


def load_and_unify_attributes(graph_pickle_path="niu_graph.pkl"):
    # Load the NetworkX graph from a pickle and ensure every node has the same set of attributes
    with open(graph_pickle_path, "rb") as f:
        G = pickle.load(f)

    for node_id, data in G.nodes(data=True):
        # node_type
        data.setdefault("node_type", "unknown")

        # If it's an author, fill in missing fields with sensible defaults:
        if data["node_type"] == "author":
            data.setdefault("works_count", 0.0)
            data.setdefault("cited_by_count", 0.0)
            data.setdefault("abstracts", [])
        elif data["node_type"] == "topic":
            # For topics, we won't have numeric or abstract data
            data.setdefault("works_count", 0.0)
            data.setdefault("cited_by_count", 0.0)
            data.setdefault("abstracts", [])
        else:
            # fallback for unknown node types
            data.setdefault("works_count", 0.0)
            data.setdefault("cited_by_count", 0.0)
            data.setdefault("abstracts", [])

    return G


def build_pyg_data(G):
    print("Embedding node information...")

    # Embed each node and store in a dict
    node_embeddings = {}
    embedding_dim = 384  # dimension for all-MiniLM-L6-v2; change if using a different model

    for node_id, attrs in G.nodes(data=True):
        node_type = attrs.get("node_type", "unknown")

        if node_type == "author":
            # numeric features
            wc = float(attrs.get("works_count", 0))
            cc = float(attrs.get("cited_by_count", 0))

            # text embedding (from abstracts)
            abstracts = attrs.get("abstracts", [])
            author_emb = compute_author_embedding(abstracts)

            # combine numeric + text (2 + embedding_dim)
            numeric_vec = torch.tensor([wc, cc], dtype=torch.float)
            full_vec = torch.cat([numeric_vec, author_emb], dim=0)

        elif node_type == "topic":
            # numeric features: [0,0]
            numeric_vec = torch.zeros(2, dtype=torch.float)

            # text embedding (from node_id)
            topic_emb = compute_topic_embedding(node_id)

            full_vec = torch.cat([numeric_vec, topic_emb], dim=0)
        else:
            # fallback for unknown node types
            # just zero everything
            numeric_vec = torch.zeros(2, dtype=torch.float)
            zero_emb = torch.zeros(embedding_dim, dtype=torch.float)
            full_vec = torch.cat([numeric_vec, zero_emb], dim=0)

        node_embeddings[node_id] = full_vec

    # Now convert to PyG
    pyg_data = from_networkx(G)

    # Build final feature matrix pyg_data.x
    num_nodes = pyg_data.num_nodes
    all_features = []

    # from_networkx typically keeps a 'node' attribute array with the original node_id
    for i in range(num_nodes):
        node_id = pyg_data["node_id"][i]  # the original ID
        # Retrieve the precomputed vector
        feat_vec = node_embeddings[node_id]
        all_features.append(feat_vec)

    x = torch.stack(all_features, dim=0)
    pyg_data.x = x

    return pyg_data

In [32]:
class GraphSAGEModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels=64):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x  # shape [num_nodes, out_channels]

    def predict_link(self, emb, edge_pairs):
        """
        Dot-product similarity -> sigmoid as probability of link existence.
        """
        u = edge_pairs[0]
        v = edge_pairs[1]
        score = (emb[u] * emb[v]).sum(dim=-1)
        return torch.sigmoid(score)


def train_test_link_prediction(pyg_data, epochs=20, lr=0.01, hidden_dim=32):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # -------------------------------------------
    # A) Identify authors vs. topics from node_type
    #    If node_type is a Python list, we can do:
    node_types = list(pyg_data["node_type"])  # e.g. 100-length list of b'author' or b'topic'
    # Convert bytes to str if needed
    node_types = [nt.decode("utf-8") if isinstance(nt, bytes) else nt for nt in node_types]

    # We'll gather indices of authors
    author_indices_list = [i for i, nt in enumerate(node_types) if nt == "author"]
    author_indices_torch = torch.tensor(author_indices_list, dtype=torch.long)

    # -------------------------------------------
    # B) Filter edges for co_author (author-author)
    edge_index = pyg_data.edge_index
    relations = pyg_data["relation"]  # likely a list of b'co_author', b'has_topic', etc.

    # Convert relations to a list of strings
    if isinstance(relations, torch.Tensor):
        # Might be a tensor of bytes or python objects. Let's turn them into normal Python strings
        relations = [r.decode("utf-8") if isinstance(r, bytes) else str(r) for r in relations.tolist()]
    else:
        # if it's already a list, just ensure they're strings
        relations = [r.decode("utf-8") if isinstance(r, bytes) else str(r) for r in relations]

    coauthor_edges = []
    for i, rel in enumerate(relations):
        if rel == "co_author":
            src = edge_index[0, i].item()
            dst = edge_index[1, i].item()
            # check that both are authors
            if node_types[src] == "author" and node_types[dst] == "author":
                coauthor_edges.append((src, dst))

    if len(coauthor_edges) < 2:
        print("Not enough co_author edges for training.")
        return None, None, None

    pos_edge_count = len(coauthor_edges)
    # Turn into a torch tensor of shape [2, n]
    coauthor_edges_t = torch.tensor(coauthor_edges, dtype=torch.long).t()

    # -------------------------------------------
    # C) Negative sampling: random author-author pairs not in co_author
    existing_pairs = set()
    for a, b in coauthor_edges:
        if a > b:
            a, b = b, a
        existing_pairs.add((a, b))

    needed_neg_count = pos_edge_count
    author_list = author_indices_list
    negative_edges_list = []
    attempts = 0
    max_attempts = needed_neg_count * 10

    while len(negative_edges_list) < needed_neg_count and attempts < max_attempts:
        a = random.choice(author_list)
        b = random.choice(author_list)
        if a == b:
            attempts += 1
            continue
        if a > b:
            a, b = b, a
        if (a, b) not in existing_pairs:
            negative_edges_list.append((a, b))
        attempts += 1

    if len(negative_edges_list) < needed_neg_count:
        needed_neg_count = len(negative_edges_list)

    neg_edges_t = torch.tensor(negative_edges_list[:needed_neg_count], dtype=torch.long).t()

    # -------------------------------------------
    # D) Train/test split
    num_pos = coauthor_edges_t.size(1)
    perm_pos = torch.randperm(num_pos)
    split_pos = int(num_pos * 0.8)
    train_pos = coauthor_edges_t[:, perm_pos[:split_pos]]
    test_pos = coauthor_edges_t[:, perm_pos[split_pos:]]

    num_neg = neg_edges_t.size(1)
    perm_neg = torch.randperm(num_neg)
    split_neg = int(num_neg * 0.8)
    train_neg = neg_edges_t[:, perm_neg[:split_neg]]
    test_neg = neg_edges_t[:, perm_neg[split_neg:]]

    # -------------------------------------------
    # E) Build & train model
    # We'll do a trivial x=pyg_data.x if you have numeric/embedding features
    # or we can do x=some dummy if you haven't built your features yet.
    x = pyg_data.x
    if x is None:
        print("Warning: No features found in pyg_data.x. Using a random feature vector.")
        x = torch.rand(pyg_data.num_nodes, 8)  # random fallback

    in_channels = x.size(1)
    model = GraphSAGEModel(in_channels, hidden_dim, out_channels=64).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    x = x.to(device)
    full_edge_index = pyg_data.edge_index.to(device)
    model.train()
    for epoch in range(1, epochs+1):
        optimizer.zero_grad()

        emb = model(x, full_edge_index)
        pos_score = model.predict_link(emb, train_pos.to(device))
        neg_score = model.predict_link(emb, train_neg.to(device))

        pos_labels = torch.ones(pos_score.size(0), device=device)
        neg_labels = torch.zeros(neg_score.size(0), device=device)

        all_scores = torch.cat([pos_score, neg_score], dim=0)
        all_labels = torch.cat([pos_labels, neg_labels], dim=0)

        loss = F.binary_cross_entropy(all_scores, all_labels)
        loss.backward()
        optimizer.step()

        if epoch % 5 == 0:
            print(f"Epoch {epoch}/{epochs}, Loss = {loss.item():.4f}")

    # -------------------------------------------
    # F) Evaluate
    model.eval()
    with torch.no_grad():
        emb = model(x, full_edge_index)
        test_pos_score = model.predict_link(emb, test_pos.to(device))
        test_neg_score = model.predict_link(emb, test_neg.to(device))

        # combine
        test_scores = torch.cat([test_pos_score, test_neg_score], dim=0).cpu().numpy()
        test_labels = torch.cat([
            torch.ones(test_pos_score.size(0)),
            torch.zeros(test_neg_score.size(0))
        ]).cpu().numpy()

        # AUC
        auc = roc_auc_score(test_labels, test_scores)

        # Precision & Recall at threshold 0.5
        binary_preds = (test_scores >= 0.5).astype(int)
        precision = precision_score(test_labels, binary_preds)
        recall = recall_score(test_labels, binary_preds)

        print(f"\nTest AUC: {auc:.4f}")
        print(f"Precision: {precision:.4f}, Recall: {recall:.4f}")

    return model, auc, (precision, recall)

In [33]:
# Load pickled NetworkX graph
print("Loading graph from niu_graph.pkl...")
G = load_and_unify_attributes("niu_graph.pkl")

Loading graph from niu_graph.pkl...


In [34]:
# Build PyG data, embedding abstracts & topic IDs
pyg_data = convert_graph_to_pyg(G)

In [35]:
# Train GNN for link prediction
print("Training GNN for co-author link prediction (with abstract & topic embeddings)...")
model, auc, (precision, recall) = train_test_link_prediction(
    pyg_data,
    epochs=20,
    lr=0.01,
    hidden_dim=32
)
if model:
    print("\nDone with training!")
    print(f"Final metrics: AUC={auc:.4f}, Precision={precision:.4f}, Recall={recall:.4f}")

Training GNN for co-author link prediction (with abstract & topic embeddings)...
Epoch 5/20, Loss = 0.7751
Epoch 10/20, Loss = 0.7070
Epoch 15/20, Loss = 0.6999
Epoch 20/20, Loss = 0.6978

Test AUC: 0.5666
Precision: 0.5000, Recall: 1.0000

Done with training!
Final metrics: AUC=0.5666, Precision=0.5000, Recall=1.0000
