In [9]:
import networkx as nx
import numpy as np
import json
import pandas as pd
import pickle
import matplotlib.pyplot as plt
def load_embeddings(base_path):
    title_emb = pd.read_csv(f"{base_path}/title_embeddings.csv", index_col=0).transpose()
    abstract_emb = pd.read_csv(f"{base_path}/abstract_embeddings.csv", index_col=0).transpose()
    return title_emb, abstract_emb

def load_paper_sources(base_path, rules=True):
    with open(f"{base_path}/paper_source_trace_valid_wo_ans.json", "r") as f:
        paper_source = json.load(f)
    with open(f"{base_path}/paper_source_trace_train_ans.json", "r") as f:
        paper_source_train = json.load(f)
    if rules:
        with open(f"{base_path}/paper_source_gen_by_rule.json", "r") as f:
            paper_source_rule = json.load(f)
            rule_papers = [{"_id": pid, "references": list(refs.values())} 
                      for pid, refs in paper_source_rule.items()]
        return paper_source + paper_source_train + rule_papers
    else:
        return paper_source + paper_source_train

def create_node_attributes(paper, title_emb, abstract_emb):
    paper_id = paper["_id"]
    has_embeddings = paper_id in title_emb.columns and paper_id in abstract_emb.columns
    print(paper_id)
    attributes = {
        "title": paper.get("title", []),
        "authors": paper.get("authors", []),
        "year": paper.get("year", []),
        "venue": paper.get("venue", []),
        "most_important_references": paper.get("refs_trace", paper.get("ref_trace", [])),
        "title_embeddings": title_emb[paper_id].tolist() if has_embeddings else [],
        "abstract_embeddings": abstract_emb[paper_id].tolist() if has_embeddings else [],
        "paper_id": paper_id
    }
    return attributes

def main():
    base_path = "/Users/gabesmithline/Desktop/gnn_project/data"
    
    title_embeddings, abstract_embeddings = load_embeddings(base_path)
    paper_source = load_paper_sources(base_path, rules=True)
    
    G = nx.Graph()
    
    for paper in paper_source:
        G.add_node(paper["_id"], 
                   **create_node_attributes(paper, title_embeddings, abstract_embeddings))
    
    for paper in paper_source:
        refs_trace = paper.get("refs_trace", paper.get("ref_trace", []))
        important_refs = set(ref["_id"] for ref in refs_trace)
        #edge attributes are added in the order of the references
        
        for source in paper["references"]:
            edge_attributes = {
                "most_important_references": 1 if source in important_refs else 0,
                "source": source,
                "target": paper["_id"]
            }
            G.add_edge(paper["_id"], source, **edge_attributes)
         
    print(len(G.edges(data=True)))
    important_edges = [(u,v) for u,v,d in G.edges(data=True) if d['most_important_references'] == 1]
    print("Edges with most_important_references=1:", len(important_edges))
    print(len(G.nodes(data=True)))
    '''
    # Visualize graph
    plt.clf() # Clear any existing plots
    plt.close('all') # Close all figures
    plt.figure(figsize=(12, 8))
    pos = nx.spring_layout(G, k=1, iterations=50)
    
    nx.draw_networkx_nodes(G, pos, node_size=100, alpha=0.7)
    
    edge_colors = ['red' if G[u][v]['most_important_references'] else 'gray' 
                  for u,v in G.edges()]
    nx.draw_networkx_edges(G, pos, alpha=0.5, edge_color=edge_colors, width=0.5)
    
    plt.title("Paper Citation Network\nRed edges indicate important references")
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    # Save graph
    '''
    
    with open(f'{base_path}/adj_matrix.pkl', 'wb') as f:
       pickle.dump(G, f)
    
if __name__ == "__main__":
    main()

61dbf1dcd18a2b6e00d9f311
61dbf06b6750f87b50ecd224
6042187291e0115d09aff2a7
60c31feb6750f85387887e7c
5ff8844791e011c832676679
60b9a4ebe4510cd7c8fc6b77
6180ac435244ab9dcb793a8f
60c33c6a91e01104fa0ef733
60d996c80abde95dc965f5c0
60c402f491e011d44febefa7
61850e9691e01121084ca0d6
618ddb455244ab9dcbda8f5f
6164fcc15244ab9dcb24cf0b
60cbeaaf91e011eef576dac1
612c4c285244ab9dcbca22ce
618c94bb6750f806be6689f3
607ffb8b91e011772654f6ef
60757d6d91e0110f6fe6843e
6103d7ba91e01159791b21df
618c94976750f806be6689d2
6098feeb91e011aa8bcb6dbf
60f2b1d05244ab9dcbbbdfe7
607d4e9e91e011bf62020909
60782d0091e011f5ecc9dc04
60a640a491e0115d932bfd75
60c2fb6191e0117e30ca2951
6034f66f91e01122c046f9dc
60641c5c9e795e72406b65cd
6076c80791e0113d72574489
600fe888d4150a363c24b1e4
60c312ac9e795e9243fd165e
61fc99465aee126c0fcdcbf7
619e189e6750f82b1e8c7102
606c685691e0114248cd042d
605dbaf191e0113c28655a7f
604b4c7891e0110eed64c4e2
6065b2ef91e011d10ad615a8
60bdde338585e32c38af510f
60cae12b91e011b32937419e
614012c15244ab9dcb8166d0
