# ML with Graphs

## Cross Domain Knowledge Discovery

- Harish Varma Siravuri
- Saurav Mukhopadhyay

In [3]:
import requests
import networkx as nx
import pandas as pd
import numpy as np
import torch
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.loader import LinkNeighborLoader
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm.notebook import tqdm
from torch_geometric.nn import HeteroConv, SAGEConv, Linear
from torch_geometric.loader import NeighborLoader

In [4]:
def get_abstract(inverted_index):
    if inverted_index is not None:
        abstract_words = []
        max_index = max(max(positions) for positions in inverted_index.values())
        abstract_words = [None] * (max_index + 1)
        for word, positions in inverted_index.items():
            for position in positions:
                abstract_words[position] = word

        abstract = " ".join(
            [word if word is not None else "" for word in abstract_words]
        )
        return abstract
    else:
        return "Unavailable"


def fetch_openalex_data(paper_limit: int = 300, degree_limit: int = 1):
    base_url = "https://api.openalex.org/works"
    response = requests.get(base_url)

    seed_works = response.json().get("results")
    ids = [
        work["id"].replace("openalex.org", "api.openalex.org/works")
        for work in seed_works
    ]

    # Remove after testing
    # ids = [ids[0]]

    papers_dict = {}
    topics_dict = {}
    degree = 0

    while degree < degree_limit:
        degree += 1
        new_ids = []

        for pid in ids:
            try:
                response = requests.get(pid).json()
                new_ids.append(response["id"].replace("openalex.org", "api.openalex.org/works"))
                new_ids.extend([ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["referenced_works"]])
                new_ids.extend([ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["related_works"]])
                cited_by_url = response["cited_by_api_url"]
                cited_by_papers = requests.get(cited_by_url).json().get("results")
                new_ids.extend([work["id"].replace("openalex.org", "api.openalex.org/works") for work in cited_by_papers])

                dic_topic = response.get('primary_topic')
                dic_title = response.get("title")
                dic_abstract = response.get("abstract_inverted_index")
                topic = ""
                title = ""
                abstract = ""
                if dic_topic is not None:
                    topic = dic_topic.get("id", "")
                if dic_title is not None:
                    title = dic_title
                if dic_abstract is not None:
                    abstract = get_abstract(dic_abstract)

                papers_dict[pid] = {
                    "id": pid,
                    "citation_count": len(cited_by_papers),
                    "topic": topic,
                    "title": title,
                    "abstract": abstract,
                    "cites": [ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["referenced_works"]]
                }
                if degree == degree_limit:
                    for work in [ref_work.replace("openalex.org", "api.openalex.org/works") for ref_work in response["referenced_works"]]:
                        res = requests.get(work).json()
                        if res is not None:
                            dic_topic = res.get('primary_topic')
                            dic_title = res.get("title")
                            dic_abstract = res.get("abstract_inverted_index")
                            topic = ""
                            title = ""
                            abstract = ""
                            if dic_topic is not None:
                                topic = dic_topic.get("id", "")
                            if dic_title is not None:
                                title = dic_title
                            if dic_abstract is not None:
                                abstract = get_abstract(dic_abstract)
                            papers_dict[work] = {
                                "id": work,
                                "citation_count": 0,
                                "topic": topic,
                                "title": title,
                                "abstract": abstract,
                                "cites": []
                            }
                            if dic_topic is not None:
                                topics_dict[res["primary_topic"]["id"]] = res["primary_topic"]["display_name"]
                            else:
                                topics_dict[""] = ""
                if topic == "":
                    topics_dict[""] = ""
                else:
                    topics_dict[response["primary_topic"]["id"]] = response["primary_topic"]["display_name"]

            except Exception as e:
                print(f"Error fetching data for {pid}: {e}")
                continue

        ids = new_ids
    return papers_dict, topics_dict

In [5]:
papers_dict, topics_dict = fetch_openalex_data(degree_limit=1)

In [6]:
import torch
import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, SAGEConv
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
import numpy as np

def create_heterogeneous_graph(papers_dict, topics_dict):
    # Collect unique topics
    unique_topics = set(paper_info['topic'] for paper_info in papers_dict.values() if paper_info['topic'])

    # Create mappings
    paper_ids = list(papers_dict.keys())
    topic_ids = list(unique_topics)

    # Create index mappings
    paper_id_to_idx = {pid: idx for idx, pid in enumerate(paper_ids)}
    topic_id_to_idx = {tid: idx + len(paper_ids) for idx, tid in enumerate(topic_ids)}

    # Use SentenceTransformer for text embeddings
    text_embedder = SentenceTransformer('all-MiniLM-L6-v2')

    # Prepare node features
    total_nodes = len(paper_ids) + len(topic_ids)
    feature_dim = 384  # SentenceTransformer embedding dimension
    x = torch.zeros(total_nodes, feature_dim)

    # Populate paper features
    for paper_id in paper_ids:
        paper_info = papers_dict[paper_id]
        text = f"{paper_info['title']} {paper_info['abstract']}"
        embedding = text_embedder.encode(text, convert_to_tensor=True)
        x[paper_id_to_idx[paper_id]] = embedding

    # Populate topic features
    for topic_id in topic_ids:
        topic_name = topics_dict.get(topic_id, topic_id)
        embedding = text_embedder.encode(topic_name, convert_to_tensor=True)
        x[topic_id_to_idx[topic_id]] = embedding

    # Prepare edge indices
    citation_src, citation_dst = [], []
    topic_paper_src, topic_paper_dst = [], []

    # Create citation edges
    for paper_id in paper_ids:
        paper_info = papers_dict[paper_id]
        src_idx = paper_id_to_idx[paper_id]
        for cited_paper in paper_info['cites']:
            if cited_paper in paper_id_to_idx:
                dst_idx = paper_id_to_idx[cited_paper]
                citation_src.append(src_idx)
                citation_dst.append(dst_idx)

    # Create topic-paper edges
    for paper_id in paper_ids:
        paper_info = papers_dict[paper_id]
        if paper_info['topic'] in topic_id_to_idx:
            paper_idx = paper_id_to_idx[paper_id]
            topic_idx = topic_id_to_idx[paper_info['topic']]

            # Bidirectional edges
            topic_paper_src.append(paper_idx)
            topic_paper_dst.append(topic_idx)
            topic_paper_src.append(topic_idx)
            topic_paper_dst.append(paper_idx)

    # Create edge indices
    citation_edge_index = torch.tensor([citation_src, citation_dst], dtype=torch.long)
    topic_paper_edge_index = torch.tensor([topic_paper_src, topic_paper_dst], dtype=torch.long)

    # Create HeteroData
    data = HeteroData()
    data['paper', 'cites', 'paper'].edge_index = citation_edge_index
    data['paper', 'has_topic', 'topic'].edge_index = topic_paper_edge_index
    data['topic', 'has_paper', 'paper'].edge_index = topic_paper_edge_index.flip(0)
    data.x = x

    return data, paper_id_to_idx, topic_id_to_idx

class HeteroGCNLinkPredictor(torch.nn.Module):
    def __init__(self, feature_dim, hidden_channels=64):
        super().__init__()

        # Graph convolution layers
        self.conv1 = GCNConv(feature_dim, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)

        # Link prediction MLP
        self.link_predictor = torch.nn.Sequential(
            torch.nn.Linear(hidden_channels * 2, hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_channels, 1)
        )

    def encode(self, x, edge_index):
        # Graph convolution
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

    def decode(self, z, edge_label_index):
        row, col = edge_label_index
        z_src = z[row]
        z_dst = z[col]

        # Concatenate source and destination node embeddings
        z_combined = torch.cat([z_src, z_dst], dim=-1)

        # Predict link probability
        return self.link_predictor(z_combined).squeeze()

def train_link_prediction(data, papers_dict, paper_id_to_idx, mode='same_topic', epochs=100):
    # Prepare edges based on mode
    link_edges = []
    for paper1_id, paper1_info in papers_dict.items():
        for paper2_id, paper2_info in papers_dict.items():
            if paper1_id != paper2_id and paper1_id in paper_id_to_idx and paper2_id in paper_id_to_idx:
                # Same topic link prediction
                if mode == 'same_topic' and paper1_info['topic'] == paper2_info['topic']:
                    link_edges.append([
                        paper_id_to_idx[paper1_id],
                        paper_id_to_idx[paper2_id]
                    ])
                # Different topic link prediction
                elif mode == 'different_topic' and paper1_info['topic'] != paper2_info['topic']:
                    link_edges.append([
                        paper_id_to_idx[paper1_id],
                        paper_id_to_idx[paper2_id]
                    ])

    # Check if we have enough edges
    if not link_edges:
        print(f"No {mode} edges found. Skipping training.")
        return None

    # Convert to tensor
    edge_index = torch.tensor(link_edges).t()

    # Prepare data splits
    num_val = max(int(0.1 * len(link_edges)), 1)
    num_test = max(int(0.1 * len(link_edges)), 1)

    # Split edges
    perm = torch.randperm(edge_index.size(1))
    train_idx = perm[num_val + num_test:]
    val_idx = perm[:num_val]
    test_idx = perm[num_val:num_val+num_test]

    train_edge_index = edge_index[:, train_idx]

    # Initialize model
    model = HeteroGCNLinkPredictor(feature_dim=data.x.size(1))

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    criterion = torch.nn.BCEWithLogitsLoss()

    # Training loop
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        # Forward pass
        z = model.encode(data.x, data['paper', 'cites', 'paper'].edge_index)

        # Restrict z to paper nodes (first len(paper_ids) nodes)
        z_papers = z[:len(paper_id_to_idx)]

        # Generate negative samples (random sampling)
        neg_edge_index = torch.randint(0, len(paper_id_to_idx), (2, train_edge_index.size(1)))

        # Positive and negative link predictions
        pos_link_pred = model.decode(z_papers, train_edge_index)
        neg_link_pred = model.decode(z_papers, neg_edge_index)

        # Compute loss
        pos_loss = criterion(pos_link_pred, torch.ones_like(pos_link_pred))
        neg_loss = criterion(neg_link_pred, torch.zeros_like(neg_link_pred))
        loss = pos_loss + neg_loss

        # Backward pass
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 10 == 0:
            print(f'{mode.capitalize()} Topic Link Prediction - Epoch {epoch+1:03d}, Loss: {loss:.4f}')

    return model

In [7]:
graph_data, paper_id_to_idx, topic_id_to_idx = create_heterogeneous_graph(papers_dict, topics_dict)

# Train models for same-topic and different-topic link prediction
same_topic_model = train_link_prediction(graph_data, papers_dict, paper_id_to_idx, mode='same_topic')
different_topic_model = train_link_prediction(graph_data, papers_dict, paper_id_to_idx, mode='different_topic')


AttributeError: type object 'torch._C.Tag' has no attribute 'pt2_compliant_tag'