In [1]:
try:
    import torch_geometric
except:
    !pip install torch-geometric

In [2]:
a={'a':1}
len(a)

1

In [3]:
import requests
import networkx as nx
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.loader import LinkNeighborLoader
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm.notebook import tqdm
from torch_geometric.nn import HeteroConv, SAGEConv, Linear
from torch_geometric.loader import NeighborLoader
import copy

In [4]:
def get_abstract(inverted_index):
    if inverted_index is not None:
        abstract_words = []
        max_index = max(max(positions) for positions in inverted_index.values())
        abstract_words = [None] * (max_index + 1)
        for word, positions in inverted_index.items():
            for position in positions:
                abstract_words[position] = word

        abstract = " ".join(
            [word if word is not None else "" for word in abstract_words]
        )
        return abstract
    else:
        return "Unavailable"


def fetch_openalex_data(paper_limit: int = 300, degree_limit: int = 1):
    base_url = "https://api.openalex.org/works"
    response = requests.get(base_url)

    seed_works = response.json().get("results")
    ids = [
        work["id"].replace("openalex.org", "api.openalex.org/works")
        for work in seed_works
    ]

    # Remove after testing
    # ids = [ids[0]]

    papers_dict = {}
    topics_dict = {}
    degree = 0

    while degree < degree_limit:
        degree += 1
        new_ids = []

        for pid in ids:
            if(len(papers_dict)) > 90000:
                break
            try:
                response = requests.get(pid).json()
                new_ids.append(response["id"].replace("openalex.org", "api.openalex.org/works"))
                new_ids.extend([ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["referenced_works"]])
                new_ids.extend([ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["related_works"]])
                cited_by_url = response["cited_by_api_url"]
                cited_by_papers = requests.get(cited_by_url).json().get("results")
                new_ids.extend([work["id"].replace("openalex.org", "api.openalex.org/works") for work in cited_by_papers])

                dic_topic = response.get('primary_topic')
                dic_title = response.get("title")
                dic_abstract = response.get("abstract_inverted_index")
                topic = ""
                title = ""
                abstract = ""
                if dic_topic is not None:
                    topic = dic_topic.get("id", "")
                if dic_title is not None:
                    title = dic_title
                if dic_abstract is not None:
                    abstract = get_abstract(dic_abstract)

                papers_dict[pid] = {
                    "id": pid,
                    "citation_count": len(cited_by_papers),
                    "topic": topic,
                    "title": title,
                    "abstract": abstract,
                    "cites": [ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["referenced_works"]]
                }
                if degree == degree_limit:
                    for work in [ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["referenced_works"]]:
                        res = requests.get(work).json()
                        if res is not None:
                            dic_topic = res.get('primary_topic')
                            dic_title = res.get("title")
                            dic_abstract = res.get("abstract_inverted_index")
                            topic = ""
                            title = ""
                            abstract = ""
                            if dic_topic is not None:
                                topic = dic_topic.get("id", "")
                            if dic_title is not None:
                                title = dic_title
                            if dic_abstract is not None:
                                abstract = get_abstract(dic_abstract)
                            papers_dict[work] = {
                                "id": work,
                                "citation_count": 0,
                                "topic": topic,
                                "title": title,
                                "abstract": abstract,
                                "cites": []
                            }
                            if dic_topic is not None:
                                topics_dict[res["primary_topic"]["id"]] = res["primary_topic"]["display_name"]
                            else:
                                topics_dict[""] = ""
                if topic == "":
                    topics_dict[""] = ""
                else:
                    topics_dict[response["primary_topic"]["id"]] = response["primary_topic"]["display_name"]

            except Exception as e:
                print(f"Error fetching data for {pid}: {e}")
                continue

        ids = new_ids
    return papers_dict, topics_dict

In [5]:
# papers_dict, topics_dict = fetch_openalex_data(degree_limit=1)

In [6]:
import torch
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, SAGEConv
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import numpy as np

def create_heterogeneous_graph(papers_dict, topics_dict):
    # Collect unique topics
    unique_topics = set(paper_info['topic'] for paper_info in papers_dict.values() if paper_info['topic'])

    # Create mappings
    paper_ids = list(papers_dict.keys())
    topic_ids = list(unique_topics)

    # Create index mappings
    paper_id_to_idx = {pid: idx for idx, pid in enumerate(paper_ids)}
    topic_id_to_idx = {tid: idx + len(paper_ids) for idx, tid in enumerate(topic_ids)}

    # Use SentenceTransformer for text embeddings
    text_embedder = SentenceTransformer('all-MiniLM-L6-v2')

    # Prepare node features
    total_nodes = len(paper_ids) + len(topic_ids)
    feature_dim = 384  # SentenceTransformer embedding dimension
    x = torch.zeros(total_nodes, feature_dim)

    # Populate paper features
    for paper_id in paper_ids:
        paper_info = papers_dict[paper_id]
        text = f"{paper_info['title']} {paper_info['abstract']}"
        embedding = text_embedder.encode(text, convert_to_tensor=True)
        x[paper_id_to_idx[paper_id]] = embedding

    # Populate topic features
    for topic_id in topic_ids:
        topic_name = topics_dict.get(topic_id, topic_id)
        embedding = text_embedder.encode(topic_name, convert_to_tensor=True)
        x[topic_id_to_idx[topic_id]] = embedding

    # Prepare edge indices
    citation_src, citation_dst = [], []
    topic_paper_src, topic_paper_dst = [], []

    # Create citation edges
    for paper_id in paper_ids:
        paper_info = papers_dict[paper_id]
        src_idx = paper_id_to_idx[paper_id]
        for cited_paper in paper_info['cites']:
            if cited_paper in paper_id_to_idx:
                dst_idx = paper_id_to_idx[cited_paper]
                citation_src.append(src_idx)
                citation_dst.append(dst_idx)

    # Create topic-paper edges
    for paper_id in paper_ids:
        paper_info = papers_dict[paper_id]
        if paper_info['topic'] in topic_id_to_idx:
            paper_idx = paper_id_to_idx[paper_id]
            topic_idx = topic_id_to_idx[paper_info['topic']]

            # Bidirectional edges
            topic_paper_src.append(paper_idx)
            topic_paper_dst.append(topic_idx)
            topic_paper_src.append(topic_idx)
            topic_paper_dst.append(paper_idx)

    # Create edge indices
    citation_edge_index = torch.tensor([citation_src, citation_dst], dtype=torch.long)
    topic_paper_edge_index = torch.tensor([topic_paper_src, topic_paper_dst], dtype=torch.long)

    # Create HeteroData
    data = HeteroData()
    data['paper', 'cites', 'paper'].edge_index = citation_edge_index
    data['paper', 'has_topic', 'topic'].edge_index = topic_paper_edge_index
    data['topic', 'has_paper', 'paper'].edge_index = topic_paper_edge_index.flip(0)
    data.x = x

    return data, paper_id_to_idx, topic_id_to_idx

def print_graph_statistics(graph_data, papers_dict, topics_dict, edge_splits):
    """Print detailed statistics about the graph structure."""

    print("\n=== Graph Statistics ===")

    # Node statistics
    num_papers = len(papers_dict)
    num_topics = len(topics_dict)
    print("\nNode Counts:")
    print(f"Papers: {num_papers}")
    print(f"Topics: {num_topics}")
    print(f"Total Nodes: {num_papers + num_topics}")

    # Edge statistics
    paper_edges = graph_data['paper', 'cites', 'paper'].edge_index
    topic_edges = graph_data['paper', 'has_topic', 'topic'].edge_index

    print("\nTotal Edge Counts:")
    print(f"Citations (paper → paper): {paper_edges.size(1)}")
    print(f"Topic Associations (paper ↔ topic): {topic_edges.size(1)}")
    print(f"Total Edges: {paper_edges.size(1) + topic_edges.size(1)}")

    # Edge split statistics
    (paper_message, paper_supervision, paper_val, paper_test,
     topic_message, topic_supervision, topic_val, topic_test) = edge_splits

    print("\nCitation Edge Splits:")
    print(f"Message Edges: {paper_message.size(1)}")
    print(f"Supervision Edges: {paper_supervision.size(1)}")
    print(f"Validation Edges: {paper_val.size(1)}")
    print(f"Test Edges: {paper_test.size(1)}")
    print(f"Total: {paper_message.size(1) + paper_supervision.size(1) + paper_val.size(1) + paper_test.size(1)}")

    print("\nTopic Association Edge Splits:")
    print(f"Message Edges: {topic_message.size(1)}")
    print(f"Supervision Edges: {topic_supervision.size(1)}")
    print(f"Validation Edges: {topic_val.size(1)}")
    print(f"Test Edges: {topic_test.size(1)}")
    print(f"Total: {topic_message.size(1) + topic_supervision.size(1) + topic_val.size(1) + topic_test.size(1)}")

    # Additional statistics
    print("\nGraph Density:")
    possible_citations = num_papers * (num_papers - 1)  # Directed edges
    citation_density = paper_edges.size(1) / possible_citations
    print(f"Citation Density: {citation_density:.6f}")

    possible_topic_associations = num_papers * num_topics
    topic_density = topic_edges.size(1) / possible_topic_associations
    print(f"Topic Association Density: {topic_density:.6f}")

    # Topic distribution
    topic_counts = {}
    for paper_info in papers_dict.values():
        topic = paper_info['topic']
        topic_counts[topic] = topic_counts.get(topic, 0) + 1

    print("\nTopic Distribution:")
    print("Top 5 topics by paper count:")
    sorted_topics = sorted(topic_counts.items(), key=lambda x: x[1], reverse=True)
    for topic_id, count in sorted_topics[:5]:
        topic_name = topics_dict.get(topic_id, 'Unknown')
        print(f"Topic: {topic_name}, Papers: {count}")

In [7]:
class HeteroGCNMultiPredictor(torch.nn.Module):
    def __init__(self, feature_dim, hidden_channels=64):
        super().__init__()

        # Graph convolution layers
        self.conv1 = GCNConv(feature_dim, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

        # Link predictor for paper citations
        self.citation_predictor = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels * 2, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, 1)
        )

        # Link predictor for topic associations
        self.topic_predictor = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels * 2, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, 1)
        )

    def encode(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

    def decode_citations(self, z, edge_label_index):
        row, col = edge_label_index
        z_src = z[row]
        z_dst = z[col]
        z_combined = torch.cat([z_src, z_dst], dim=-1)
        return self.citation_predictor(z_combined).squeeze()

    def decode_topics(self, z, edge_label_index):
        row, col = edge_label_index
        z_src = z[row]
        z_dst = z[col]
        z_combined = torch.cat([z_src, z_dst], dim=-1)
        return self.topic_predictor(z_combined).squeeze()

def split_edges(edge_index, val_ratio=0.1, test_ratio=0.1, message_ratio=0.5):
    """
    Split edges into training message edges, training supervision edges, validation edges, and test edges.
    message_ratio determines split between message and supervision edges within training set.
    """
    num_edges = edge_index.size(1)
    num_val = int(num_edges * val_ratio)
    num_test = int(num_edges * test_ratio)
    num_train = num_edges - (num_val + num_test)
    num_message = int(num_train * message_ratio)

    # Randomly shuffle edges
    perm = torch.randperm(num_edges)

    # Split indices
    message_idx = perm[:num_message]
    supervision_idx = perm[num_message:num_train]
    val_idx = perm[num_train:num_train+num_val]
    test_idx = perm[num_train+num_val:]

    # Create edge sets
    message_edges = edge_index[:, message_idx]
    supervision_edges = edge_index[:, supervision_idx]
    val_edges = edge_index[:, val_idx]
    test_edges = edge_index[:, test_idx]

    return message_edges, supervision_edges, val_edges, test_edges

def split_edges_by_type(paper_edges, topic_edges, val_ratio=0.1, test_ratio=0.1, message_ratio=0.5):
    """Split both paper citation edges and topic association edges."""

    # Split citation edges
    paper_splits = split_edges(paper_edges, val_ratio, test_ratio, message_ratio)
    paper_message, paper_supervision, paper_val, paper_test = paper_splits

    # Split topic edges
    topic_splits = split_edges(topic_edges, val_ratio, test_ratio, message_ratio)
    topic_message, topic_supervision, topic_val, topic_test = topic_splits

    return (paper_message, paper_supervision, paper_val, paper_test,
            topic_message, topic_supervision, topic_val, topic_test)

def calculate_metrics(pred, target):
    """Calculate accuracy, precision, recall, and F1 score."""
    from sklearn.metrics import accuracy_score, precision_recall_fscore_support

    pred_binary = (pred > 0.5).cpu().numpy()
    target = target.cpu().numpy()

    accuracy = accuracy_score(target, pred_binary)
    precision, recall, f1, _ = precision_recall_fscore_support(target, pred_binary, average='binary')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

def train_and_evaluate_multi(model, data, edge_splits, paper_id_to_idx, topic_id_to_idx, epochs=100):
    (paper_message, paper_supervision, paper_val, paper_test,
     topic_message, topic_supervision, topic_val, topic_test) = edge_splits

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss()

    def evaluate(z, pos_edge_set, edge_type="citation", name=""):
        model.eval()
        with torch.no_grad():
            # Generate negative edges
            neg_edge_set = torch.randint(0,
                len(paper_id_to_idx) if edge_type == "citation" else len(topic_id_to_idx),
                (2, pos_edge_set.size(1)))

            # Get predictions
            if edge_type == "citation":
                pos_pred = model.decode_citations(z, pos_edge_set).sigmoid()
                neg_pred = model.decode_citations(z, neg_edge_set).sigmoid()
            else:
                pos_pred = model.decode_topics(z, pos_edge_set).sigmoid()
                neg_pred = model.decode_topics(z, neg_edge_set).sigmoid()

            # Combine predictions and create labels
            pred = torch.cat([pos_pred, neg_pred])
            target = torch.cat([torch.ones_like(pos_pred), torch.zeros_like(neg_pred)])

            metrics = calculate_metrics(pred, target)

            print(f"\n{name} {edge_type.capitalize()} Metrics:")
            print(f"Accuracy: {metrics['accuracy']:.4f}")
            print(f"Precision: {metrics['precision']:.4f}")
            print(f"Recall: {metrics['recall']:.4f}")
            print(f"F1: {metrics['f1']:.4f}")

            return metrics

    best_val_f1 = {'citation': 0, 'topic': 0}
    best_model = None

    # Combine message edges for initial embedding
    combined_message = torch.cat([paper_message, topic_message], dim=1)

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        # Get node embeddings using message edges
        z = model.encode(data.x, combined_message)

        # Citation prediction loss
        neg_paper_edges = torch.randint(0, len(paper_id_to_idx), (2, paper_supervision.size(1)))
        paper_pos_pred = model.decode_citations(z, paper_supervision)
        paper_neg_pred = model.decode_citations(z, neg_paper_edges)
        paper_loss = criterion(paper_pos_pred, torch.ones_like(paper_pos_pred)) + \
                    criterion(paper_neg_pred, torch.zeros_like(paper_neg_pred))

        # Topic prediction loss
        neg_topic_edges = torch.randint(0, len(topic_id_to_idx), (2, topic_supervision.size(1)))
        topic_pos_pred = model.decode_topics(z, topic_supervision)
        topic_neg_pred = model.decode_topics(z, neg_topic_edges)
        topic_loss = criterion(topic_pos_pred, torch.ones_like(topic_pos_pred)) + \
                    criterion(topic_neg_pred, torch.zeros_like(topic_neg_pred))

        # Combined loss
        loss = paper_loss + topic_loss

        loss.backward()
        optimizer.step()

        if (epoch + 1) % 10 == 0:
            print(f"\nEpoch {epoch+1:03d}:")

            # Evaluate citations
            train_citation_metrics = evaluate(z, paper_supervision, "citation", "Training")
            val_citation_metrics = evaluate(z, paper_val, "citation", "Validation")

            # Evaluate topic associations
            train_topic_metrics = evaluate(z, topic_supervision, "topic", "Training")
            val_topic_metrics = evaluate(z, topic_val, "topic", "Validation")

            # Save best model based on average F1 score
            current_val_f1 = {
                'citation': val_citation_metrics['f1'],
                'topic': val_topic_metrics['f1']
            }

            if (current_val_f1['citation'] + current_val_f1['topic'])/2 > \
               (best_val_f1['citation'] + best_val_f1['topic'])/2:
                best_val_f1 = current_val_f1
                best_model = copy.deepcopy(model)

    # Final evaluation on test set
    print("\nFinal Test Set Evaluation:")
    z = best_model.encode(data.x, combined_message)
    test_citation_metrics = evaluate(z, paper_test, "citation", "Test")
    test_topic_metrics = evaluate(z, topic_test, "topic", "Test")

    return best_model, test_citation_metrics, test_topic_metrics

In [8]:
# Load data
papers_dict, topics_dict = fetch_openalex_data(degree_limit=2)

# Create graph
graph_data, paper_id_to_idx, topic_id_to_idx = create_heterogeneous_graph(papers_dict, topics_dict)

# Get paper citation edges and topic association edges
paper_edges = graph_data['paper', 'cites', 'paper'].edge_index
topic_edges = graph_data['paper', 'has_topic', 'topic'].edge_index

# Split edges
edge_splits = split_edges_by_type(paper_edges, topic_edges)

print_graph_statistics(graph_data, papers_dict, topics_dict, edge_splits)

# Initialize and train model
model = HeteroGCNMultiPredictor(feature_dim=graph_data.x.size(1))
best_model, test_citation_metrics, test_topic_metrics = train_and_evaluate_multi(
    model, graph_data, edge_splits, paper_id_to_idx, topic_id_to_idx
)

# Print final results
print("\nFinal Test Results:")
print("\nCitation Prediction:")
print(f"Accuracy: {test_citation_metrics['accuracy']:.4f}")
print(f"Precision: {test_citation_metrics['precision']:.4f}")
print(f"Recall: {test_citation_metrics['recall']:.4f}")
print(f"F1: {test_citation_metrics['f1']:.4f}")

print("\nTopic Association Prediction:")
print(f"Accuracy: {test_topic_metrics['accuracy']:.4f}")
print(f"Precision: {test_topic_metrics['precision']:.4f}")
print(f"Recall: {test_topic_metrics['recall']:.4f}")
print(f"F1: {test_topic_metrics['f1']:.4f}")

Error fetching data for https://api.openalex.org/works/W1576013682: 'NoneType' object is not subscriptable
Error fetching data for https://api.openalex.org/works/W59369663: 'NoneType' object is not subscriptable
Error fetching data for https://api.openalex.org/works/W2778153218: Expecting value: line 1 column 1 (char 0)
Error fetching data for https://api.openalex.org/works/W2527526854: Expecting value: line 1 column 1 (char 0)
Error fetching data for https://api.openalex.org/works/W170893625: 'NoneType' object is not subscriptable
Error fetching data for https://api.openalex.org/works/W2159707944: Expecting value: line 1 column 1 (char 0)
Error fetching data for https://api.openalex.org/works/W2332681686: Expecting value: line 1 column 1 (char 0)
Error fetching data for https://api.openalex.org/works/W2029411830: Expecting value: line 1 column 1 (char 0)
Error fetching data for https://api.openalex.org/works/W1976806156: Expecting value: line 1 column 1 (char 0)
Error fetching data fo

  torch.utils._pytree._register_pytree_node(



=== Graph Statistics ===

Node Counts:
Papers: 49749
Topics: 2594
Total Nodes: 52343

Total Edge Counts:
Citations (paper → paper): 39346
Topic Associations (paper ↔ topic): 98246
Total Edges: 137592

Citation Edge Splits:
Message Edges: 15739
Supervision Edges: 15739
Validation Edges: 3934
Test Edges: 3934
Total: 39346

Topic Association Edge Splits:
Message Edges: 39299
Supervision Edges: 39299
Validation Edges: 9824
Test Edges: 9824
Total: 98246

Graph Density:
Citation Density: 0.000016
Topic Association Density: 0.000761

Topic Distribution:
Top 5 topics by paper count:
Topic: Advancements in Density Functional Theory, Papers: 2288
Topic: Diagnosis and Management of Alzheimer's Disease, Papers: 941
Topic: Methods for Evidence Synthesis in Research, Papers: 831
Topic: Molecular Mechanisms of Synaptic Plasticity and Neurological Disorders, Papers: 806
Topic: , Papers: 626

Epoch 010:

Training Citation Metrics:
Accuracy: 0.7592
Precision: 0.7611
Recall: 0.7555
F1: 0.7583

Validatio

In [None]:
def discover_cross_topic_links(model, data, papers_dict, paper_id_to_idx, threshold=0.8):
    """
    Discovers potential links between papers of different topics.

    Args:
        model: Trained GNN model
        data: Graph data object
        papers_dict: Dictionary containing paper information
        paper_id_to_idx: Mapping from paper IDs to indices
        threshold: Probability threshold for considering a link (default: 0.8)

    Returns:
        List of dictionaries containing discovered links and their probabilities
    """
    model.eval()
    with torch.no_grad():
        # Get node embeddings using all available edges
        z = model.encode(data.x, data['paper', 'cites', 'paper'].edge_index)
        z_papers = z[:len(paper_id_to_idx)]

        # Create reverse mapping from indices to paper IDs
        idx_to_paper_id = {idx: pid for pid, idx in paper_id_to_idx.items()}

        # Generate all possible pairs of papers from different topics
        potential_links = []
        paper_indices = list(range(len(paper_id_to_idx)))

        for i in range(len(paper_indices)):
            for j in range(i + 1, len(paper_indices)):
                paper1_id = idx_to_paper_id[paper_indices[i]]
                paper2_id = idx_to_paper_id[paper_indices[j]]

                # Check if papers are from different topics
                if papers_dict[paper1_id]['topic'] != papers_dict[paper2_id]['topic']:
                    potential_links.append([paper_indices[i], paper_indices[j]])

        if not potential_links:
            return []

        # Convert to tensor and predict links
        potential_edges = torch.tensor(potential_links, dtype=torch.long).t()
        predictions = model.decode(z_papers, potential_edges).sigmoid()

        # Collect high-probability links
        discovered_links = []
        for i, prob in enumerate(predictions):
            if prob > threshold:
                paper1_idx = potential_links[i][0]
                paper2_idx = potential_links[i][1]
                paper1_id = idx_to_paper_id[paper1_idx]
                paper2_id = idx_to_paper_id[paper2_idx]

                discovered_links.append({
                    'paper1': {
                        'id': paper1_id,
                        'title': papers_dict[paper1_id]['title'],
                        'topic': papers_dict[paper1_id]['topic']
                    },
                    'paper2': {
                        'id': paper2_id,
                        'title': papers_dict[paper2_id]['title'],
                        'topic': papers_dict[paper2_id]['topic']
                    },
                    'probability': prob.item()
                })

        # Sort by probability in descending order
        discovered_links.sort(key=lambda x: x['probability'], reverse=True)

        return discovered_links

def split_edges(edge_index, val_ratio=0.1, test_ratio=0.1, message_ratio=0.5):
    """
    Split edges into training message edges, training supervision edges, validation edges, and test edges.
    message_ratio determines split between message and supervision edges within training set.
    """
    num_edges = edge_index.size(1)
    num_val = int(num_edges * val_ratio)
    num_test = int(num_edges * test_ratio)
    num_train = num_edges - (num_val + num_test)
    num_message = int(num_train * message_ratio)

    # Randomly shuffle edges
    perm = torch.randperm(num_edges)

    # Split indices
    message_idx = perm[:num_message]
    supervision_idx = perm[num_message:num_train]
    val_idx = perm[num_train:num_train+num_val]
    test_idx = perm[num_train+num_val:]

    # Create edge sets
    message_edges = edge_index[:, message_idx]
    supervision_edges = edge_index[:, supervision_idx]
    val_edges = edge_index[:, val_idx]
    test_edges = edge_index[:, test_idx]

    return message_edges, supervision_edges, val_edges, test_edges

def calculate_metrics(pred, target):
    """Calculate accuracy, precision, recall, and F1 score."""
    from sklearn.metrics import accuracy_score, precision_recall_fscore_support

    pred_binary = (pred > 0.5).cpu().numpy()
    target = target.cpu().numpy()

    accuracy = accuracy_score(target, pred_binary)
    precision, recall, f1, _ = precision_recall_fscore_support(target, pred_binary, average='binary')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

def train_and_evaluate(model, data, message_edges, supervision_edges, val_edges, test_edges, epochs, paper_id_to_idx):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss()

    def evaluate(message_edge_set, pos_edge_set, name=""):
        model.eval()
        with torch.no_grad():
            z = model.encode(data.x, message_edge_set)

            # Generate negative edges
            neg_edge_set = torch.randint(0, len(paper_id_to_idx), (2, pos_edge_set.size(1)))

            # Get predictions
            pos_pred = model.decode(z, pos_edge_set).sigmoid()
            neg_pred = model.decode(z, neg_edge_set).sigmoid()

            # Combine predictions and create labels
            pred = torch.cat([pos_pred, neg_pred])
            target = torch.cat([torch.ones_like(pos_pred), torch.zeros_like(neg_pred)])

            # Calculate metrics
            metrics = calculate_metrics(pred, target)

            print(f"\n{name} Metrics:")
            print(f"Accuracy: {metrics['accuracy']:.4f}")
            print(f"Precision: {metrics['precision']:.4f}")
            print(f"Recall: {metrics['recall']:.4f}")
            print(f"F1: {metrics['f1']:.4f}")

            return metrics

    best_val_f1 = 0
    best_model = None

    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        # Training step using message edges to predict supervision edges
        z = model.encode(data.x, message_edges)

        # Generate negative edges for training
        neg_edges = torch.randint(0, len(paper_id_to_idx), (2, supervision_edges.size(1)))

        pos_pred = model.decode(z, supervision_edges)
        neg_pred = model.decode(z, neg_edges)

        loss = criterion(pos_pred, torch.ones_like(pos_pred)) + criterion(neg_pred, torch.zeros_like(neg_pred))

        loss.backward()
        optimizer.step()

        if (epoch + 1) % 10 == 0:
            print(f"\nEpoch {epoch+1:03d}:")

            # Training evaluation (using message edges to predict supervision edges)
            train_metrics = evaluate(message_edges, supervision_edges, "Training")

            # Validation evaluation (using message + supervision edges to predict validation edges)
            combined_train_edges = torch.cat([message_edges, supervision_edges], dim=1)
            val_metrics = evaluate(combined_train_edges, val_edges, "Validation")

            if val_metrics['f1'] > best_val_f1:
                best_val_f1 = val_metrics['f1']
                best_model = copy.deepcopy(model)

    # Final evaluation on test set using best model
    print("\nFinal Test Set Evaluation:")
    all_train_edges = torch.cat([message_edges, supervision_edges, val_edges], dim=1)
    test_metrics = evaluate(all_train_edges, test_edges, "Test")

    return best_model, test_metrics

import networkx as nx
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from tqdm import tqdm

def create_visualization(data, papers_dict, topics_dict, message_edges, supervision_edges,
                        val_edges, test_edges, model, paper_id_to_idx, topic_id_to_idx,
                        threshold=0.7, num_frames=50):
    """
    Creates a GIF showing the evolution of the graph as new edges are predicted.
    """
    # Create reverse mappings
    idx_to_paper = {idx: pid for pid, idx in paper_id_to_idx.items()}
    idx_to_topic = {idx: tid for tid, idx in topic_id_to_idx.items()}

    # Create base graph
    G = nx.Graph()

    # Add paper nodes
    for idx, paper_id in idx_to_paper.items():
        G.add_node(idx,
                  type='paper',
                  title=papers_dict[paper_id]['title'],
                  topic=papers_dict[paper_id]['topic'])

    # Add topic nodes
    for idx, topic_id in idx_to_topic.items():
        G.add_node(idx + len(paper_id_to_idx),
                  type='topic',
                  name=topics_dict[topic_id])

    # Function to get edge colors based on type
    def get_edge_colors(G):
        colors = []
        for u, v in G.edges():
            if G[u][v].get('edge_type') == 'message':
                colors.append('blue')
            elif G[u][v].get('edge_type') == 'supervision':
                colors.append('green')
            elif G[u][v].get('edge_type') == 'validation':
                colors.append('orange')
            elif G[u][v].get('edge_type') == 'test':
                colors.append('red')
            elif G[u][v].get('edge_type') == 'predicted':
                colors.append('purple')
            else:
                colors.append('gray')
        return colors

    def get_node_colors(G):
        colors = []
        for node in G.nodes():
            if G.nodes[node]['type'] == 'topic':
                colors.append('yellow')
            else:
                topic = G.nodes[node]['topic']
                # Generate a unique color for each topic
                hash_val = hash(topic)
                r = (hash_val & 0xFF) / 255.0
                g = ((hash_val >> 8) & 0xFF) / 255.0
                b = ((hash_val >> 16) & 0xFF) / 255.0
                colors.append([r, g, b, 0.7])
        return colors

    # Initialize plot
    fig, ax = plt.subplots(figsize=(15, 15))

    def update(frame):
        ax.clear()

        # Add edges progressively
        if frame == 0:
            # Add message edges
            for i in range(message_edges.size(1)):
                src, dst = message_edges[:, i].tolist()
                G.add_edge(src, dst, edge_type='message')

        elif frame == 1:
            # Add supervision edges
            for i in range(supervision_edges.size(1)):
                src, dst = supervision_edges[:, i].tolist()
                G.add_edge(src, dst, edge_type='supervision')

        elif frame == 2:
            # Add validation edges
            for i in range(val_edges.size(1)):
                src, dst = val_edges[:, i].tolist()
                G.add_edge(src, dst, edge_type='validation')

        elif frame == 3:
            # Add test edges
            for i in range(test_edges.size(1)):
                src, dst = test_edges[:, i].tolist()
                G.add_edge(src, dst, edge_type='test')

        else:
            # Predict new edges
            with torch.no_grad():
                z = model.encode(data.x, data['paper', 'cites', 'paper'].edge_index)
                z_papers = z[:len(paper_id_to_idx)]

                # Sample some random paper pairs
                num_predictions = 10
                paper_indices = list(range(len(paper_id_to_idx)))
                pairs = []
                for _ in range(num_predictions):
                    i, j = np.random.choice(paper_indices, 2, replace=False)
                    pairs.append([i, j])

                pairs = torch.tensor(pairs, dtype=torch.long).t()
                predictions = model.decode(z_papers, pairs).sigmoid()

                # Add high probability edges
                for i, prob in enumerate(predictions):
                    if prob > threshold:
                        src, dst = pairs[:, i].tolist()
                        G.add_edge(src, dst, edge_type='predicted')

        # Draw the graph
        pos = nx.spring_layout(G, k=1/np.sqrt(G.number_of_nodes()), iterations=50)

        # Draw nodes
        nx.draw_networkx_nodes(G, pos,
                             node_color=get_node_colors(G),
                             node_size=100)

        # Draw edges
        nx.draw_networkx_edges(G, pos,
                             edge_color=get_edge_colors(G),
                             width=1.0,
                             alpha=0.5)

        # Add title
        if frame == 0:
            plt.title("Message Edges (Training)", fontsize=16)
        elif frame == 1:
            plt.title("Supervision Edges Added", fontsize=16)
        elif frame == 2:
            plt.title("Validation Edges Added", fontsize=16)
        elif frame == 3:
            plt.title("Test Edges Added", fontsize=16)
        else:
            plt.title(f"Predicted Edges (Frame {frame})", fontsize=16)

        # Add legend
        legend_elements = [
            plt.Line2D([0], [0], color='blue', label='Message'),
            plt.Line2D([0], [0], color='green', label='Supervision'),
            plt.Line2D([0], [0], color='orange', label='Validation'),
            plt.Line2D([0], [0], color='red', label='Test'),
            plt.Line2D([0], [0], color='purple', label='Predicted')
        ]
        ax.legend(handles=legend_elements, loc='upper right')

        plt.axis('off')

    # Create animation
    anim = FuncAnimation(fig, update, frames=num_frames, interval=500, repeat=True)

    # Save as GIF
    writer = PillowWriter(fps=2)
    anim.save('graph_evolution.gif', writer=writer)
    plt.close()

def main():
    # Load data from OpenAlex
    print("Fetching data from OpenAlex...")
    papers_dict, topics_dict = fetch_openalex_data(degree_limit=6)

    # Create graph data
    print("Creating heterogeneous graph...")
    graph_data, paper_id_to_idx, topic_id_to_idx = create_heterogeneous_graph(papers_dict, topics_dict)

    # Ensure we have enough edges for meaningful splits
    edge_index = graph_data['paper', 'cites', 'paper'].edge_index
    num_edges = edge_index.size(1)

    if num_edges < 20:  # Minimum threshold for meaningful splitting
        print(f"Warning: Only {num_edges} edges found. Need more data for meaningful evaluation.")
        return

    print(f"Total number of papers: {len(papers_dict)}")
    print(f"Total number of topics: {len(topics_dict)}")
    print(f"Total number of citation edges: {num_edges}")

    # Split edges
    print("\nSplitting edges...")
    message_edges, supervision_edges, val_edges, test_edges = split_edges(
        edge_index,
        val_ratio=0.1,
        test_ratio=0.1,
        message_ratio=0.5
    )

    print(f"Number of message edges: {message_edges.size(1)}")
    print(f"Number of supervision edges: {supervision_edges.size(1)}")
    print(f"Number of validation edges: {val_edges.size(1)}")
    print(f"Number of test edges: {test_edges.size(1)}")

    # Initialize model
    print("\nInitializing model...")
    model = HeteroGCNLinkPredictor(
        feature_dim=graph_data.x.size(1),
        hidden_channels=64
    )

    # Train and evaluate model
    print("\nStarting training and evaluation...")
    best_model, test_metrics = train_and_evaluate(
        model=model,
        data=graph_data,
        message_edges=message_edges,
        supervision_edges=supervision_edges,
        val_edges=val_edges,
        test_edges=test_edges,
        epochs=100,
        paper_id_to_idx=paper_id_to_idx
    )

    # Print final test metrics
    print("\nFinal Test Set Performance:")
    print(f"Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Precision: {test_metrics['precision']:.4f}")
    print(f"Recall: {test_metrics['recall']:.4f}")
    print(f"F1: {test_metrics['f1']:.4f}")

    # Optionally: Use model to predict new cross-topic links
    print("\nDiscovering potential cross-topic links...")
    discovered_links = discover_cross_topic_links(
        best_model,
        graph_data,
        papers_dict,
        paper_id_to_idx,
        threshold=0.8
    )

    # Print top discovered links
    if discovered_links:
        print("\nTop 5 Discovered Cross-Topic Links:")
        for i, link in enumerate(discovered_links[:5], 1):
            print(f"\n{i}. Probability: {link['probability']:.4f}")
            print(f"Paper 1 ({topics_dict[link['paper1']['topic']]}): {link['paper1']['title']}")
            print(f"Paper 2 ({topics_dict[link['paper2']['topic']]}): {link['paper2']['title']}")
    else:
        print("No high-probability cross-topic links discovered.")

    # print("\nGenerating visualization...")
    # create_visualization(
    #     graph_data, papers_dict, topics_dict,
    #     message_edges, supervision_edges, val_edges, test_edges,
    #     best_model, paper_id_to_idx, topic_id_to_idx
    # )
    # print("Visualization saved as 'graph_evolution.gif'")

main()

Fetching data from OpenAlex...


KeyboardInterrupt: 

In [None]:
# graph_data, paper_id_to_idx, topic_id_to_idx = create_heterogeneous_graph(papers_dict, topics_dict)

# # Train models for same-topic and different-topic link prediction
# same_topic_model = train_link_prediction(graph_data, papers_dict, paper_id_to_idx, mode='same_topic')
# different_topic_model = train_link_prediction(graph_data, papers_dict, paper_id_to_idx, mode='different_topic')


In [None]:

# graph_data, paper_id_to_idx, topic_id_to_idx = create_heterogeneous_graph(papers_dict, topics_dict)

# # Train models
# different_topic_model, train_metrics, val_metrics = train_link_prediction(
#     graph_data, papers_dict, paper_id_to_idx, mode='different_topic'
# )

# # Discover new cross-topic links
# discovered_links = discover_cross_topic_links(
#     different_topic_model, graph_data, papers_dict, paper_id_to_idx
# )

# print("\nTop 5 Discovered Cross-Topic Links:")
# for link in discovered_links[:5]:
#     print(f"\nProbability: {link['probability']:.4f}")
#     print(f"Paper 1 ({link['paper1']['topic']}): {link['paper1']['title']}")
#     print(f"Paper 2 ({link['paper2']['topic']}): {link['paper2']['title']}")

In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.animation import FuncAnimation, PillowWriter
import matplotlib.patches as patches

def create_initial_graph():
    """Create initial heterogeneous graph with two node types and edge types"""
    G = nx.Graph()

    # Add nodes of type 1 (e.g., users)
    for i in range(5):
        G.add_node(f'U{i}', type='user', pos=(np.random.random(), np.random.random()))

    # Add nodes of type 2 (e.g., items)
    for i in range(5):
        G.add_node(f'I{i}', type='item', pos=(np.random.random(), np.random.random()))

    # Add initial edges of type 1 (e.g., user-item interactions)
    initial_edges = [
        ('U0', 'I1'), ('U1', 'I2'), ('U2', 'I0'),
        ('U3', 'I3'), ('U4', 'I4')
    ]
    G.add_edges_from(initial_edges, type='interaction')

    # Add initial edges of type 2 (e.g., user-user relationships)
    user_edges = [
        ('U0', 'U1'), ('U2', 'U3'), ('U3', 'U4')
    ]
    G.add_edges_from(user_edges, type='friendship')

    return G

def predict_new_edges(G, frame):
    """Simulate edge prediction by adding new edges based on frame number"""
    predictions = [
        ('U0', 'I3', 'interaction'),
        ('U1', 'I4', 'interaction'),
        ('U1', 'U3', 'friendship'),
        ('U2', 'I4', 'interaction'),
        ('U0', 'U4', 'friendship')
    ]

    if frame < len(predictions):
        new_edge = predictions[frame]
        G.add_edge(new_edge[0], new_edge[1], type=new_edge[2])

    return G

def update(frame, G, ax):
    """Update function for animation"""
    ax.clear()

    # Update graph with new predicted edges
    G = predict_new_edges(G, frame)

    # Get node positions
    pos = nx.get_node_attributes(G, 'pos')

    # Draw nodes by type
    user_nodes = [node for node, attr in G.nodes(data=True) if attr['type'] == 'user']
    item_nodes = [node for node, attr in G.nodes(data=True) if attr['type'] == 'item']

    nx.draw_networkx_nodes(G, pos, nodelist=user_nodes, node_color='lightblue',
                          node_size=500)
    nx.draw_networkx_nodes(G, pos, nodelist=item_nodes, node_color='lightgreen',
                          node_shape='s', node_size=500)

    # Draw edges by type
    interaction_edges = [(u, v) for (u, v, d) in G.edges(data=True) if d['type'] == 'interaction']
    friendship_edges = [(u, v) for (u, v, d) in G.edges(data=True) if d['type'] == 'friendship']

    nx.draw_networkx_edges(G, pos, edgelist=interaction_edges, edge_color='red',
                          width=2)
    nx.draw_networkx_edges(G, pos, edgelist=friendship_edges, edge_color='blue',
                          style='dashed', width=2)

    # Draw labels
    nx.draw_networkx_labels(G, pos)

    # Add title and legend
    # ax.set_title(f'Heterogeneous Graph Evolution (Step {frame+1})')
    # ax.legend(loc='upper left', bbox_to_anchor=(1, 1))

    # Set limits and remove axes
    ax.set_xlim(-0.1, 1.1)
    ax.set_ylim(-0.1, 1.1)
    ax.axis('off')

def create_graph_animation():
    """Create and save the graph animation as a GIF"""
    # Create initial graph
    G = create_initial_graph()

    # Set up the figure
    fig, ax = plt.subplots(figsize=(10, 8))
    plt.subplots_adjust(right=0.85)

    # Create animation
    frames = 6  # Number of prediction steps + 1
    anim = FuncAnimation(fig, update, frames=frames, fargs=(G, ax),
                        interval=1000, repeat=True)

    # Save as GIF
    writer = PillowWriter(fps=1)
    anim.save('hetero_graph_evolution.gif', writer=writer)
    plt.close()

# Generate the animation
create_graph_animation()