In [1]:
try:
    import torch_geometric
except:
    !pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m12.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.6.1


In [18]:
import requests
import networkx as nx
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.loader import LinkNeighborLoader
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm.notebook import tqdm
from torch_geometric.nn import HeteroConv, SAGEConv, Linear
from torch_geometric.loader import NeighborLoader
from sentence_transformers import SentenceTransformer
import copy

import pickle
import json

# Using Pickle
def save_data_pickle(papers_dict, topics_dict, filename="openalex_data.pkl"):
    """Save the dictionaries using pickle"""
    data = {
        'papers': papers_dict,
        'topics': topics_dict
    }
    with open(filename, 'wb') as f:
        pickle.dump(data, f)

def load_data_pickle(filename="openalex_data.pkl"):
    """Load the dictionaries from pickle file"""
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data['papers'], data['topics']

# Using JSON
def save_data_json(papers_dict, topics_dict, filename="openalex_data.json"):
    """Save the dictionaries using JSON"""
    data = {
        'papers': papers_dict,
        'topics': topics_dict
    }
    with open(filename, 'w') as f:
        json.dump(data, f, indent=4)

def load_data_json(filename="openalex_data.json"):
    """Load the dictionaries from JSON file"""
    with open(filename, 'r') as f:
        data = json.load(f)
    return data['papers'], data['topics']

In [19]:
def get_abstract(inverted_index):
    if inverted_index is not None:
        abstract_words = []
        max_index = max(max(positions) for positions in inverted_index.values())
        abstract_words = [None] * (max_index + 1)
        for word, positions in inverted_index.items():
            for position in positions:
                abstract_words[position] = word

        abstract = " ".join(
            [word if word is not None else "" for word in abstract_words]
        )
        return abstract
    else:
        return "Unavailable"


def fetch_openalex_data(paper_limit: int = 300, degree_limit: int = 1):
    base_url = "https://api.openalex.org/works"
    response = requests.get(base_url)

    seed_works = response.json().get("results")
    ids = [
        work["id"].replace("openalex.org", "api.openalex.org/works")
        for work in seed_works
    ]

    # Remove after testing
    # ids = [ids[0]]

    papers_dict = {}
    topics_dict = {}
    degree = 0

    while degree < degree_limit:
        degree += 1
        new_ids = []

        for pid in ids:
            if(len(papers_dict)) > 90000:
                break
            try:
                response = requests.get(pid).json()
                new_ids.append(response["id"].replace("openalex.org", "api.openalex.org/works"))
                new_ids.extend([ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["referenced_works"]])
                new_ids.extend([ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["related_works"]])
                cited_by_url = response["cited_by_api_url"]
                cited_by_papers = requests.get(cited_by_url).json().get("results")
                new_ids.extend([work["id"].replace("openalex.org", "api.openalex.org/works") for work in cited_by_papers])

                dic_topic = response.get('primary_topic')
                dic_title = response.get("title")
                dic_abstract = response.get("abstract_inverted_index")
                topic = ""
                title = ""
                abstract = ""
                if dic_topic is not None:
                    topic = dic_topic.get("id", "")
                if dic_title is not None:
                    title = dic_title
                if dic_abstract is not None:
                    abstract = get_abstract(dic_abstract)

                papers_dict[pid] = {
                    "id": pid,
                    "citation_count": len(cited_by_papers),
                    "topic": topic,
                    "title": title,
                    "abstract": abstract,
                    "cites": [ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["referenced_works"]]
                }
                if degree == degree_limit:
                    for work in [ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["referenced_works"]]:
                        res = requests.get(work).json()
                        if res is not None:
                            dic_topic = res.get('primary_topic')
                            dic_title = res.get("title")
                            dic_abstract = res.get("abstract_inverted_index")
                            topic = ""
                            title = ""
                            abstract = ""
                            if dic_topic is not None:
                                topic = dic_topic.get("id", "")
                            if dic_title is not None:
                                title = dic_title
                            if dic_abstract is not None:
                                abstract = get_abstract(dic_abstract)
                            papers_dict[work] = {
                                "id": work,
                                "citation_count": 0,
                                "topic": topic,
                                "title": title,
                                "abstract": abstract,
                                "cites": []
                            }
                            if dic_topic is not None:
                                topics_dict[res["primary_topic"]["id"]] = res["primary_topic"]["display_name"]
                            else:
                                topics_dict[""] = ""
                if topic == "":
                    topics_dict[""] = ""
                else:
                    topics_dict[response["primary_topic"]["id"]] = response["primary_topic"]["display_name"]

            except Exception as e:
                print(f"Error fetching data for {pid}: {e}")
                continue

        ids = new_ids
    return papers_dict, topics_dict

In [20]:
papers_dict, topics_dict = load_data_pickle()

In [21]:
def create_edge_split(edge_index, num_nodes, val_ratio=0.1, test_ratio=0.1, message_ratio=0.5):
    """
    Split edges into training message passing edges, training supervision edges,
    validation edges, and test edges
    """
    num_edges = edge_index.size(1)

    # Create random permutation of edges
    perm = torch.randperm(num_edges)
    edge_index = edge_index[:, perm]

    # Calculate split sizes
    test_size = int(num_edges * test_ratio)
    val_size = int(num_edges * val_ratio)
    remaining_size = num_edges - test_size - val_size
    message_size = int(remaining_size * message_ratio)
    supervision_size = remaining_size - message_size

    # Split edges
    test_edges = edge_index[:, :test_size]
    val_edges = edge_index[:, test_size:test_size + val_size]
    train_msg_edges = edge_index[:, test_size + val_size:test_size + val_size + message_size]
    train_sup_edges = edge_index[:, test_size + val_size + message_size:]

    # Generate negative edges for each split
    def sample_negative_edges(pos_edges, num_samples, existing_edges):
        neg_edges = []
        existing_edges_set = {(src.item(), dst.item()) for src, dst in existing_edges.t()}

        while len(neg_edges) < num_samples:
            src = torch.randint(0, num_nodes, (1,))
            dst = torch.randint(0, num_nodes, (1,))
            if src != dst and (src.item(), dst.item()) not in existing_edges_set:
                neg_edges.append([src.item(), dst.item()])
                existing_edges_set.add((src.item(), dst.item()))

        return torch.tensor(neg_edges).t()

    # Sample negative edges for each split
    test_neg = sample_negative_edges(test_edges, test_size, edge_index)
    val_neg = sample_negative_edges(val_edges, val_size, edge_index)
    train_sup_neg = sample_negative_edges(train_sup_edges, supervision_size, edge_index)

    return {
        'train_msg': train_msg_edges,
        'train_sup': (train_sup_edges, train_sup_neg),
        'val': (val_edges, val_neg),
        'test': (test_edges, test_neg)
    }

In [24]:
import torch
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, to_hetero
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import random
import numpy as np

def create_edge_split(edge_index, num_nodes, val_ratio=0.1, test_ratio=0.1, message_ratio=0.5):
    """
    Split edges into training message passing edges, training supervision edges,
    validation edges, and test edges
    """
    num_edges = edge_index.size(1)

    # Create random permutation of edges
    perm = torch.randperm(num_edges)
    edge_index = edge_index[:, perm]

    # Calculate split sizes
    test_size = int(num_edges * test_ratio)
    val_size = int(num_edges * val_ratio)
    remaining_size = num_edges - test_size - val_size
    message_size = int(remaining_size * message_ratio)
    supervision_size = remaining_size - message_size

    # Split edges
    test_edges = edge_index[:, :test_size]
    val_edges = edge_index[:, test_size:test_size + val_size]
    train_msg_edges = edge_index[:, test_size + val_size:test_size + val_size + message_size]
    train_sup_edges = edge_index[:, test_size + val_size + message_size:]

    # Generate negative edges for each split
    def sample_negative_edges(pos_edges, num_samples, existing_edges):
        neg_edges = []
        existing_edges_set = {(src.item(), dst.item()) for src, dst in existing_edges.t()}

        while len(neg_edges) < num_samples:
            src = torch.randint(0, num_nodes, (1,))
            dst = torch.randint(0, num_nodes, (1,))
            if src != dst and (src.item(), dst.item()) not in existing_edges_set:
                neg_edges.append([src.item(), dst.item()])
                existing_edges_set.add((src.item(), dst.item()))

        return torch.tensor(neg_edges).t()

    # Sample negative edges for each split
    test_neg = sample_negative_edges(test_edges, test_size, edge_index)
    val_neg = sample_negative_edges(val_edges, val_size, edge_index)
    train_sup_neg = sample_negative_edges(train_sup_edges, supervision_size, edge_index)

    return {
        'train_msg': train_msg_edges,
        'train_sup': (train_sup_edges, train_sup_neg),
        'val': (val_edges, val_neg),
        'test': (test_edges, test_neg)
    }

def create_hetero_graph(papers_dict, topics_dict):

    # Initialize sentence transformer
    model = SentenceTransformer('all-MiniLM-L6-v2')

    # Create paper nodes
    paper_ids = list(papers_dict.keys())
    id_to_idx = {pid: idx for idx, pid in enumerate(paper_ids)}

    # Create paper features using sentence transformer
    paper_texts = []
    for pid in paper_ids:
        title = papers_dict[pid]['title']
        abstract = papers_dict[pid]['abstract']
        # Concatenate title and abstract
        text = f"{title} {abstract}"
        paper_texts.append(text)

    # Get embeddings for all papers at once (more efficient)
    paper_features = model.encode(paper_texts, convert_to_tensor=True)
    data['paper'].x = paper_features
    data['paper'].paper_ids = paper_ids

    # Create topic nodes
    topic_ids = list(topics_dict.keys())
    topic_id_to_idx = {tid: idx for idx, tid in enumerate(topic_ids)}

    # Create topic features using sentence transformer
    topic_texts = [topics_dict[tid] for tid in topic_ids]
    topic_features = model.encode(topic_texts, convert_to_tensor=True)
    data['topic'].x = topic_features
    data['topic'].topic_ids = topic_ids

    # Create citation edges
    edge_index_cites = []
    for pid, paper_info in papers_dict.items():
        src_idx = id_to_idx[pid]
        for cited_id in paper_info['cites']:
            if cited_id in id_to_idx:
                dst_idx = id_to_idx[cited_id]
                edge_index_cites.append([src_idx, dst_idx])

    edge_index_cites = torch.tensor(edge_index_cites).t()

    # Split citation edges
    edge_splits = create_edge_split(edge_index_cites, len(paper_ids))

    # Store splits in the data object
    data['paper', 'cites', 'paper'].edge_index_splits = edge_splits
    data['paper', 'cites', 'paper'].edge_index = edge_splits['train_msg']

    # Create paper-topic edges
    edge_index_topics = []
    for pid, paper_info in papers_dict.items():
        paper_idx = id_to_idx[pid]
        topic_idx = topic_id_to_idx[paper_info['topic']]
        edge_index_topics.append([paper_idx, topic_idx])

    data['paper', 'has_topic', 'topic'].edge_index = torch.tensor(edge_index_topics).t()

    return data

class DirectionalPredictor(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()

        # Directional transform for source paper
        self.source_transform = torch.nn.Linear(hidden_channels, hidden_channels)

        # Directional transform for target paper
        self.target_transform = torch.nn.Linear(hidden_channels, hidden_channels)

        # Directional transform for source topic
        self.source_topic_transform = torch.nn.Linear(hidden_channels, hidden_channels)

        # Directional transform for target topic
        self.target_topic_transform = torch.nn.Linear(hidden_channels, hidden_channels)

        # Asymmetric scoring function
        self.score = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels * 4, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.5),
            # Directional attention
            torch.nn.Linear(hidden_channels, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, 1)
        )

    def forward(self, z_src, z_dst, z_topic_src, z_topic_dst):
        # Transform source and target differently
        h_src = self.source_transform(z_src)  # Source-specific transform
        h_dst = self.target_transform(z_dst)  # Target-specific transform

        # Transform topics differently based on source/target role
        h_topic_src = self.source_topic_transform(z_topic_src)
        h_topic_dst = self.target_topic_transform(z_topic_dst)

        # Concatenate with explicit directional order
        z = torch.cat([h_src, h_dst, h_topic_src, h_topic_dst], dim=-1)

        return self.score(z).squeeze()

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv(-1, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return x

# class HeteroGNN(torch.nn.Module):
#     def __init__(self, hidden_channels, data):
#         super().__init__()
#         self.gnn = GNN(hidden_channels)
#         self.model = to_hetero(self.gnn, data.metadata())

#         self.predictor = torch.nn.Sequential(
#             torch.nn.Linear(hidden_channels * 4, hidden_channels),
#             torch.nn.ReLU(),
#             torch.nn.Dropout(0.5),
#             torch.nn.Linear(hidden_channels, hidden_channels // 2),
#             torch.nn.ReLU(),
#             torch.nn.Linear(hidden_channels // 2, 1)
#         )

#     def forward(self, x_dict, edge_index_dict):
#         return self.model(x_dict, edge_index_dict)

#     def predict_edge(self, z_dict, edge_index, paper_to_topic_edges):
#         row, col = edge_index
#         z_paper_src = z_dict['paper'][row]
#         z_paper_dst = z_dict['paper'][col]

#         topic_idx_src = self.get_topic_idx(row, paper_to_topic_edges)
#         topic_idx_dst = self.get_topic_idx(col, paper_to_topic_edges)

#         z_topic_src = z_dict['topic'][topic_idx_src]
#         z_topic_dst = z_dict['topic'][topic_idx_dst]

#         z = torch.cat([z_paper_src, z_paper_dst, z_topic_src, z_topic_dst], dim=-1)
#         return self.predictor(z).squeeze()

#     staticmethod
#     def get_topic_idx(paper_idx, paper_to_topic_edges):
#         topic_indices = []
#         paper_idx_list = paper_idx.tolist()
#         if not isinstance(paper_idx_list, list):
#             paper_idx_list = [paper_idx_list]

#         for idx in paper_idx_list:
#             mask = (paper_to_topic_edges[0] == idx)
#             topic_idx = paper_to_topic_edges[1][mask][0]
#             topic_indices.append(topic_idx)

#         return torch.tensor(topic_indices, device=paper_idx.device)
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, data):
        super().__init__()
        self.gnn = GNN(hidden_channels)
        self.model = to_hetero(self.gnn, data.metadata())

        # Replace simple predictor with directional predictor
        self.predictor = DirectionalPredictor(hidden_channels)

    def forward(self, x_dict, edge_index_dict):
        return self.model(x_dict, edge_index_dict)

    def predict_edge(self, z_dict, edge_index, paper_to_topic_edges):
        row, col = edge_index

        # Get paper embeddings
        z_paper_src = z_dict['paper'][row]
        z_paper_dst = z_dict['paper'][col]

        # Get topic embeddings
        topic_idx_src = self.get_topic_idx(row, paper_to_topic_edges)
        topic_idx_dst = self.get_topic_idx(col, paper_to_topic_edges)
        z_topic_src = z_dict['topic'][topic_idx_src]
        z_topic_dst = z_dict['topic'][topic_idx_dst]

        # Use directional predictor
        return self.predictor(z_paper_src, z_paper_dst, z_topic_src, z_topic_dst)

    @staticmethod
    def get_topic_idx(paper_idx, paper_to_topic_edges):
        topic_indices = []
        paper_idx_list = paper_idx.tolist()
        if not isinstance(paper_idx_list, list):
            paper_idx_list = [paper_idx_list]

        for idx in paper_idx_list:
            mask = (paper_to_topic_edges[0] == idx)
            topic_idx = paper_to_topic_edges[1][mask][0]
            topic_indices.append(topic_idx)

        return torch.tensor(topic_indices, device=paper_idx.device)

def evaluate_edges(model, z_dict, edge_index_pos, edge_index_neg, paper_to_topic_edges):
    """Evaluate model on positive and negative edges"""
    edge_index_all = torch.cat([edge_index_pos, edge_index_neg], dim=1)
    labels = torch.cat([torch.ones(edge_index_pos.size(1)),
                       torch.zeros(edge_index_neg.size(1))])

    pred = model.predict_edge(z_dict, edge_index_all, paper_to_topic_edges)
    pred_label = (pred > 0).float()

    accuracy = accuracy_score(labels, pred_label)
    precision = precision_score(labels, pred_label)
    recall = recall_score(labels, pred_label)
    f1 = f1_score(labels, pred_label)

    return accuracy, precision, recall, f1

def suggest_new_citations(model, data, k=5):
    """
    Suggest k new citation pairs and check predictions in both directions.

    Args:
        model: Trained HeteroGNN model
        data: HeteroData object containing paper information
        k: Number of suggestions to return (default=5)
    """
    model.eval()
    paper_to_topic_edges = data['paper', 'has_topic', 'topic'].edge_index

    # Get existing edges to avoid suggesting them
    existing_edges = set()
    edge_splits = data['paper', 'cites', 'paper'].edge_index_splits
    for split_name, edges in edge_splits.items():
        if isinstance(edges, tuple):  # Handle (pos_edge, neg_edge) tuples
            edges = edges[0]  # Take only positive edges
        for i in range(edges.size(1)):
            existing_edges.add((edges[0, i].item(), edges[1, i].item()))

    # Get number of papers
    num_papers = data['paper'].x.size(0)

    # Get node embeddings
    with torch.no_grad():
        node_embeddings = model(data.x_dict, data.edge_index_dict)

    # Generate all possible paper pairs
    suggestions = []
    for i in range(num_papers):
        for j in range(num_papers):
            if i != j and (i, j) not in existing_edges:
                # Check prediction A → B
                edge_AB = torch.tensor([[i], [j]])
                pred_AB = torch.sigmoid(model.predict_edge(
                    node_embeddings, edge_AB, paper_to_topic_edges)).item()

                # Check prediction B → A
                edge_BA = torch.tensor([[j], [i]])
                pred_BA = torch.sigmoid(model.predict_edge(
                    node_embeddings, edge_BA, paper_to_topic_edges)).item()

                # If either prediction is high enough, add to suggestions
                if pred_AB > 0.8 or pred_BA > 0.8:  # Threshold of 0.8
                    suggestions.append((i, j, pred_AB, pred_BA))

    # Sort by maximum prediction score (either direction)
    suggestions.sort(key=lambda x: max(x[2], x[3]), reverse=True)

    # Print top k suggestions
    print(f"\nTop {k} Citation Suggestions:")
    print("-" * 80)

    for src_idx, dst_idx, pred_AB, pred_BA in suggestions[:k]:
        src_id = data['paper'].paper_ids[src_idx]
        dst_id = data['paper'].paper_ids[dst_idx]

        print(f"Paper {src_id} → Paper {dst_id}:")
        print(f"Forward prediction (A→B): {pred_AB:.4f}")
        print(f"Reverse prediction (B→A): {pred_BA:.4f}")

        # Show which prediction is stronger
        if pred_AB > pred_BA:
            print(f"Suggestion: Paper {src_id} should cite Paper {dst_id}")
        else:
            print(f"Suggestion: Paper {dst_id} should cite Paper {src_id}")
        print("-" * 80)

def train_and_evaluate(data):
    model = HeteroGNN(hidden_channels=64, data=data)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss()

    edge_splits = data['paper', 'cites', 'paper'].edge_index_splits
    paper_to_topic_edges = data['paper', 'has_topic', 'topic'].edge_index

    best_val_f1 = 0
    best_model = None

    for epoch in range(100):
        model.train()
        optimizer.zero_grad()

        # Forward pass using message passing edges
        node_embeddings = model(data.x_dict, data.edge_index_dict)

        # Get predictions for supervision edges
        train_pos, train_neg = edge_splits['train_sup']
        train_edge_index = torch.cat([train_pos, train_neg], dim=1)
        train_labels = torch.cat([torch.ones(train_pos.size(1)),
                                torch.zeros(train_neg.size(1))])

        pred = model.predict_edge(node_embeddings, train_edge_index, paper_to_topic_edges)
        loss = criterion(pred, train_labels)

        loss.backward()
        optimizer.step()

        # Validation
        model.eval()
        with torch.no_grad():
            val_pos, val_neg = edge_splits['val']
            val_acc, val_prec, val_rec, val_f1 = evaluate_edges(
                model, node_embeddings, val_pos, val_neg, paper_to_topic_edges)

            if val_f1 > best_val_f1:
                best_val_f1 = val_f1
                best_model = model.state_dict()

        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1:03d}, Loss: {loss:.4f}, Val F1: {val_f1:.4f}')

    # Load best model and evaluate on test set
    model.load_state_dict(best_model)
    model.eval()
    with torch.no_grad():
        final_embeddings = model(data.x_dict, data.edge_index_dict)
        test_pos, test_neg = edge_splits['test']
        test_acc, test_prec, test_rec, test_f1 = evaluate_edges(
            model, final_embeddings, test_pos, test_neg, paper_to_topic_edges)

        print("\nTest Metrics:")
        print(f"Accuracy: {test_acc:.4f}")
        print(f"Precision: {test_prec:.4f}")
        print(f"Recall: {test_rec:.4f}")
        print(f"F1 Score: {test_f1:.4f}")
    suggest_new_citations(model, data)



In [None]:
data = create_hetero_graph(papers_dict, topics_dict)
train_and_evaluate(data)