In [None]:
import pandas as pd
import numpy as np
import sqlite3
import torch
import regex as re
import requests
import json
import spacy
import matplotlib.pyplot as plt
import pickle
import networkx as nx
from tqdm.auto import tqdm
import pickle
from sklearn.metrics.pairwise import cosine_similarity
from itertools import combinations
from collections import defaultdict
import matplotlib.pyplot as plt
from networkx.algorithms.community import greedy_modularity_communities
from networkx.algorithms.community.quality import modularity
import requests
import time
from torch_geometric.data import Data
import torch
from sklearn.preprocessing import OneHotEncoder, StandardScaler
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
import torch.optim as optim
from sklearn.metrics import roc_auc_score, precision_score, recall_score, average_precision_score
import random
from spacy.lang.en.stop_words import STOP_WORDS
import joblib
from bertopic import BERTopic
from sklearn.feature_extraction.text import TfidfVectorizer
from umap import UMAP
from matplotlib.lines import Line2D

# Graph Creation

## Load Dataframes

In [None]:
con = sqlite3.connect("docs_with_subtopics.db")

df = pd.read_sql_query("SELECT * from docs_with_subtopics", con)
con.close()

df['topic_list'] = df['topic_list_str'].apply(json.loads)

df

In [None]:
con = sqlite3.connect("topic_counts.db")

topic_counts = pd.read_sql_query("SELECT * from topic_counts", con)
con.close()

topic_counts

## Check Subtopic Co-Occurrence

In [None]:
binary_topic_matrix = (topic_counts
                      .pivot(index='title', columns='topic_number', values='topic_count')
                      .fillna(0)
                      .astype(int)
)
binary_topic_matrix = (binary_topic_matrix > 0).astype(int)

co_occurrence_matrix = binary_topic_matrix.T.dot(binary_topic_matrix)

In [None]:
co_occurrence_matrix

In [None]:
co_occurrence_matrix.index.name = None
co_occurrence_matrix.columns.name = None

co_occur_pairs = (
    co_occurrence_matrix.stack()
    .reset_index()
)

co_occur_pairs.columns = ['topic_1','topic_2','count']

co_occur_pairs = co_occur_pairs[co_occur_pairs['topic_1'] < co_occur_pairs['topic_2']]

top_pairs = co_occur_pairs.sort_values("count", ascending=False)

top_pairs.head(10)

In [None]:
row_sums = co_occurrence_matrix.sum(axis=1)

self_counts = np.diag(co_occurrence_matrix.values)

co_occurs_with_others = row_sums - self_counts

isolated_topics = co_occurs_with_others[co_occurs_with_others==0].index.tolist()

print("Topics with no co-occurrence:", isolated_topics)

In [None]:
print(co_occurs_with_others.sort_values(ascending=False).head(20))

In [None]:
def verify_co_occ_counts(topic_counts, binary_matrix, co_occurrence_series):
    unique_topics = topic_counts['topic_number'].unique()
    mismatches = []
    
    for topic in tqdm(unique_topics, desc='Verifying topics'):
        if topic not in binary_matrix.columns:
            print(f"topic {topic} not in binary matrix. Skipping")
            continue
        
        
        docs = binary_matrix[binary_matrix[topic]==1]
        
        manual_total = (docs.sum(axis=1) - 1).sum()
        
        stored_value = co_occurrence_series.get(topic,None)
        
        if stored_value != manual_total:
            mismatches.append({
                "topic": topic,
                "stored": stored_value,
                "manual": manual_total,
                "difference" : manual_total-stored_value
            })
    
    return mismatches

In [None]:
mismatches = verify_co_occ_counts(topic_counts, binary_topic_matrix, co_occurs_with_others)

mismatches_df = pd.DataFrame(mismatches)

print(mismatches_df.head())

## Prepare Graph Inputs

### Subtopic Embedding Cosine Similarity Matrix

In [None]:
with open("subtopic_embeddings.pkl", "rb") as f:
    subtopic_embeddings = pickle.load(f)

In [None]:
embedding_matrix = np.stack([emb.cpu().numpy() for emb in subtopic_embeddings.values()])

subtopic_ids = list(subtopic_embeddings.keys())
similarity_matrix = cosine_similarity(embedding_matrix)

In [None]:
similarity_df = pd.DataFrame(similarity_matrix, index=subtopic_ids, columns=subtopic_ids)
similarity_df

In [None]:
mean_value = similarity_df.mean(axis=None)
print(mean_value)

### Compute Directed Edge Weights

In [None]:
total_topic_counts = co_occurrence_matrix.values.diagonal()
topic_labels = co_occurrence_matrix.index.tolist()

topic_count_dict = dict(zip(topic_labels, total_topic_counts))

In [None]:
edges = []

for i, topic_i in enumerate(topic_labels):
    for j, topic_j in enumerate(topic_labels):
        if i == j:
            continue # Skip Diagonals
        
        co_occur = co_occurrence_matrix.loc[topic_i,topic_j]
        count_i = topic_count_dict[topic_i]
        count_j = topic_count_dict[topic_j]
        
        if count_i == 0 or count_j == 0:
            continue
            
        # Directional Probabilities
        p_ij = co_occur / count_i
        p_ji = co_occur / count_j
        
        # Keep only stronger direction
        if p_ij > p_ji:
            edges.append((topic_i, topic_j, p_ij))
        elif p_ij < p_ji:
            edges.append((topic_i, topic_j, p_ji))
        else: # Bidirectional edge if probabilities are equal
            edges.append((topic_i, topic_j, p_ji))
            edges.append((topic_i, topic_j, p_ij))

## Graph Construction - MultiDiGraph

In [None]:
graph_df = df[['title', 'cluster', 'topic_list']]
similarity_threshold = 0.5
G = nx.MultiDiGraph()

#Nodes
for _, row in graph_df.iterrows():
    G.add_node(
        row['title'],
        cluster=row['cluster'],
        subtopics=row['topic_list']
    )
    
# Subtopic > Document mapping
topic_to_docs = defaultdict(set)
for _, row in graph_df.iterrows():
    for topic in row['topic_list']:
        topic_to_docs[topic].add(row['title'])
        
# Lookup for edge weights
subtopic_edge_dict = {
    (src, tgt): weight
    for src, tgt, weight in edges
}


titles = graph_df['title'].tolist()

for doc1, doc2 in combinations(graph_df['title'], 2):
    subs1 = set(G.nodes[doc1]['subtopics'])
    subs2 = set(G.nodes[doc2]['subtopics'])
    
    for s1 in subs1:
        for s2 in subs2:
            if s1 == s2:
                continue
                
                
            try:
                sim = similarity_df.loc[s1,s2]
                if sim < similarity_threshold:
                    continue
            except KeyError:
                continue # Skip missing value
            
            # Check Direction
            if (s1, s2) in subtopic_edge_dict:
                weight = subtopic_edge_dict[(s1,s2)]
                G.add_edge(doc1, doc2, weight=weight, sim=sim, subtopic_pair=(s1, s2))
            elif (s2, s1) in subtopic_edge_dict:
                weight = subtopic_edge_dict[(s2, s1)]
                G.add_edge(doc2, doc1, weight=weight, sim=sim, subtopic_pair=(s2,s1))

In [None]:
print("Number of nodes:", G.number_of_nodes())
print("Number of edges:", G.number_of_edges())

## Comparing clustering in graph to document K-means clustering

In [None]:
kmeans_labels = {node: G.nodes[node]['cluster'] for node in G.nodes()}

In [None]:
import community as co

undirected_G = nx.Graph()
undirected_G.add_nodes_from(G.nodes(data=True))

for u, v, data in G.edges(data=True):
    current_weight = undirected_G.get_edge_data(u, v, default={'weight':0}).get('weight', 0)
    undirected_G.add_edge(u, v, weight=current_weight + data.get('weight', 1))
    
print(f"Converted to undirected Graph with {undirected_G.number_of_nodes()} nodes and {undirected_G.number_of_edges()} edges.")

partition = co.best_partition(undirected_G, weight='weight')

graph_clustering_labels = partition

from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score

kmeans_labels = {node: G.nodes[node]['cluster'] for node in G.nodes()}

common_nodes = sorted(list(set(kmeans_labels.keys()) & set(graph_clustering_labels.keys())))

if common_nodes:
    kmeans_labels_ordered = [kmeans_labels[node] for node in common_nodes]
    graph_labels_ordered = [graph_clustering_labels[node] for node in common_nodes]
    
    ari_score = adjusted_rand_score(kmeans_labels_ordered, graph_labels_ordered)
    nmi_score = normalized_mutual_info_score(kmeans_labels_ordered, graph_labels_ordered)
    print(f"\nAdjusted Rand Index (ARI) between K-Means and Louvain: {ari_score:.3f}")
    print(f"Normalized Mutual Information (NMI) between K-Means and Louvain: {nmi_score:.3f}")
else:
    print("\nNo common nodes for comparison. Something might have gone wrong with graph construction or clustering.")
    

In [None]:
components = list(nx.connected_components(G.to_undirected()))
print(f"Number of connected components: {len(components)}")

for i, comp in enumerate(components):
    print(f"Component {i}: {len(comp)} nodes")


## Measuring Performance

### Modularity

In [None]:
G_undirected = G.to_undirected()

communities = list(greedy_modularity_communities(G_undirected))

mod_score = modularity(G_undirected, communities)

print(f"Modularity score: {mod_score:.4f}")
print(f"Detected Communities: {len(communities)}")

In [None]:
pagerank_scores = nx.pagerank(G, weight="weight")
top_pagerank = sorted(pagerank_scores.items(), key=lambda x: x[1], reverse=True)[:10]
for node, score in top_pagerank:
    print(f"{node}: {score:.4f}")

In [None]:
hits_hubs, hits_authorities = nx.hits(G_largest, max_iter=1000)

print("Authority Nodes")
for node, score in sorted(hits_authorities.items(), key=lambda x: x[1], reverse=True)[:10]:
    print(f"{node}: {score:.4f}")
    
print("Hub Nodes")
for node, score in sorted(hits_hubs.items(), key=lambda x: x[1], reverse=True)[:10]:
    print(f"{node}: {score:.4f}")

In [None]:
wcc = max(nx.weakly_connected_components(G_largest), key=len)
G_wcc = G_largest.subgraph(wcc).copy()

avg_path_len = nx.average_shortest_path_length(G_wcc.to_undirected())
diameter = nx.diameter(G_largest.to_undirected())
density = nx.density(G_largest)

print(f" Average Path Length: {avg_path_len:.4f}")
print(f" Diameter {diameter}")
print(f" Density: {density:.4f}")

In [None]:
num_weak_components = nx.number_weakly_connected_components(G)
num_strong_components = nx.number_strongly_connected_components(G)

print(f"Weakly connected components: {num_weak_components}")
print(f"Strongly connected components: {num_strong_components}")

## Hits@N / MRR

### Extracting links between articles in our database

In [None]:
def format_titles(title):
    return title.strip().replace(" ", "_").capitalize()

def get_backlinks(title, all_titles_set, sleep_time=0.5):
    backlinks = set()
    base_url = "https://en.wikipedia.org/w/api.php"
    formatted_title = format_titles(title)
    params = {
        "action" : "query",
        "list"   : "backlinks",
        "bltitle": formatted_title,
        "bllimit": "max",
        "format" : "json"
    }    
    
    while True:
        try:
            response = requests.get(base_url, params=params).json()
        except Exception as e:
            print(f"Error fetching {title}: {e}")
            return None
        
        if 'query' in response:
            links = [entry['title'] for entry in response['query']['backlinks']]
            filtered = [link for link in links if link in all_titles_set]
            backlinks.update(filtered)
        
        if 'continue' in response:
            params.update(response['continue'])
            time.sleep(sleep_time)
        else:
            break
            
    return backlinks if backlinks else set()

In [None]:
# Creating and Saving backlings
#Takes a while, onlty run once

#all_titles_set = set(df['title'].str.strip())
#
#ground_truth_links = {}
#
#for title in df['title']:
#    incoming = get_backlinks(title, all_titles_set)
#    ground_truth_links[title] = incoming
#    
#with open('ground_truth_links.pkl', 'wb') as f:
#    pickle.dump(ground_truth_links, f)

In [None]:
with open('ground_truth_links.pkl', 'rb') as f:
    ground_truth_links = pickle.load(f)

In [None]:
total_backlinks = sum(len(links) for links in ground_truth_links.values() if links is not None)
print(f"Total Backlinks: {total_backlinks}")

### Hits@N

In [None]:
predicted_incoming = defaultdict(list)

for src, tgt, key, data in G.edges(data=True, keys=True):
    weight = data.get('weight', 0.0)
    subtopic_pair = data.get('subtopic_pair')
    predicted_incoming[tgt].append((src,weight,subtopic_pair))

In [None]:
def compute_hits_at_n(ground_truth, predicted_incoming, N=10):
    hits = 0
    total = 0
    
    for tgt_doc, true_sources in ground_truth.items():
        if not true_sources:
            continue
            
        preds = predicted_incoming.get(tgt_doc, [])
        
        # Keep only best edge per unique source
        best_edges = {}
        for src, weight, _ in preds:
            if src not in best_edges or weight > best_edges[src]:
                best_edges[src] = weight
        
        top_n = sorted(best_edges.items(), key=lambda x: x[1], reverse=True)[:N]        
        top_n_preds = [src for src, _ in top_n]
        
        if any(p in true_sources for p in top_n_preds):
            hits += 1
        total += 1
        
    return hits / total if total else 0

In [None]:
def compute_mrr(ground_truth, predicted_incoming):
    reciprocal_ranks = []
    
    for tgt_doc, true_sources in ground_truth.items():
        if not true_sources:
            continue
            
        preds = predicted_incoming.get(tgt_doc, [])
        
        best_edges = {}
        for src, weight, _ in preds:
            if src not in best_edges or weight > best_edges[src]:
                best_edges[src] = weight        
        
        ranked_sources = sorted(best_edges.items(), key=lambda x: x[1], reverse=True)
        ranked_src_ids = [src for src, _ in ranked_sources]
        
        for rank, src in enumerate(ranked_src_ids, start=1):
            if src in true_sources:
                reciprocal_ranks.append(1/rank)
                break
        else:
            reciprocal_ranks.append(0)
    
    return sum(reciprocal_ranks) / len(reciprocal_ranks) if reciprocal_ranks else 0

In [None]:
hits1  = compute_hits_at_n(ground_truth_links, predicted_incoming, N=1)
hits3  = compute_hits_at_n(ground_truth_links, predicted_incoming, N=3)
hits5  = compute_hits_at_n(ground_truth_links, predicted_incoming, N=5)
hits10 = compute_hits_at_n(ground_truth_links, predicted_incoming, N=10)
mrr    = compute_mrr(ground_truth_links, predicted_incoming)

print(f"Hits@1: {hits1:.4f}")
print(f"Hits@3: {hits3:.4f}")
print(f"Hits@5: {hits5:.4f}")
print(f"Hits@10: {hits10:.4f}")
print(f"MRR: {mrr:.4f}")

# Link Induction

## PyTorch Conversions

In [None]:
#  Mapping or node name to integer index
node_to_idx = {node: i for i, node in enumerate(G.nodes())}

In [None]:
#Load UMAP reduced TF-IDF embeddings:
embedding = np.load("umap_embeddings.npy")
print("UMAP embedding shape:", embedding.shape)

In [None]:
doc_titles = df['title'].values
title_to_idx ={title: i for i, title in enumerate(doc_titles)}

#One-hot encoded cluster labels
encoder = OneHotEncoder(sparse_output=False)
cluster_onehot = encoder.fit_transform(df[['cluster']])


scaler = StandardScaler()
scaled_embedding = scaler.fit_transform(embedding)

full_features = np.hstack([scaled_embedding, cluster_onehot])

# Node features for PyG graph 
node_features = []
for node in G.nodes():
    idx = title_to_idx.get(node)
    if idx is not None:
        vec = full_features[idx]
        node_features.append(vec)
    else:
        node_features.append(np.zeros(tfidf_matrix.shape[1]))

node_features_array = np.array(node_features)
x = torch.tensor(node_features_array, dtype=torch.float)

In [None]:
edge_sources   = []
edge_targets   = []
edge_weights   = []
edge_sims      = []
edge_subtopics = []

for src_title, tgt_title, key, attr in G.edges(data=True, keys=True):
    src_idx = title_to_idx.get(src_title)
    tgt_idx = title_to_idx.get(tgt_title)
    
    #Skip if node is missing
    if src_idx is None or tgt_idx is None:
        continue
    
    edge_sources.append(src_idx)
    edge_targets.append(tgt_idx)
    
    # get edge attributes
    edge_weights.append(attr.get('weight',1.0))
    edge_sims.append(attr.get('sim',0.0))
    edge_subtopics.append(attr.get('subtopic_pair', ('NA', 'NA')))

In [None]:
edge_index = torch.tensor([edge_sources, edge_targets], dtype=torch.long)
edge_weight = torch.tensor(edge_weights, dtype=torch.float)
edge_sim = torch.tensor(edge_sims, dtype=torch.float)

In [None]:
# Vocabulary of all subtopic labels from topic_labels
all_subtopics = set()

# Add subtopics  to vocab
for topics in df['topic_list']:
    all_subtopics.update(topics)
    
subtopic_to_idx = {sub: i for i, sub in enumerate(sorted(all_subtopics))}

# Convert subtopic pairs to pairs of indices
edge_subtopic_idx = torch.tensor(
    [[subtopic_to_idx[s1], subtopic_to_idx[s2]] for s1, s2 in edge_subtopics],
    dtype=torch.long
)

In [None]:
data = Data(
    x=x,
    edge_index=edge_index,
    edge_weight=edge_weight,
    edge_sim=edge_sim,
    edge_subtopic=edge_subtopic_idx
)

In [None]:
# Save Graph
#torch.save(data,"multidigraph.pt")

## GAT Model

In [None]:
num_edges = data.edge_subtopic.shape[0]
data.edge_subtopic.shape == (num_edges,2)

### Train/Test Split

In [None]:
import torch
import numpy as np
import random
from collections import defaultdict
from sklearn.model_selection import train_test_split

def create_stratified_link_split_nodes_disjointed(
    G, df, title_to_idx, similarity_df, subtopic_to_idx, subtopic_edge_dict,
    hard_negative_sim_threshold=0.5,
    test_size=0.2, random_state=42
):
    random.seed(random_state)
    np.random.seed(random_state)
    torch.manual_seed(random_state)

    doc_titles_list = df['title'].tolist()
    all_global_node_ids = list(range(len(doc_titles_list)))
    
    node_connectivity_status = {}
    for node_id in all_global_node_ids:
        node_title = doc_titles_list[node_id]
        if G.degree(node_title) == 0:
            node_connectivity_status[node_id] = 'isolated'
        else:
            node_connectivity_status[node_id] = 'connected'
    
    isolated_count = sum(1 for status in node_connectivity_status.values() if status == 'isolated')
    connected_count = len(all_global_node_ids) - isolated_count
    print(f"Total Nodes: {len(all_global_node_ids)}")
    print(f"Isolated Nodes: {isolated_count}")
    print(f"Connected Nodes: {connected_count}")
    print(f"Ratio Isolated:Connected = {isolated_count / connected_count:.2f}" if connected_count > 0 else "N/A")
    print(f"--- End Node Connectivity Summary ---")
        
    raw_stratify_keys = [
        f"cluster_{df.loc[df['title'] == doc_titles_list[nid], 'cluster'].values[0]}_"
        f"{node_connectivity_status[nid]}"
        for nid in all_global_node_ids
    ]
    
    from collections import Counter
    key_counts = Counter(raw_stratify_keys)
    
    min_members_for_stratify = 3
    final_stratify_keys = []
    for key in raw_stratify_keys:
        if key_counts[key] < min_members_for_stratify:
            final_stratify_keys.append('mis_stratify_category')
        else:
            final_stratify_keys.append(key)
    
    train_node_ids, test_node_ids = train_test_split(
        all_global_node_ids,
        test_size=test_size,
        stratify=final_stratify_keys,
        random_state=random_state
    )
    
    train_node_ids_set = set(train_node_ids)
    test_node_ids_set = set(test_node_ids)
    
    # Check for Node Overlap
    node_overlap_count = len(train_node_ids_set.intersection(test_node_ids_set))
    print(f"\n--- Node Overlap Check ---")
    print(f"Nodes in Training Set: {len(train_node_ids_set)}")
    print(f"Nodes in Test Set: {len(test_node_ids_set)}")
    print(f"Overlap (Train Nodes & Test Nodes): {node_overlap_count} (Expected: 0)")
    assert node_overlap_count == 0, "Node sets overlap! Split is not node-disjoint."
    print(f"--- Node Overlap Check Complete ---")
    
    train_edges = []
    test_edges = []
    all_pos_edges_in_G = []
    
    for src_title, tgt_title, data in G.edges(data=True):
        src_id = title_to_idx.get(src_title)
        tgt_id = title_to_idx.get(tgt_title)
        if src_id is None or tgt_id is None: continue

        s1_name, s2_name = data.get('subtopic_pair', (None, None))
        s_idx1 = subtopic_to_idx.get(s1_name, 0)
        s_idx2 = subtopic_to_idx.get(s2_name, 0)
        
        edge_data_dict = {
            'src': src_id, 'tgt': tgt_id, 'subtopic' : [s_idx1, s_idx2],
            'weight' : data.get('weight', 0.0), 'sim' : data.get('sim', 0.0),
            'cluster': (df.loc[df['title'] == src_title, 'cluster'].values[0], df.loc[df['title'] == tgt_title, 'cluster'].values[0])
        }
        all_pos_edges_in_G.append(edge_data_dict)
        
        if src_id in train_node_ids_set and tgt_id in train_node_ids_set:
            train_edges.append(edge_data_dict)
        elif src_id in test_node_ids_set and tgt_id in test_node_ids_set:
            test_edges.append(edge_data_dict)

    def extract_tensor_dict(edges_list_of_dicts):
        if not edges_list_of_dicts:
            return {
                'edge_index': torch.empty((2, 0), dtype=torch.long), 'subtopic' : torch.empty((0, 2), dtype=torch.long),
                'weight': torch.empty(0, dtype=torch.float), 'sim': torch.empty(0, dtype=torch.float),
            }
        return {
            'edge_index': torch.tensor([[e['src'], e['tgt']] for e in edges_list_of_dicts], dtype=torch.long).T,
            'subtopic'  : torch.tensor([e['subtopic'] for e in edges_list_of_dicts], dtype=torch.long),
            'weight'    : torch.tensor([e['weight'] for e in edges_list_of_dicts], dtype=torch.float),
            'sim'       : torch.tensor([e['sim'] for e in edges_list_of_dicts], dtype=torch.float),
        }

    edge_group_dict = defaultdict(list)
    cluster_map = {}
    for edge_data_processed in all_pos_edges_in_G:
        pair = (edge_data_processed['src'], edge_data_processed['tgt'])
        edge_group_dict[pair].append(edge_data_processed)
        if pair not in cluster_map:
            src_id_for_cluster = edge_data_processed['src']
            src_title_for_cluster = doc_titles_list[src_id_for_cluster]
            src_cluster = df.loc[df['title'] == src_title_for_cluster, 'cluster'].values[0]
            cluster_map[pair] = src_cluster
        
    pairs_to_split = list(edge_group_dict.keys())
    clusters_for_stratify = [cluster_map[pair] for pair in pairs_to_split]
    
    train_pairs, test_pairs = train_test_split(
        pairs_to_split, test_size=test_size, stratify=clusters_for_stratify, random_state=random_state
    )
    
    # Convert all existing positive edge pairs to a symmetric set
    all_pos_edge_pairs_set_global = set((e['src'], e['tgt']) for e in all_pos_edges_in_G)
    reversed_pos_edges_to_add = [(tgt, src) for src, tgt in list(all_pos_edge_pairs_set_global)]
    all_pos_edge_pairs_set_global.update(reversed_pos_edges_to_add)

    # Define all TRAIN POSITIVE edge pairs
    train_pos_edge_ids_set_sym = set(pair for pair in train_pairs)
    reversed_train_pos_to_add = [(tgt, src) for src, tgt in list(train_pos_edge_ids_set_sym)]
    train_pos_edge_ids_set_sym.update(reversed_train_pos_to_add)
    
    # Define all TEST POSITIVE edge pairs
    test_pos_edge_ids_set_sym = set(pair for pair in test_pairs)
    reversed_test_pos_to_add = [(tgt, src) for src, tgt in list(test_pos_edge_ids_set_sym)]
    test_pos_edge_ids_set_sym.update(reversed_test_pos_to_add)

    # Samples from the full graph node IDs and avoids specific sets of edges.
    def generate_negatives(num_samples_needed, hard_type, source_node_pool, target_node_pool, current_avoid_set): 
        negatives_list_of_dicts = []
        attempts_count = 0
        max_attempts_for_gen = num_samples_needed * 100 
        
        source_node_pool_list = list(source_node_pool)
        target_node_pool_list = list(target_node_pool)

        while len(negatives_list_of_dicts) < num_samples_needed and attempts_count < max_attempts_for_gen:
            i = random.choice(source_node_pool_list)
            j = random.choice(target_node_pool_list)
            
            if i == j: # Skip self-loops
                attempts_count +=1
                continue

            if (i,j) in current_avoid_set: 
                attempts_count +=1
                continue
            if (j,i) in current_avoid_set:
                attempts_count +=1
                continue
                
            subs_i = df.loc[df['title'] == doc_titles_list[i], 'topic_list'].values[0]
            subs_j = df.loc[df['title'] == doc_titles_list[j], 'topic_list'].values[0]
            
            best_sim = -1.0
            best_pair_names = (None, None)
            if subs_i and subs_j:
                for s1_name in subs_i:
                    for s2_name in subs_j:
                        sim_val_from_df = similarity_df.get(s1_name, {}).get(s2_name)
                        if sim_val_from_df is not None:
                            if sim_val_from_df > best_sim:
                                best_sim = sim_val_from_df
                                best_pair_names = (s1_name, s2_name)
            
            if best_sim == -1.0: 
                attempts_count += 1
                continue
            if (hard_type and best_sim < hard_negative_sim_threshold) or \
               (not hard_type and best_sim >= hard_negative_sim_threshold/2): continue
            
            weight = subtopic_edge_dict.get(best_pair_names, 0.0)

            negatives_list_of_dicts.append({
                'src': i, 'tgt': j,
                'subtopic': [subtopic_to_idx.get(best_pair_names[0], 0), subtopic_to_idx.get(best_pair_names[1], 0)],
                'weight': weight, 'sim': best_sim
            })
            current_avoid_set.add((i,j)) 
            current_avoid_set.add((j,i)) 
            attempts_count += 1 

        print(f"Generated {len(negatives_list_of_dicts)} negatives for is_hard={hard_type} after {attempts_count} attempts.")
        return negatives_list_of_dicts
    
    # Generate TRAIN Easy Negatives
    current_global_avoid_set = all_pos_edge_pairs_set_global.copy() # Start with all existing graph edges
    
    train_neg_edges_easy_list = generate_negatives(len(train_edges), hard_type=False, 
                                                   source_node_pool=train_node_ids_set, 
                                                   target_node_pool=train_node_ids_set,
                                                   current_avoid_set=current_global_avoid_set)
    
    # Generate TRAIN Hard Negatives
    train_neg_edges_hard_list = generate_negatives(len(train_edges), hard_type=True,
                                                   source_node_pool=train_node_ids_set, 
                                                   target_node_pool=train_node_ids_set,
                                                   current_avoid_set=current_global_avoid_set)

    # Generate TEST Easy Negatives
    test_neg_edges_easy_list = generate_negatives(len(test_edges), hard_type=False,
                                                  source_node_pool=train_node_ids_set, 
                                                  target_node_pool=train_node_ids_set,
                                                  current_avoid_set=current_global_avoid_set)

    # Generate TEST Hard Negatives
    test_neg_edges_hard_list = generate_negatives(len(test_edges), hard_type=True,
                                                  source_node_pool=train_node_ids_set, 
                                                  target_node_pool=train_node_ids_set,
                                                  current_avoid_set=current_global_avoid_set)

    return {
        'train_pos': extract_tensor_dict(train_edges),
        'test_pos': extract_tensor_dict(test_edges),
        'train_neg_easy': extract_tensor_dict(train_neg_edges_easy_list),
        'train_neg_hard': extract_tensor_dict(train_neg_edges_hard_list),
        'test_neg_easy' : extract_tensor_dict(test_neg_edges_easy_list),
        'test_neg_hard' : extract_tensor_dict(test_neg_edges_hard_list)
    }, train_node_ids_set, test_node_ids_set

In [None]:
splits, train_node_ids_set, test_node_ids_set = create_stratified_link_split_nodes_disjointed(
    G=G,
    df=df,
    similarity_df=similarity_df,
    subtopic_to_idx=subtopic_to_idx,
    subtopic_edge_dict=subtopic_edge_dict,
    title_to_idx=title_to_idx,
    test_size=0.2,
    hard_negative_sim_threshold=0.45,
    random_state=42
)

In [None]:
splits
train_pos_set_check = set(tuple(e) for e in splits['train_pos']['edge_index'].T.cpu().tolist())
train_pos_set_check.update(set((tgt, src) for src, tgt in train_pos_set_check))

# test_pos
test_pos_set_check = set(tuple(e) for e in splits['test_pos']['edge_index'].T.cpu().tolist())
test_pos_set_check.update(set((tgt, src) for src, tgt in test_pos_set_check))

# train_neg_easy
train_neg_easy_set_check = set(tuple(e) for e in splits['train_neg_easy']['edge_index'].T.cpu().tolist())
train_neg_easy_set_check.update(set((tgt, src) for src, tgt in train_neg_easy_set_check))

# train_neg_hard
train_neg_hard_set_check = set(tuple(e) for e in splits['train_neg_hard']['edge_index'].T.cpu().tolist())
train_neg_hard_set_check.update(set((tgt, src) for src, tgt in train_neg_hard_set_check))

# test_neg_easy
test_neg_easy_set_check = set(tuple(e) for e in splits['test_neg_easy']['edge_index'].T.cpu().tolist())
test_neg_easy_set_check.update(set((tgt, src) for src, tgt in test_neg_easy_set_check))

# test_neg_hard
test_neg_hard_set_check = set(tuple(e) for e in splits['test_neg_hard']['edge_index'].T.cpu().tolist())
test_neg_hard_set_check.update(set((tgt, src) for src, tgt in test_neg_hard_set_check))


print("\n Positive Edge Leakage")
overlap_train_pos_test_pos = len(train_pos_set_check.intersection(test_pos_set_check))
print(f"Train Pos vs Test Pos overlap: {overlap_train_pos_test_pos}")

print("\nNegative Edge Leakage")
overlap_train_neg_easy_test_neg_easy = len(train_neg_easy_set_check.intersection(test_neg_easy_set_check))
print(f"Train Neg Easy vs Test Neg Easy overlap: {overlap_train_neg_easy_test_neg_easy}")

overlap_train_neg_hard_test_neg_hard = len(train_neg_hard_set_check.intersection(test_neg_hard_set_check))
print(f"Train Neg Hard vs Test Neg Hard overlap: {overlap_train_neg_hard_test_neg_hard}")

print("\nCross-Split Leakage (e.g., test in training negs, or vice-versa")
# Test positives in any training negative set
overlap_test_pos_train_neg_easy = len(test_pos_set_check.intersection(train_neg_easy_set_check))
print(f"Test Pos vs Train Neg Easy overlap: {overlap_test_pos_train_neg_easy}")

overlap_test_pos_train_neg_hard = len(test_pos_set_check.intersection(train_neg_hard_set_check))
print(f"Test Pos vs Train Neg Hard overlap: {overlap_test_pos_train_neg_hard}")

# Test negatives in any training positive set
overlap_test_neg_easy_train_pos = len(test_neg_easy_set_check.intersection(train_pos_set_check))
print(f"Test Neg Easy vs Train Pos overlap: {overlap_test_neg_easy_train_pos}")

overlap_test_neg_hard_train_pos = len(test_neg_hard_set_check.intersection(train_pos_set_check))
print(f"Test Neg Hard vs Train Pos overlap: {overlap_test_neg_hard_train_pos}")

# Check for any test negative in any train negative
all_train_neg_set_check = train_neg_easy_set_check.union(train_neg_hard_set_check)
all_test_neg_set_check = test_neg_easy_set_check.union(test_neg_hard_set_check)
overlap_all_train_neg_all_test_neg = len(all_train_neg_set_check.intersection(all_test_neg_set_check))
print(f"All Train Neg vs All Test Neg overlap: {overlap_all_train_neg_all_test_neg}")

In [None]:
import torch
import numpy as np 

print("\n--- Verifying Node-Disjoint Negative Edge Integrity ---")

# Get all training negative edges
all_train_neg_edges = []
all_train_neg_edges.extend(splits['train_neg_easy']['edge_index'].T.cpu().tolist())
all_train_neg_edges.extend(splits['train_neg_hard']['edge_index'].T.cpu().tolist())

# Get all test negative edges 
all_test_neg_edges = []
all_test_neg_edges.extend(splits['test_neg_easy']['edge_index'].T.cpu().tolist())
all_test_neg_edges.extend(splits['test_neg_hard']['edge_index'].T.cpu().tolist())

# Define node sets for checking
train_nodes = train_node_ids_set 
test_nodes = test_node_ids_set   

# Function to check for cross-split edges
def check_cross_split_edges(edge_list, source_set, target_set, description):
    cross_split_count = 0
    for src, tgt in edge_list:
        is_src_in_source = src in source_set
        is_tgt_in_target = tgt in target_set
        
        if is_src_in_source and is_tgt_in_target:
            pass
        elif (is_src_in_source and tgt in test_nodes) or \
             (is_src_in_source and tgt in train_nodes) or \
             (is_tgt_in_target and src in test_nodes) or \
             (is_tgt_in_target and src in train_nodes):
             pass 


    cross_split_edges_found = []
    for src, tgt in edge_list:
        src_is_train = src in train_nodes
        tgt_is_train = tgt in train_nodes
        src_is_test = src in test_nodes
        tgt_is_test = tgt in test_nodes

        if (src_is_train and tgt_is_test) or (src_is_test and tgt_is_train):
            cross_split_edges_found.append((src, tgt))
    
    print(f"{description}: Found {len(cross_split_edges_found)} cross-split edges (Expected: 0)")
    if len(cross_split_edges_found) > 0:
        print(f"  Sample cross-split edges: {cross_split_edges_found[:5]}")


# Check Training Negatives
print("\n--- Checking Training Negative Edges ---")
check_cross_split_edges(all_train_neg_edges, train_nodes, test_nodes, "All Train Negatives")

# Check Test Negatives
print("\n--- Checking Test Negative Edges ---")
check_cross_split_edges(all_test_neg_edges, train_nodes, test_nodes, "All Test Negatives")

print("\n--- Node-Disjoint Negative Edge Integrity Check Complete ---")

#### Early Stopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=10):
        self.patience = patience
        self.best_score = None
        self.counter = 0
        self.stop = False

    def __call__(self, val_loss):
        if self.best_score is None or val_loss < self.best_score:
            self.best_score = val_loss
            self.counter=0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.stop=True

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import roc_auc_score, precision_score, recall_score, average_precision_score
import matplotlib.pyplot as plt

def train_gat_model(
    model,
    data,
    splits,
    num_epochs=1000,
    hard_neg_start_epoch=30,
    patience=10,
    min_epochs=40,
    lr=0.001,
    neg_cycle_length_epochs=10,
    device='cpu'
):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()
    early_stopper = EarlyStopping(patience)
    checkpoint_path = 'best_model.pt'
    best_val_loss = float('inf')

    # Unpack Splits
    train_pos      = splits['train_pos']
    test_pos       = splits['test_pos']
    train_neg_easy = splits['train_neg_easy']
    train_neg_hard = splits['train_neg_hard']
    test_neg_easy  = splits['test_neg_easy']
    test_neg_hard  = splits['test_neg_hard']

    combined_train_negs = {
        'edge_index': torch.cat([train_neg_easy['edge_index'], train_neg_hard['edge_index']], dim=1),
        'subtopic': torch.cat([train_neg_easy['subtopic'], train_neg_hard['subtopic']], dim=0),
        'weight': torch.cat([train_neg_easy['weight'], train_neg_hard['weight']], dim=0),
        'sim': torch.cat([train_neg_easy['sim'], train_neg_hard['sim']], dim=0)
    }

    # Tracking
    history = {
        'train_loss'    : [],
        'val_easy_loss' : [],
        'val_easy_auc'  : [],
        'val_hard_loss' : [],
        'val_hard_auc'  : []
    }

    # Combine positive & Negative edges into one dataset per epoch
    def prepare_batch(pos, neg):
        edge_index = torch.cat([pos['edge_index'], neg['edge_index']], dim=1).to(device)
        edge_subtopic = torch.cat([pos['subtopic'], neg['subtopic']], dim=0).to(device)
        edge_weight = torch.cat([pos['weight'],neg['weight']], dim=0).to(device)
        edge_sim = torch.cat([pos['sim'], neg['sim']], dim=0).to(device)

        labels = torch.cat([
            torch.ones(pos['edge_index'].shape[1], dtype=torch.float),
            torch.zeros(neg['edge_index'].shape[1], dtype=torch.float)
        ]).to(device)

        return edge_index, edge_subtopic, edge_weight, edge_sim, labels

    for epoch in range(num_epochs):
        model.train()

        # Adding Curriculum (Swithcing to hard negatives after N epochs)
        if epoch < hard_neg_start_epoch:
            neg_train = train_neg_easy
        else:
            cycle_number = (epoch - hard_neg_start_epoch) // neg_cycle_length_epochs

            if cycle_number % 2 == 0:
                neg_train = train_neg_easy
            else:
                neg_train = train_neg_hard

        edge_index, subtopic, weight, sim, labels = prepare_batch(train_pos, neg_train)
        node_feats = data.x.to(device)

        preds = model(node_feats, edge_index, subtopic, weight, sim)
        loss = loss_fn(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #Validation
        model.eval()
        with torch.no_grad():
            val_results = {}
            for neg_type, neg_split in {'Easy': test_neg_easy, 'Hard': test_neg_hard}.items():
                val_edge_index, val_subtopic, val_weight, val_sim, val_labels = prepare_batch(test_pos, neg_split)
                val_preds = model(node_feats, val_edge_index, val_subtopic, val_weight, val_sim)
                val_loss = loss_fn(val_preds, val_labels).item()
                val_auc = roc_auc_score(val_labels.cpu(), val_preds.cpu())

                val_probs = torch.sigmoid(val_preds).cpu().numpy()
                val_bin_preds = (val_probs >= 0.5).astype(int)
                val_labels_np = val_labels.cpu().numpy()

                precision =  precision_score(val_labels_np, val_bin_preds)
                recall = recall_score(val_labels_np, val_bin_preds)
                mrr = average_precision_score(val_labels_np, val_probs)

                val_results[neg_type] = {
                    'loss' : val_loss,
                    'auc' : val_auc,
                    'precision' : precision,
                    'recall' : recall,
                    'mrr' : mrr
                }

        print("="*60)
        print(f"Epoch {epoch+1}")
        print(f"Train Loss       : {loss.item():.4f}")

        print("\n[Validation - Easy Negatives]")
        print(f"  Loss         : {val_results['Easy']['loss']:.4f}")
        print(f"  AUC          : {val_results['Easy']['auc']:.4f}")
        print(f"  Precision    : {val_results['Easy']['precision']:.4f}")
        print(f"  Recall       : {val_results['Easy']['recall']:.4f}")
        print(f"  MRR          : {val_results['Easy']['mrr']:.4f}")

        print("\n[Validation - Hard Negatives]")
        print(f"  Loss         : {val_results['Hard']['loss']:.4f}")
        print(f"  AUC          : {val_results['Hard']['auc']:.4f}")
        print(f"  Precision    : {val_results['Hard']['precision']:.4f}")
        print(f"  Recall       : {val_results['Hard']['recall']:.4f}")
        print(f"  MRR          : {val_results['Hard']['mrr']:.4f}")
        print("=" * 60)


        # Save Best Model
        if val_results['Hard']['loss'] < best_val_loss:
            best_val_loss = val_results['Hard']['loss']
            torch.save(model.state_dict(), checkpoint_path)

        #Update History
        history['train_loss'].append(loss.item())
        history['val_easy_loss'].append(val_results['Easy']['loss'])
        history['val_easy_auc'].append(val_results['Easy']['auc'])
        history['val_hard_loss'].append(val_results['Hard']['loss'])
        history['val_hard_auc'].append(val_results['Hard']['auc'])

        # Early stopping based on validation with hard negatives
        if epoch >= min_epochs:
            early_stopper(val_results['Hard']['loss'])
            if early_stopper.stop:
                print("Early stopping triggered.")
                break

    #Plot Training Curve
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(history['train_loss'], label='Train loss')
    plt.plot(history['val_easy_loss'], label='Val Easy Loss')
    plt.plot(history['val_hard_loss'], label = 'Val Hard Loss')
    plt.title('Loss Curve')
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend()

    plt.subplot(1,2,2)
    plt.plot(history['val_easy_auc'], label='Val Easy AUC')
    plt.plot(history['val_hard_auc'], label='Val Hard AUC')
    plt.title('Validation AUC')
    plt.xlabel('Epoch'); plt.ylabel('AUC'); plt.legend()

    plt.tight_layout()
    plt.show()

    # Return best model
    model.load_state_dict(torch.load(checkpoint_path))
    return model

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class SubtopicGAT(torch.nn.Module):
    # Parameters = default values
    def __init__(self, in_channels, hidden_channels, num_subtopics, edge_emb_dim=16, heads=2, dropout=0.5):
        super().__init__()
        self.dropout = dropout

        # Subtopic pair embedding
        self.subtopic_embed = nn.Embedding(num_subtopics,edge_emb_dim)

        # GAT Layer
        self.gat1 = GATConv(in_channels,hidden_channels,heads=heads, concat=True, dropout=dropout)
        self.gat2 = GATConv(hidden_channels * heads, hidden_channels, heads=1, concat=False, dropout=dropout)

        #Decoder MLP for edge scoring
        self.edge_decoder = nn.Sequential(
            nn.Linear(2 * hidden_channels + 2 * edge_emb_dim + 2, 64), #+2 for edge_weight and edge_sim
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(64,1)
        )

    def forward(self, x, edge_index, edge_subtopic, edge_weight, edge_sim):
        # Encode Node Features
        x = self.gat1(x, edge_index).relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gat2(x, edge_index)

        # Compute edge-wise node embeddings
        src, tgt = edge_index
        h_src = x[src]
        h_tgt = x[tgt]

        # Subtopic Embeddings
        s1 = self.subtopic_embed(edge_subtopic[:,0])
        s2 = self.subtopic_embed(edge_subtopic[:,1])
        edge_subtopic_feat = torch.cat([s1,s2],dim=1)

        # Combine edge features
        edge_feats = torch.cat([
            h_src, h_tgt,
            edge_subtopic_feat,
            edge_weight.unsqueeze(1),
            edge_sim.unsqueeze(1)
        ], dim=1)

        return self.edge_decoder(edge_feats).squeeze(1)

In [None]:
in_channels = data.x.shape[1] 
hidden_channels = 64         
num_subtopics   = len(subtopic_to_idx)

model = SubtopicGAT(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    num_subtopics=num_subtopics,
    edge_emb_dim=16,
    heads=2
)

In [None]:
trained_model = train_gat_model(
    model=model,
    data=data,
    splits=splits,
    num_epochs=1000,
    hard_neg_start_epoch = 30,
    patience = 20,
    lr = .001,
    device='cpu'
)

## Predictions & Queries

In [None]:
# Load Models
model_path = 'best_model_final.pt'

state_dict = torch.load(model_path)

model.load_state_dict(state_dict)


import joblib

#TFIDF
tfidf = joblib.load('fitted_tfidf_vectorizer.joblib')

#UMAP
umap_model = joblib.load('fitted_umap_model.joblib')

In [None]:
from scipy.sparse import load_npz, csr_matrix

#Load cluster centroids
cluster_centroids = load_npz("cluster_centroids.npz")

#Transform cluster centroids into UMAP reduced embeddings to match
umap_reduced_cluster_centroids = {}
for cluster_idx in range(cluster_centroids.shape[0]):
    tfidf_centroid_vector_sparse = cluster_centroids[cluster_idx, :]
    
    tfidf_centroid_vector_dense = tfidf_centroid_vector_sparse.toarray()
    
    umap_embedding_centroid = umap_model.transform(tfidf_centroid_vector_dense)
    
    scaled_umap_embedding_centroid = scaler.transform(umap_embedding_centroid)
    
    umap_reduced_cluster_centroids[cluster_idx] = scaled_umap_embedding_centroid
    

In [None]:
fitted_bertopic_models = {}
for cluster_id in range(9):
    model_path = f"Models/bertopic_cluster_{cluster_id}.model"
    
    fitted_bertopic_models[cluster_id] = BERTopic.load(model_path)

In [None]:
import spacy
import re

nlp = spacy.load("en_core_web_lg")

# Analyzing and POS tagging to match tfidf preprocessing
def analyze_query(doc_query: spacy.tokens.Doc):
    query_nav_lemmas = []
    query_entities_strings = []
    
    for token in doc_query:
        if not token.is_space and not token.is_punct:
            if token.pos_ in ['Non', 'PROPN', 'ADJ', 'ADV', 'VERB', 'AUX']:
                query_nav_lemmas.append(token.lemma_)
    
    for ent in doc_query.ents:
        query_entities_strings.append(str(ent))
    
    return query_nav_lemmas, query_entities_strings

# To match preprocessing
def resentece_query(words_list_or_entity_list):
    return "|".join([w for w in words_list_or_entity_list if len(w.strip()) > 0])

def preprocess_lemmatize_query(query_text: str, STOP_WORDS_SET: set) -> str:
    
    # lowercase and remove anything but letters
    query_text = query_text.lower()
    
    doc = nlp(query_text)
    
    query_nav_lemmas, query_entities_strings = analyze_query(doc)
    
    nav_string = resentece_query(query_nav_lemmas)
    entities_string = resentece_query(query_entities_strings)
    
    combined_text = f"{nav_string or ''}|{entities_string or ''}"
    
    words = re.split(r'\||\#', combined_text)
    
    processed_words = []
    for w in words:
        # Remove single character words/parts of a list (a.)
        cleaned_word = re.sub(r'\b[A-Za-z]\.', '', w)
        cleaned_word = re.sub(r'\b[A-Za-z]\b', '', cleaned_word)
        
        # Replace specific unicode whitespace characters with single space, in case query is a wiki article
        cleaned_word = re.sub(r'[\xa0\u200b\u202f]', ' ', cleaned_word)
        
        cleaned_word = cleaned_word.strip()
        
        if len(cleaned_word) > 1 and cleaned_word not in STOP_WORDS_SET:
            processed_words.append(cleaned_word)
    
    return ' '.join(processed_words)

In [None]:
def get_query_node_feature_vector(
    query_text: str,
    fitted_tfidf_vectorizer: TfidfVectorizer,
    fitted_umap_model: UMAP,
    fitted_scaler_node_features: StandardScaler,
    fitted_cluster_encoder: OneHotEncoder,
    df_global_for_cluster_dim: pd.DataFrame,
    STOP_WORDS_SET: set,
) -> np.ndarray:
    
    # Preprocess & Lemmatize
    processed_query_string = preprocess_lemmatize_query(query_text, STOP_WORDS_SET)
    
    # TF-IDF Vectorize the query
    query_tfidf_vector = fitted_tfidf_vectorizer.transform([query_text])
    
    # UMAP Transform
    query_umap_embedding = fitted_umap_model.transform(query_tfidf_vector)
    
    # Standardize
    query_scaled_embedding = fitted_scaler_node_features.transform(query_umap_embedding)
    
    # Cluster One-Hot
    cluster_dim = fitted_cluster_encoder.categories_[0].size
    query_cluster_onehot = np.zeros((1,cluster_dim))
    
    # Concatenate
    full_query_node_feature_vector = np.hstack([query_scaled_embedding, query_cluster_onehot])
    
    return full_query_node_feature_vector

In [None]:
def predict_related_documents(
    query_text: str,
    trained_model: torch.nn.Module,
    global_data_obj: Data,
    original_G_nx: nx.MultiDiGraph,
    df_global: pd.DataFrame,
    title_to_idx_global: dict,
    doc_titles_list: list,
    subtopic_to_idx_global: dict,
    similarity_df_global: pd.DataFrame,
    subtopic_edge_dict_global: dict,
    fitted_tfidf_vectorizer: TfidfVectorizer,
    fitted_umap_model: UMAP,
    fitted_scaler_node_features: StandardScaler,
    fitted_cluster_encoder: OneHotEncoder,
    umap_reduced_cluster_centroids: dict,
    STOP_WORDS: set,
    fitted_bertopic_models: dict,
    #global_subtopic_centroids_umap: dict,
    top_k_results: int = 10,
    prediction_probability_threshold: float = 0.5,
    include_neighbors: bool =True,
    max_neighbors_per_doc: int = 5,
    device: str = 'cpu',
    query_sim_threshold: float = 0.5
):
    trained_model.eval()
    trained_model.to(device)
    
    # Copy of Graph
    temp_G_nx = original_G_nx.copy()
    
    num_existing_nodes = global_data_obj.num_nodes
    query_node_id = num_existing_nodes
    
    query_subtopics_list = []
    
    query_node_features_np = get_query_node_feature_vector(
        query_text=query_text,
        fitted_tfidf_vectorizer=fitted_tfidf_vectorizer,
        fitted_umap_model=fitted_umap_model,
        fitted_scaler_node_features=fitted_scaler_node_features,
        fitted_cluster_encoder=fitted_cluster_encoder,
        df_global_for_cluster_dim=df,
        STOP_WORDS_SET=STOP_WORDS
    )
    query_node_features_tensor = torch.tensor(query_node_features_np, dtype=torch.float).to(device)
    
    query_scaled_embedding_np = query_node_features_np[:, :fitted_umap_model.n_components]
    
    # Assign query to cluster via centroid similarity
    best_cluster = None
    highest_similarity_to_centroid = -1.0
    
    for cluster_id, centroid_embedding_umap in umap_reduced_cluster_centroids.items():
        similarity = cosine_similarity(query_scaled_embedding_np, centroid_embedding_umap)[0][0]
        
        if similarity > highest_similarity_to_centroid:
            highest_similarity_to_centroid = similarity
            best_cluster = cluster_id
    
    print(f"Query '{query_text}' is most similar to cluster: {best_cluster} (Similarity: {highest_similarity_to_centroid:.4f})")
    
    query_subtopics_list = []
    
    bertopic_model_assigned_topic = False
    if best_cluster is not None and fitted_bertopic_models.get(best_cluster) is not None:
        bertopic_model_for_query = fitted_bertopic_models[best_cluster]
        
        query_topic_ids, _ = bertopic_model_for_query.transform([query_text])
        
        if query_topic_ids and query_topic_ids[0] != -1:
            full_subtopic_name = f"{best_cluster}_{query_topic_ids[0]}"
            
            if full_subtopic_name in subtopic_to_idx_global:
                query_subtopics_list.append(full_subtopic_name)
                bertopic_model_assigned_topic = True
            else:
                print(f"Warning: BERTopic topic '{full_subtopic_name}' not found in global subtopic_to_idx. Skipping.")
    
        else:
            print("No specific topic assigned by BERTopic for this query (id -1)")
    else:
        print(f"Warning: No BERTopic model found for predicted cluster")
            
    # Final check before candidate edges
    if not query_subtopics_list:
        print("Query could not be assigned to any meaningful subtopics, returning empty list")
        return []
    
    # Generate candidate Edges
    target_docs_in_cluster_df = df_global[df_global['cluster'] == best_cluster]
    
    filtered_target_doc_ids = [title_to_idx_global[title] for title in target_docs_in_cluster_df['title']]
    
    if not filtered_target_doc_ids:
        print(f"No documents found in predicted cluster {best_cluster}")
        return []
    
    candidate_edges_sources = []
    candidate_edges_targets = []
    candidate_edges_weights = []
    candidate_edge_sims = []
    candidate_edge_subtopic_indices = []
    
    # Only search documents inside the cluster
    for target_doc_id in filtered_target_doc_ids:
        target_doc_title = doc_titles_list[target_doc_id]
        target_doc_subtopics_list = df_global.loc[df_global['title'] == target_doc_title, 'topic_list'].values[0]
        
        best_sim_for_query_target = -1.0
        best_subtopic_pair_for_query_target = (None, None)
        
        # Find best subtopic pari between query's extracted subtopics and target document
        if query_subtopics_list and target_doc_subtopics_list:
            for qs_query_topic in query_subtopics_list:
                for ts_doc_topic in target_doc_subtopics_list:
                    sim_val = similarity_df_global.get(qs_query_topic, {}).get(ts_doc_topic)
                    if sim_val is not None:
                        if sim_val > best_sim_for_query_target:
                            best_sim_for_query_target = sim_val
                            best_subtopic_pair_for_query_target = (qs_query_topic, ts_doc_topic)
                            
        sim_val_final = best_sim_for_query_target if best_sim_for_query_target >= 0 else 0.0
    
            
        weight_val = subtopic_edge_dict_global.get(best_subtopic_pair_for_query_target, 0.0)
        
        candidate_edges_sources.append(query_node_id)
        candidate_edges_targets.append(target_doc_id)
        candidate_edges_weights.append(weight_val)
        candidate_edge_sims.append(sim_val_final)
        candidate_edge_subtopic_indices.append([
            subtopic_to_idx_global.get(best_subtopic_pair_for_query_target[0], 0),
            subtopic_to_idx_global.get(best_subtopic_pair_for_query_target[1], 0)
        ])
    if not candidate_edges_sources: # No relevant docs above threshold
        print(f"No relevant documents found for the query above the similarity threshold")
        return []
    # Create a temporary expanded PyG object for creation
    expanded_x = torch.cat([global_data_obj.x, query_node_features_tensor], dim=0).to(device)
    
    candidate_edge_index = torch.tensor([candidate_edges_sources, candidate_edges_targets], dtype=torch.long)
    candidate_edge_weight = torch.tensor(candidate_edges_weights, dtype=torch.float)
    candidate_edge_sim = torch.tensor(candidate_edge_sims, dtype=torch.float)
    candidate_edge_subtopic = torch.tensor(candidate_edge_subtopic_indices, dtype=torch.long)
    
    expanded_edge_index = torch.cat([global_data_obj.edge_index, candidate_edge_index], dim=1).to(device)
    expanded_edge_weight = torch.cat([global_data_obj.edge_weight, candidate_edge_weight], dim=0).to(device)
    expanded_edge_sim = torch.cat([global_data_obj.edge_sim, candidate_edge_sim], dim=0).to(device)
    expanded_edge_subtopic = torch.cat([global_data_obj.edge_subtopic, candidate_edge_subtopic], dim=0).to(device)
    
    temp_data_for_prediction = Data(
        x=expanded_x,
        edge_index=expanded_edge_index,
        edge_weight=expanded_edge_weight,
        edge_sim=expanded_edge_sim,
        edge_subtopic=expanded_edge_subtopic
    ).to(device)
    
    # Link Prediction with trained model
    with torch.no_grad():
        query_to_doc_preds_logits = trained_model(
            x=temp_data_for_prediction.x,
            edge_index=temp_data_for_prediction.edge_index,
            edge_subtopic=temp_data_for_prediction.edge_subtopic,
            edge_weight=temp_data_for_prediction.edge_weight,
            edge_sim=temp_data_for_prediction.edge_sim
        )
        query_to_doc_preds_probs = torch.sigmoid(query_to_doc_preds_logits).cpu().numpy()
        
    # Ranking and Retreival
    results = []

    start_index_of_query_edges = global_data_obj.edge_index.size(1)
    end_index_of_query_edges = start_index_of_query_edges + len(candidate_edges_sources)
    
    query_doc_probs_filtered_for_ranking = query_to_doc_preds_probs[start_index_of_query_edges:end_index_of_query_edges]
    query_doc_targets_filtered_for_ranking = torch.tensor(candidate_edges_targets, dtype=torch.long).cpu().numpy()
    
    all_query_doc_predictions = []
    for i, prob in enumerate(query_doc_probs_filtered_for_ranking):
        target_doc_global_id = query_doc_targets_filtered_for_ranking[i]
        target_doc_title = doc_titles_list[target_doc_global_id]
        
        all_query_doc_predictions.append({
            'doc_title': target_doc_title,
            'prediction_probability': prob,
            'node_id': target_doc_global_id,
            'source_type': 'direct_prediction'
        })
    all_query_doc_predictions.sort(key=lambda x: x['prediction_probability'], reverse=True)
    
    final_unique_results = []
    seen_doc_ids = set()
    
    for res in all_query_doc_predictions:
        #Stop if enough direct predicions to meet top k resutls
        if len(final_unique_results) >= top_k_results and not include_neighbors:
            break
        
        if res['prediction_probability'] >= prediction_probability_threshold and \
           res['node_id'] not in seen_doc_ids:
            
            final_unique_results.append(res)
            seen_doc_ids.add(res['node_id'])
            
    # Graph Traversal
    traversal_results = []
    if include_neighbors and len(seen_doc_ids) < top_k_results:
        print("Traversing Graph for Related Documents")
        for direct_res in final_unique_results:
            source_doc_id = direct_res['node_id']
            source_doc_title = direct_res['doc_title']
            
            current_neighbors_added_from_this_doc = 0
            # Out Neigbors
            for neighbor_title in temp_G_nx.successors(source_doc_title):
                neighbor_id = title_to_idx_global.get(neighbor_title)
                if neighbor_id is not None and neighbor_id not in seen_doc_ids:
                    traversal_results.append({
                        'doc_title': neighbor_title,
                        'prediction_probability': 0.0, #placeholder
                        'node_id': neighbor_id,
                        'source_type': f"out_neighbor_from_id_{source_doc_id}"
                    })
                    current_neighbors_added_from_this_doc +=1
                    seen_doc_ids.add(neighbor_id)
                if len(final_unique_results) + len(traversal_results) >= top_k_results: break
                    
            current_neighbors_added_from_this_doc = 0
            # in Neighbors
            for neighbor_title in temp_G_nx.predecessors(source_doc_title):
                neighbor_id = title_to_idx_global.get(neighbor_title)
                if neighbor_id is not None and neighbor_id not in seen_doc_ids:
                    traversal_results.append({
                        'doc_title': neighbor_title,
                        'prediction_probability': 0.0, #placeholder
                        'node_id': neighbor_id,
                        'source_type': f"out_neighbor_from_id_{source_doc_id}"
                    })
                    current_neighbors_added_from_this_doc +=1
                    seen_doc_ids.add(neighbor_id)
                    if len(traversal_results) >= max_neighbors_per_doc: break #limit neighbors
                if len(final_unique_results) + len(traversal_results) >= top_k_results: break
    final_results = final_unique_results + traversal_results
            
    final_results.sort(key=lambda x: x['prediction_probability'], reverse=True)
            
    print(f"\nTop {top_k_results} related documents for query: '{query_text}'")
    for i, res in enumerate(final_results[:top_k_results]):
        print(f"{i+1}. Document: '{res['doc_title']}' (ID: {res['node_id']}) - Probability: {res['prediction_probability']:.4f} (Source: {res['source_type']})")
    
    return results[:top_k_results]

In [None]:
my_query = "Alcohol increases the risk of cancer of the breast (in women), throat, liver, oesophagus, mouth, larynx, and colon.[58] In Western Europe, 10% of cancers in males and 3% of cancers in females are attributed to alcohol exposure, especially liver and digestive tract cancers.[59] Cancer from work-related substance exposures may cause between 2 and 20% of cases,[60] causing at least 200,000 deaths.[61] Cancers such as lung cancer and mesothelioma can come from inhaling tobacco smoke or asbestos fibers, or leukemia from exposure to benzene"
prediction_results = predict_related_documents(
    query_text = my_query,
    trained_model = model,
    global_data_obj = data,
    original_G_nx = G,
    df_global = df,
    title_to_idx_global = title_to_idx,
    doc_titles_list = doc_titles,
    subtopic_to_idx_global = subtopic_to_idx,
    similarity_df_global = similarity_df,
    subtopic_edge_dict_global = subtopic_edge_dict,
    fitted_tfidf_vectorizer = tfidf,
    fitted_umap_model = umap_model,
    fitted_scaler_node_features = scaler,
    fitted_cluster_encoder = encoder,
    STOP_WORDS = STOP_WORDS,
    umap_reduced_cluster_centroids=umap_reduced_cluster_centroids,
    fitted_bertopic_models = fitted_bertopic_models,
    top_k_results = 100,
    prediction_probability_threshold = 0.5,
    include_neighbors=True,
    max_neighbors_per_doc = 5,
    device = 'cpu',
    query_sim_threshold = 0.0
)

In [None]:
!pip freeze

In [None]:
from torch_geometric.nn import MessagePassing
from torch.nn import Linear, LeakyReLU
import torch.nn.functional as F
from torch_geometric.utils import softmax as pyg_softmax

class EAGATConv(MessagePassing):
    def __init__(self, in_channels, hidden_channels, edge_emb_dim=16, heads=2, dropout=0.3):
        super().__init__(aggr='add')
        self.heads = heads
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.dropout = dropout

        self.lin_node = Linear(in_channels, hidden_channels * heads, bias=False)
        self.lin_edge = Linear(2 * edge_emb_dim + 2, hidden_channels * heads, bias=False)
        self.att = Linear(3 * hidden_channels * heads, 1, bias=False)
        self.leaky_relu = LeakyReLU(0.2)
    def forward(self, x, edge_index, edge_attr):
        x = self.lin_node(x)
        edge_attr = self.lin_edge(edge_attr)
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr, index, ptr, size_i):
        x_cat = torch.cat([x_i, x_j, edge_attr], dim=-1)
        alpha = self.att(x_cat)
        alpha = self.leaky_relu(alpha)
        alpha = pyg_softmax(alpha, index, ptr, size_i)
        alpha = F.dropout(alpha, p=self.dropout, training=self.training)
        return x_j * alpha

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

class SubtopicEAGAT(torch.nn.Module):
    # Parameters = default values
    def __init__(self, in_channels, hidden_channels, num_subtopics, edge_emb_dim=16, heads=2, dropout=0.3):
        super().__init__()
        self.dropout = dropout

        # Subtopic pair embedding
        self.subtopic_embed = nn.Embedding(num_subtopics,edge_emb_dim)

        self.lin_edge = nn.Linear(2*edge_emb_dim+2, edge_emb_dim)

        # EGAT Layer
        self.egat1 = EGATConv(in_channels,hidden_channels,heads=heads,edge_emb_dim=edge_emb_dim, dropout=dropout)
        self.egat2 = EGATConv(hidden_channels * heads, hidden_channels, heads=1,edge_emb_dim=edge_emb_dim, dropout=dropout)

        #Decoder MLP for edge scoring
        self.edge_decoder = nn.Sequential(
            nn.Linear(2 * hidden_channels + 2 * edge_emb_dim + 2, 128),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(128,64),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(64,1)
        )

    def forward(self, x, edge_index, edge_subtopic, edge_weight, edge_sim):
        # Encode Node Features

        s1 = self.subtopic_embed(edge_subtopic[:,0])
        s2 = self.subtopic_embed(edge_subtopic[:,1])
        edge_feats = torch.cat([s1,s2, edge_sim.unsqueeze(1), edge_weight.unsqueeze(1)], dim=-1)

        x = self.egat1(x, edge_index, edge_feats).relu()
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.egat2(x, edge_index, edge_feats)

        # Compute edge-wise node embeddings
        src, tgt = edge_index
        h_src = x[src]
        h_tgt = x[tgt]

        edge_repr = torch.cat([h_src, h_tgt, edge_feats], dim=-1)
        return self.edge_decoder(edge_repr).squeeze(1)

In [None]:
in_channels = data.x.shape[1]
hidden_channels = 64          
num_subtopics   = len(subtopic_to_idx)

model = SubtopicEAGAT(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    num_subtopics=num_subtopics,
    edge_emb_dim=32,  
    heads=2,           
    dropout=0.5       
)

In [None]:
trained_model = train_gat_model(
    model=model,
    data=data,
    splits=splits,
    num_epochs=1000,
    hard_neg_start_epoch = 30,
    patience = 15,
    lr = .0001,
    device='cpu'
)