# Data Generation

## Setup

In [1]:
import copy
from typing import List, Dict, Optional, Union, Tuple, Literal # Added Tuple

In [2]:
import os
import json

import sys
import os

parent_dir = os.path.dirname(os.getcwd())
sys.path.append(parent_dir)

In [3]:
from graph.paper_graph import PaperGraph
from graph.graph_viz import GraphViz

In [4]:
similarity_threshold = 0.7
top_k_similar_papers = 20
similar_papers = {}

top_l_key_authors = 20
key_authors = {}


In [5]:
# driving examples
llm_api_key = os.getenv('GEMINI_API_KEY_3')
llm_model_name="gemini-2.0-flash"
embed_api_key = os.getenv('GEMINI_API_KEY_3')
embed_model_name="models/text-embedding-004"

research_topic = "llm literature review"
seed_dois = ['10.48550/arXiv.2406.10252',  # AutoSurvey: Large Language Models Can Automatically Write Surveys
            '10.48550/arXiv.2412.10415',  # Generative Adversarial Reviews: When LLMs Become the Critic
            '10.48550/arXiv.2402.12928',  # A Literature Review of Literature Reviews in Pattern Analysis and Machine Intelligence 
            ]
seed_titles = ['PaperRobot: Incremental Draft Generation of Scientific Ideas',
            'From Hypothesis to Publication: A Comprehensive Survey of AI-Driven Research Support Systems'
            ]

In [6]:
import numpy as np
from collections import Counter

def get_graph_stats(graph):
    """basic stats for graph"""
    graph_stats = {}
    graph_stats['node_cnt'] = len(graph.nodes)
    graph_stats['edge_cnt'] = len(graph.edges)
    print(f"Graph has {len(graph.nodes)} nodes and {len(graph.edges)} edges.")

    # check node types
    node_types = [node_data.get('nodeType') for _, node_data in graph.nodes(data=True)]
    node_types_cnt = Counter(node_types)
    sorted_node_counts = node_types_cnt.most_common()  # rank order by descending
    graph_stats['node_type'] = sorted_node_counts  # format like [(node type, nodes count), ...]
    print(f"There are {len(sorted_node_counts)} node types in this graph, they are:\n{sorted_node_counts}")

    # check edge types
    edge_types = [d.get('relationshipType') for _, _, d in graph.edges(data=True)]
    edge_types_cnt = Counter(edge_types)
    sorted_egdes_counts = edge_types_cnt.most_common()  # rank order by descending
    graph_stats['edge_type'] = sorted_egdes_counts  # format like [(node type, nodes count), ...]
    print(f"There are {len(sorted_egdes_counts)} edge types in this graph, they are:\n{sorted_egdes_counts}")

    return graph_stats

In [7]:
def get_paper_stats(graph:PaperGraph, seed_paper_dois):
    """get paper statistic in paper graph"""
    papers_stats = []
    for nid, node_data in graph.nodes(data=True):
        if node_data.get('nodeType') == 'Paper':
            # paper infos
            title = graph.nodes[nid].get('title')
            in_seed = True if nid in seed_paper_dois else False
            overall_cite_cnt = node_data.get('citationCount')
            overall_inf_cite_cnt = node_data.get('influentialCitationCount')
            overall_ref_cnt = node_data.get('influentialCitationCount')

            # for in edges
            in_edges_info = graph.in_edges(nid, data=True)
            local_citation_cnt = sum([1 for _, _, edge_data in in_edges_info if edge_data.get('relationshipType') == 'CITES'])
            sim_cnt_1 = sum([1 for _, _, edge_data in in_edges_info if edge_data.get('relationshipType') == 'SIMILAR_TO'])
        
            # for out edges
            out_edges_info = graph.out_edges(nid, data=True)
            local_ref_cnt = sum([1 for _, _, edge_data in out_edges_info if edge_data.get('relationshipType') == 'CITES'])
            sim_cnt_2 = sum([1 for _, _, edge_data in out_edges_info if edge_data.get('relationshipType') == 'SIMILAR_TO'])
            local_sim_cnt = sim_cnt_1 + sim_cnt_2

            # author infors
            tot_author_cnt = sum([1 for u in graph.predecessors(nid) if graph.nodes[u].get('nodeType') == 'Author'])
            h_index_lst, author_order_lst = [], []
            for u in graph.predecessors(n):
                if graph.nodes[u].get('nodeType') == 'Author':
                    h_index = graph.nodes[u].get('hIndex')
                    author_order = graph[u][nid].get('authorOrder')
                    if h_index:
                        h_index_lst.append(h_index)
                        author_order_lst.append(author_order)
            
            if len(h_index_lst) > 0 and len(h_index_lst) == len(author_order):
                avg_h_index = np.average(h_index_lst)
                weight_h_index = sum([x / y for x, y in zip(h_index_lst, author_order_lst)]) / len(h_index_lst)
            else:
                avg_h_index = None
                weight_h_index = None

            paper_stats = {"doi":nid, "title":title, "if_seed": in_seed,
                           "local_citation_cnt":local_citation_cnt, "local_reference_cnt": local_ref_cnt, "local_similarity_cnt":local_sim_cnt,
                           "global_citaion_cnt":overall_cite_cnt, "influencial_citation_cnt":overall_inf_cite_cnt, "global_refence_cnt": overall_ref_cnt,
                           "author_cnt":tot_author_cnt, "avg_h_index":avg_h_index, 'weighted_h_index':weight_h_index}
            papers_stats.append(paper_stats)
    return papers_stats

In [8]:
def get_author_stats(graph, seed_author_ids):
    """get author statistic in paper graph"""

    h_index_ref = {nid:node_data['hIndex'] for nid, node_data in graph.nodes(data=True) if node_data.get('nodeType') == 'Author' 
                   and node_data.get('hIndex') is not None}

    authors_stats = []
    for nid, node_data in graph.nodes(data=True):
        if node_data.get('nodeType') == 'Author':
            # properties
            author_name = node_data.get('name')
            h_index = node_data.get('hIndex')
            in_seed = True if nid in seed_author_ids else False
            global_paper_cnt = node_data.get('paperCount')
            global_citation_cnt = node_data.get('citationCount')

            # local stats
            out_edges_info = graph.out_edges(nid, data=True)
            local_paper_cnt = sum([1 for _, _, data in out_edges_info if data.get('relationshipType') == 'WRITES'])
            # get coauthors
            coauthor_ids = []
            for u,v, edge_data in out_edges_info:
                if edge_data.get('relationshipType') == 'WRITES':
                    coauthors = edge_data.get('coauthors', [])
                    coauthor_ids.extend([x['authorId'] for x in coauthors if x.get('authorId') is not None])
            
            # get top coauthors
            coauthor_cnt = Counter(coauthor_ids)
            top_coauthors = coauthor_cnt.most_common()[0:5]  # rank order by descending

            # calculate top coauthor h-index
            coauthor_cnt = 0
            sum_coauthor_h_index = 0
            for idx, item in enumerate(top_coauthors):
                coauthor_id = item[0]
                coauthor_hindex = h_index_ref.get(coauthor_id)
                if coauthor_hindex is not None:
                    sum_coauthor_h_index += coauthor_hindex /idx
                    coauthor_cnt += 1
            weighted_coauthor_h_index = sum_coauthor_h_index / coauthor_cnt if coauthor_cnt > 0 else None

            author_stat = {"author_id":nid, "author_name":author_name, "is_seed":in_seed,
                           "h_index":h_index, "global_paper_cnt":global_paper_cnt, "global_citation_cnt":global_citation_cnt,
                           "local_paper_cnt":local_paper_cnt, 
                           "top_coauthors":top_coauthors, "weighted_coauthor_h_index": weighted_coauthor_h_index
                          }
            authors_stats.append(author_stat)
    return authors_stats

In [9]:
from paper_expansion import PaperCollector

ps = PaperCollector(   
    research_topic = research_topic,   
    seed_paper_titles = seed_titles, 
    seed_paper_dois = seed_dois,
    llm_api_key = llm_api_key,
    llm_model_name = llm_model_name,
    embed_api_key = embed_api_key,
    embed_model_name = embed_model_name,
    from_dt = '2020-01-01',
    to_dt = '2025-04-30',
    fields_of_study = ['Computer Science'],
    search_limit = 100,
    recommend_limit = 100,
    citation_limit = 100,
    paper_graph_name = 'paper_graph'
    )

## Initial Search

### Data Geneeration

In [10]:
# --- INITIAL QUERY on SEED ---
# initial query for seed papers basic information
print("--- Running Initial Query for Seed Papers Information ---")
await ps.init_search(
    ps.research_topic,
    ps.seed_paper_titles,
    ps.seed_paper_dois,
    ps.search_limit,
    ps.from_dt,
    ps.to_dt
)

2025-04-17 21:42:53,066 - INFO - SemanticScholarKit initialized with max_concurrency=10, sleep_interval=3.0s
2025-04-17 21:42:53,070 - INFO - Fetching papers by 3 DOIs...
2025-04-17 21:42:53,073 - INFO - Fetching papers by title: 'PaperRobot: Incremental Draft Generation of Scientific Ideas...'
2025-04-17 21:42:53,075 - INFO - Fetching papers by title: 'From Hypothesis to Publication: A Comprehensive Survey of AI-Driven Research Support Systems...'
2025-04-17 21:42:53,078 - INFO - Fetching papers by topic: 'llm literature review...'
2025-04-17 21:42:53,080 - INFO - Running 4 initial query tasks concurrently...
2025-04-17 21:42:53,083 - INFO - async_search_paper_by_ids: Creating 1 tasks for 3 IDs.
2025-04-17 21:42:53,087 - INFO - async_search_paper_by_ids: Gathering 1 tasks...
2025-04-17 21:42:53,091 - INFO - async_search_paper_by_keywords: Searching papers by keyword: 'PaperRobot: Incremental Draft Generation of Scient...' with effective limit 100.
2025-04-17 21:42:53,092 - INFO - _syn

--- Running Initial Query for Seed Papers Information ---


2025-04-17 21:42:53,371 - INFO - HTTP Request: GET https://api.semanticscholar.org/graph/v1/paper/search?query=llm+literature+review&fields=abstract%2Cauthors%2CcitationCount%2CcitationStyles%2CcorpusId%2CexternalIds%2CfieldsOfStudy%2CinfluentialCitationCount%2CisOpenAccess%2Cjournal%2CopenAccessPdf%2CpaperId%2CpublicationDate%2CpublicationTypes%2CpublicationVenue%2CreferenceCount%2Cs2FieldsOfStudy%2Ctitle%2Curl%2Cvenue%2Cyear&offset=0&limit=100 "HTTP/1.1 429 "
2025-04-17 21:42:53,565 - INFO - HTTP Request: GET https://api.semanticscholar.org/graph/v1/paper/search?query=From+Hypothesis+to+Publication%3A+A+Comprehensive+Survey+of+AI-Driven+Research+Support+Systems&fields=abstract%2Cauthors%2CcitationCount%2CcitationStyles%2CcorpusId%2CexternalIds%2CfieldsOfStudy%2CinfluentialCitationCount%2CisOpenAccess%2Cjournal%2CopenAccessPdf%2CpaperId%2CpublicationDate%2CpublicationTypes%2CpublicationVenue%2CreferenceCount%2Cs2FieldsOfStudy%2Ctitle%2Curl%2Cvenue%2Cyear&offset=0&limit=100 "HTTP/1.1 4

### Basic Stats

In [11]:
# get seed DOIs
seed_paper_dois = [node['id'] for node in ps.nodes_json if node['labels'] == ['Paper'] and node['properties'].get('from_seed')==True]
seed_author_ids = []
for node in ps.nodes_json:
    if node['labels'] == ['Paper'] and node['properties'].get('from_seed')==True and isinstance(node['properties'].get('authors'), list):
        authors_id = [x['authorId'] for x in node['properties']['authors'] if x['authorId'] is not None] 
        seed_author_ids.extend(authors_id)
seed_paper_json = [node for node in ps.nodes_json if node['labels'] == ['Paper'] and node['properties'].get('from_seed')==True]
ps.explored_nodes['seed'].extend(seed_paper_dois) 

print(len(seed_paper_dois), len(seed_author_ids))

4 27


In [14]:
# basic stats
G = copy.deepcopy(ps.pg)
graph_stat = get_graph_stats(G)

Graph has 661 nodes and 570 edges.
There are 4 node types in this graph, they are:
[('Author', 455), ('Paper', 102), ('Journal', 56), ('Venue', 48)]
There are 3 edge types in this graph, they are:
[('WRITES', 459), ('PRINTS_ON', 59), ('RELEASES_IN', 52)]


### Similarity

In [16]:
# --- INTERMEDIATE: CALCULATE SIMILARITY ---
# get all paper infos
paper_nodes_json = [node for node in ps.nodes_json 
                    if node['labels'] == ['Paper'] and 
                    node['properties'].get('title') is not None and node['properties'].get('abstract') is not None]
paper_dois = [node['id'] for node in paper_nodes_json]

# calculate paper nodes similarity
semantic_similar_pool = await ps.cal_embed_and_similarity(
    paper_nodes_json=paper_nodes_json,
    paper_dois_1=paper_dois, 
    paper_dois_2=paper_dois,
    similarity_threshold=similarity_threshold,
    )

2025-04-17 21:49:43,966 - INFO - Generating embeddings for 94 papers...
2025-04-17 21:50:02,698 - INFO - Shape of embeds_1: (94, 768)
2025-04-17 21:50:02,699 - INFO - Shape of embeds_2: (94, 768)
2025-04-17 21:50:02,699 - INFO - Calculating similarity matrix...
2025-04-17 21:50:02,713 - INFO - Processing similarity matrix to create relationships...


In [None]:
edges_json = semantic_similar_pool
if type(edges_json) == dict:
    edges_json = [edges_json]

nx_edges_info = []
for item in edges_json:
    source_id = item['startNodeId']
    target_id = item['endNodeId']
    properties = item['properties']
    properties['relationshipType'] = item['relationshipType']
    # be aware that relationship shall take the form like (4, 5, dict(route=282)) for networkX
    nx_edges_info.append((source_id, target_id, properties))  

G.add_edges_from(nx_edges_info)

### Filtering & Ranking

Basic Stats

In [None]:
# basic stats
graph_stat = get_graph_stats(G)

Paper Stats

In [None]:
# for now the paper does not have citation chain
paper_stats_similar = get_paper_stats(G, seed_paper_dois, order_by='similarity')

In [None]:
filtered_papers_stats = [x for x in paper_stats_similar if x['if_seed'] == False][0:top_k_similar_papers]
filtered_papers_dois = [x['doi'] for x in filtered_papers_stats]
for item in filtered_papers_stats:
    print(item)

In [None]:
similar_papers['from_init_search'] = filtered_papers_stats

Author Stats

In [None]:
author_stats = get_author_stats(G, seed_author_ids)

In [None]:
sorted_author_writes = sorted(author_stats, key=lambda x:x['local_paper_cnt'], reverse=True)

In [None]:
filtered_authors = [x for x in sorted_author_writes if x['is_seed'] == False][0:top_l_key_authors]
key_authors['init_search'] = filtered_authors
for item in filtered_authors:
    print(item)

### Graph Viz

In [None]:
# from graph.graph_viz import GraphViz
viz = GraphViz(G, 'Paper Graph After Init Search')
viz.preprocessing()
viz.visulization()