In [6]:
import networkx as nx
import numpy as np
import json
import pandas as pd
import pickle
import matplotlib.pyplot as plt
def load_embeddings(base_path):
    title_emb = pd.read_csv(f"{base_path}/title_embeddings.csv", index_col=0).transpose()
    abstract_emb = pd.read_csv(f"{base_path}/abstract_embeddings.csv", index_col=0).transpose()
    return title_emb, abstract_emb

def load_paper_sources(base_path):
    with open(f"{base_path}/paper_source_trace_valid_wo_ans.json", "r") as f:
        paper_source = json.load(f)
    with open(f"{base_path}/paper_source_trace_train_ans.json", "r") as f:
        paper_source_train = json.load(f)
    
    with open(f"{base_path}/paper_source_gen_by_rule.json", "r") as f:
        paper_source_rule = json.load(f)
        rule_papers = [{"_id": pid, "references": list(refs.values())} 
                      for pid, refs in paper_source_rule.items()]
        
    
    return paper_source + paper_source_train + rule_papers

def create_node_attributes(paper, title_emb, abstract_emb):
    paper_id = paper["_id"]
    has_embeddings = paper_id in title_emb.index and paper_id in abstract_emb.index
    
    attributes = {
        "title": paper.get("title", []),
        "authors": paper.get("authors", []),
        "year": paper.get("year", []),
        "venue": paper.get("venue", []),
        "most_important_references": paper.get("refs_trace", paper.get("ref_trace", [])),
        "title_embeddings": title_emb.loc[paper_id].tolist() if has_embeddings else [],
        "abstract_embeddings": abstract_emb.loc[paper_id].tolist() if has_embeddings else []
    }
    return attributes

def main():
    base_path = "/Users/gabesmithline/Desktop/gnn_project/data"
    
    title_embeddings, abstract_embeddings = load_embeddings(base_path)
    paper_source = load_paper_sources(base_path)
    
    G = nx.Graph()
    
    for paper in paper_source:
        G.add_node(paper["_id"], 
                   **create_node_attributes(paper, title_embeddings, abstract_embeddings))
    
    for paper in paper_source:
        refs_trace = paper.get("refs_trace", paper.get("ref_trace", []))
        for source in paper["references"]:
            G.add_edge(paper["_id"], source, 
                      most_important_references=int(source in refs_trace))
    print(len(G.edges(data=True)))
    print(len(G.nodes(data=True)))
    '''
    # Visualize graph
    plt.figure(figsize=(12, 8))
    pos = nx.spring_layout(G, k=1, iterations=50)
    
    # Draw nodes
    nx.draw_networkx_nodes(G, pos, node_size=100, alpha=0.7)
    
    # Draw edges with different colors based on importance
    edge_colors = ['red' if G[u][v]['most_important_references'] else 'gray' 
                  for u,v in G.edges()]
    nx.draw_networkx_edges(G, pos, alpha=0.5, edge_color=edge_colors, width=0.5)
    
    plt.title("Paper Citation Network\nRed edges indicate important references")
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    # Save graph
    '''
    with open(f'{base_path}/adj_matrix.pkl', 'wb') as f:
        pickle.dump(G, f)
    
if __name__ == "__main__":
    main()

50996
36752
