In [50]:
%matplotlib notebook


In [51]:
import networkx as nx
import numpy as np
import json
import pandas as pd
import pickle
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # Registers the 3D projection

def load_embeddings(base_path):
    title_emb = pd.read_csv(f"{base_path}/title_embeddings.csv", index_col=0).transpose()
    abstract_emb = pd.read_csv(f"{base_path}/abstract_embeddings.csv", index_col=0).transpose()
    return title_emb, abstract_emb

def load_paper_sources(base_path, rules=True):
    with open(f"{base_path}/paper_source_trace_valid_wo_ans.json", "r") as f:
        paper_source = json.load(f)
    with open(f"{base_path}/paper_source_trace_train_ans.json", "r") as f:
        paper_source_train = json.load(f)
    if rules:
        with open(f"{base_path}/paper_source_gen_by_rule.json", "r") as f:
            paper_source_rule = json.load(f)
            rule_papers = [{"_id": pid, "references": list(refs.values())} 
                      for pid, refs in paper_source_rule.items()]
        return paper_source + paper_source_train + rule_papers
    else:
        return paper_source + paper_source_train

def create_node_attributes(paper, title_emb, abstract_emb):
    paper_id = paper["_id"]
    has_embeddings = paper_id in title_emb.columns and paper_id in abstract_emb.columns
    attributes = {
        "title": paper.get("title", []),
        "authors": paper.get("authors", []),
        "year": paper.get("year", []),
        "venue": paper.get("venue", []),
        "most_important_references": paper.get("refs_trace", paper.get("ref_trace", [])),
        "title_embeddings": title_emb[paper_id].tolist() if has_embeddings else [],
        "abstract_embeddings": abstract_emb[paper_id].tolist() if has_embeddings else [],
        "paper_id": paper_id
    }
    return attributes
G = nx.Graph()
def main():
    base_path = "/Users/gabesmithline/Desktop/gnn_project/data"
    
    title_embeddings, abstract_embeddings = load_embeddings(base_path)
    paper_source = load_paper_sources(base_path, rules=False)
    
    G = nx.Graph()
    
    for paper in paper_source:
        G.add_node(paper["_id"], 
                   **create_node_attributes(paper, title_embeddings, abstract_embeddings))
    
    for paper in paper_source:
        refs_trace = paper.get("refs_trace", paper.get("ref_trace", []))
        important_refs = set(ref["_id"] for ref in refs_trace)
        #edge attributes are added in the order of the references
        for source in paper["references"]:
            edge_attributes = {
                "most_important_references": 1 if source in important_refs else 0,
                "source": source,
                "target": paper["_id"]
            }
            G.add_edge(paper["_id"], source, **edge_attributes)
         
    print(len(G.edges(data=True)))
    important_edges = [
        (u, v) for u, v, data in G.edges(data=True) if data['most_important_references'] == 1
    ]
    print("Edges with most_important_references=1:", len(important_edges))
    print(len(G.nodes(data=True)))
    
    # Visualize graph
    plt.figure(figsize=(15, 10))
    
    # Extract the subgraph of important edges
    G_important = G.edge_subgraph(important_edges).copy()
    print(f"Number of edges in G_important: {G_important.number_of_edges()}")
    print(f"Number of nodes in G_important: {G_important.number_of_nodes()}")\
    
    # Remove isolated nodes
    isolated_nodes = list(nx.isolates(G_important))
    G_important.remove_nodes_from(isolated_nodes)
    print(f"Removed {len(isolated_nodes)} isolated nodes from G_important.")
    
    # Extract the largest connected component
    if not nx.is_connected(G_important):
        largest_cc = max(nx.connected_components(G_important), key=len)
        G_important = G_important.subgraph(largest_cc).copy()
        print(f"Using largest connected component with {G_important.number_of_nodes()} nodes.")
    
    print(f"Number of edges in G_important: {G_important.number_of_edges()}")
    print(f"Number of nodes in G_important: {G_important.number_of_nodes()}")
    # Recalculate positions
    #xpos_important = nx.spring_layout(G_important, k=2, iterations=200)
    pos_important = nx.spring_layout(G_important, k=5, iterations=500)
    #pos_important = nx.circular_layout(G_important)
    print("Node positions:")
    for node, position in pos_important.items():
        print(f"{node}: {position}")
    
    x_values = [pos[0] for pos in pos_important.values()]
    y_values = [pos[1] for pos in pos_important.values()]
    #z_values = [pos[2] for pos in pos_important.values()]

    
    # Plot node positions to verify
    plt.figure(figsize=(8, 6))
    plt.scatter(x_values, y_values, s=200)
    plt.title('Node Positions')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.grid(True)
    plt.axis('equal')
    plt.show()
    
    # Visualization
    fig, ax = plt.subplots(figsize=(15, 10))

    nx.draw_networkx_nodes(G_important, pos_important, node_size=100, alpha=0.9, ax=ax)
    
    # Assign colors to edges
    '''
    edge_colors = [
        'red' if data['most_important_references'] == 1 else 'blue' 
        for u, v, data in G_important.edges(data=True)
    ]
    nx.draw_networkx_edges(
        G_important, pos_important, edge_color=edge_colors, alpha=0.7, width=2, ax=ax
    )
    '''
    
    for u, v in G_important.edges():
        x_values = [pos_important[u][0], pos_important[v][0]]
        y_values = [pos_important[u][1], pos_important[v][1]]
        color = 'red' if G_important[u][v]['most_important_references'] == 1 else 'blue'
        ax.plot(x_values, y_values, color=color, alpha=0.7, linewidth=2)

    #nx.draw_networkx_labels(G_important, pos_important, font_size=8, ax=ax )
    
    plt.title("Paper Citation Network\nRed edges indicate important references")
    plt.axis('equal')  # Ensure equal scaling
    plt.xlim(min(x_values) - 0.1, max(x_values) + 0.1)
    plt.ylim(min(y_values) - 0.1, max(y_values) + 0.1)
    plt.show()
    return G_important
    
    #with open(f'{base_path}/adj_matrix.pkl', 'wb') as f:
    #    pickle.dump(G, f)
    
if __name__ == "__main__":
    G_important = main()

32607
Edges with most_important_references=1: 610
18624


<IPython.core.display.Javascript object>

Number of edges in G_important: 610
Number of nodes in G_important: 943
Removed 0 isolated nodes from G_important.
Using largest connected component with 54 nodes.
Number of edges in G_important: 58
Number of nodes in G_important: 54
Node positions:
599c7988601a182cd2648a09: [0.42799105 0.04922537]
5dc3eb4e3a55ac3c4bb65817: [-0.03749316 -0.41621817]
599c7987601a182cd2648373: [-0.08841344  0.21557952]
5b67b47917c44aac1c8637c6: [0.59709685 0.24847384]
5ce2d032ced107d4c635260c: [-0.51254888 -0.43710015]
5ede0553e06a4c1b26a83f63: [-0.14289129 -0.19495273]
5a260c8117c44a4ba8a30f54: [-0.50718192  0.0953269 ]
5e5e18e493d709897ce3a0f2: [-0.6333889   0.49435914]
58d82fcbd649053542fd6482: [ 0.91876435 -0.16881302]
5f03f3b611dc83056223205d: [-0.50625699  0.67136705]
5c6a37d03a69b1c9e12a9fc4: [-0.10535243  0.40417276]
5f8d6be69fced0a24bbab01e: [ 0.36636196 -0.10864887]
5ede0553e06a4c1b26a83ff5: [-0.68523443  0.21160161]
5f7fdd328de39f0828397afd: [0.02855635 0.37928867]
5ea6adfa91e011a546871d52: [-

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [52]:

if not nx.is_connected(G_important):
    largest_cc = max(nx.connected_components(G_important), key=len)
    G_important = G_important.subgraph(largest_cc).copy()
    print(f"Using largest connected component with {G_important.number_of_nodes()} nodes.")
    
    # Compute 3D positions
pos_3d = nx.spring_layout(G_important, dim=3, k=5, iterations=500)
    
    # Prepare for 3D plotting
x_values = [pos[0] for pos in pos_3d.values()]
y_values = [pos[1] for pos in pos_3d.values()]
z_values = [pos[2] for pos in pos_3d.values()]
    
    # Start 3D plotting
fig = plt.figure(figsize=(15, 10))
ax = fig.add_subplot(111, projection='3d')
    
    # Draw nodes
node_colors = '#1f78b4'  # You can customize node colors
ax.scatter(x_values, y_values, z_values, s=100, c=node_colors, alpha=0.9)
    
    # Draw edges
for i, (u, v) in enumerate(G_important.edges()):
    x = [pos_3d[u][0], pos_3d[v][0]]
    y = [pos_3d[u][1], pos_3d[v][1]]
    z = [pos_3d[u][2], pos_3d[v][2]]
    color = 'red' if G_important[u][v]['most_important_references'] == 1 else 'blue'
    ax.plot(x, y, z, color=color, alpha=0.7, linewidth=2)
    
    # Optionally, add labels
    # for node, (x, y, z) in pos_3d.items():
    #     ax.text(x, y, z, s=node, fontsize=8)
    
ax.set_title("Paper Citation Network (3D Visualization)\nRed edges indicate important references")
ax.axis('off')
plt.show()

<IPython.core.display.Javascript object>

In [53]:
import os
import json
def load_openai_paper_sources():
    import os
    import numpy as np
    import pandas as pd

    data = []
    directory = os.path.join("data/paper_embed_openai")
    files = os.listdir(directory)
    
    for file in files:
        try:
            with open(os.path.join(directory, file), "rb") as f:
                paper_source = np.load(f, allow_pickle=True)
                
                # Adjust this part based on the actual structure of your .npy files
                if isinstance(paper_source, (tuple, list)) and len(paper_source) == 2:
                    _id, embedding = paper_source
                elif isinstance(paper_source, dict):
                    _id = paper_source.get('_id', file)  # Use filename as fallback
                    embedding = paper_source.get('embedding') or paper_source.get('paper_embeddings', paper_source)
                else:
                    _id = file  # Fallback to filename as _id
                    embedding = paper_source
                
                # Remove the '.npy' extension from _id if present
                if isinstance(_id, str) and _id.endswith(".npy"):
                    _id = _id[:-4]  # Remove the last 4 characters
                
                data.append({
                    "_id": _id,
                    "paper_embeddings": embedding
                })
        except Exception as e:
            print(f"Failed to load {file}: {e}")
    
    embeddings = pd.DataFrame(data)
    return embeddings

# Usage
openai_embeddings = load_openai_paper_sources()
print(openai_embeddings.head())
print(openai_embeddings.info())





                        _id                                   paper_embeddings
0  5582be480cf2fcbbc5f1c886  [0.031248431652784348, -0.006423081737011671, ...
1  53e9a7f1b7602d9703132a2c  [-0.026169415563344955, -0.008621706627309322,...
2  6034f66f91e01122c046f9dc  [-0.0154993562027812, 0.008253294974565506, -0...
3  61fb47e05aee126c0f8739ae  [-0.0031305551528930664, 0.024443166330456734,...
4  53e9b07db7602d9703ae6b08  [-0.03515831381082535, 0.029901135712862015, -...
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7541 entries, 0 to 7540
Data columns (total 2 columns):
 #   Column            Non-Null Count  Dtype 
---  ------            --------------  ----- 
 0   _id               7541 non-null   object
 1   paper_embeddings  7541 non-null   object
dtypes: object(2)
memory usage: 118.0+ KB
None


In [54]:
print(openai_embeddings.loc[openai_embeddings["_id"] == "5582be480cf2fcbbc5f1c886"]["paper_embeddings"].tolist())



[array([ 0.03124843, -0.00642308, -0.00843204, ..., -0.00556077,
       -0.00683792, -0.00154168])]


In [55]:
def create_node_attributes(paper, openai_embeddings):
    paper_id = paper["_id"]
    has_embeddings = paper_id in openai_embeddings["_id"].values
    attributes = {
        "title": paper.get("title", []),
        "authors": paper.get("authors", []),
        "year": paper.get("year", []),
        "venue": paper.get("venue", []),
        "most_important_references": paper.get("refs_trace", paper.get("ref_trace", [])),
        "paper_embeddings": openai_embeddings.loc[openai_embeddings["_id"] == paper_id]["paper_embeddings"].tolist() if has_embeddings else [],
        "paper_id": paper_id
    }
    return attributes

In [56]:
base_path = "/Users/gabesmithline/Desktop/gnn_project/data"
paper_sources = load_paper_sources(base_path, rules=False)

G = nx.Graph()
    
for paper in paper_sources:
    G.add_node(paper["_id"], 
        **create_node_attributes(paper, openai_embeddings))


for paper in paper_sources:
        refs_trace = paper.get("refs_trace", paper.get("ref_trace", []))
        important_refs = set(ref["_id"] for ref in refs_trace)
        #edge attributes are added in the order of the references
        for source in paper["references"]:
            edge_attributes = {
                "most_important_references": 1 if source in important_refs else 0,
                "source": source,
                "target": paper["_id"]
            }
            G.add_edge(paper["_id"], source, **edge_attributes)
         
print(len(G.edges(data=True)))
important_edges = [
    (u, v) for u, v, data in G.edges(data=True) if data['most_important_references'] == 1
    ]
print("Edges with most_important_references=1:", len(important_edges))
print(len(G.nodes(data=True)))

with open(f'{base_path}/adj_matrix_openai_no_rules.pkl', 'wb') as f:
    pickle.dump(G, f)
       



44540
Edges with most_important_references=1: 610
34391
