# Data Generation

## Setup

In [1]:
import copy
from typing import List, Dict, Optional, Union, Tuple, Literal # Added Tuple

In [2]:
import os
import json

import sys
import os

parent_dir = os.path.dirname(os.getcwd())
sys.path.append(parent_dir)

In [3]:
from graph.paper_graph import PaperGraph
from graph.graph_viz import GraphViz

In [4]:
graph_evolution = {}

graph_stats = {}

similarity_threshold = 0.7
top_k_similar_papers = 20
similar_papers = {}

top_l_key_authors = 20
key_authors = {}

crossref_papers = {}
top_m_corssref_papers = 20

In [5]:
candit_edges_pool = []

In [6]:
# driving examples
llm_api_key = os.getenv('GEMINI_API_KEY_3')
llm_model_name="gemini-2.0-flash"
embed_api_key = os.getenv('GEMINI_API_KEY_3')
embed_model_name="models/text-embedding-004"

research_topic = "llm literature review"
seed_dois = ['10.48550/arXiv.2406.10252',  # AutoSurvey: Large Language Models Can Automatically Write Surveys
            '10.48550/arXiv.2412.10415',  # Generative Adversarial Reviews: When LLMs Become the Critic
            '10.48550/arXiv.2402.12928',  # A Literature Review of Literature Reviews in Pattern Analysis and Machine Intelligence 
            ]
seed_titles = ['PaperRobot: Incremental Draft Generation of Scientific Ideas',
            'From Hypothesis to Publication: A Comprehensive Survey of AI-Driven Research Support Systems'
            ]

In [7]:
import numpy as np
from collections import Counter

def get_graph_stats(graph):
    """basic stats for graph"""
    graph_stats = {}
    graph_stats['node_cnt'] = len(graph.nodes)
    graph_stats['edge_cnt'] = len(graph.edges)
    print(f"Graph has {len(graph.nodes)} nodes and {len(graph.edges)} edges.")

    # check node types
    node_types = [node_data.get('nodeType') for _, node_data in graph.nodes(data=True)]
    node_types_cnt = Counter(node_types)
    sorted_node_counts = node_types_cnt.most_common()  # rank order by descending
    graph_stats['node_type'] = sorted_node_counts  # format like [(node type, nodes count), ...]
    print(f"There are {len(sorted_node_counts)} node types in this graph, they are:\n{sorted_node_counts}")

    # check edge types
    edge_types = [d.get('relationshipType') for _, _, d in graph.edges(data=True)]
    edge_types_cnt = Counter(edge_types)
    sorted_egdes_counts = edge_types_cnt.most_common()  # rank order by descending
    graph_stats['edge_type'] = sorted_egdes_counts  # format like [(node type, nodes count), ...]
    print(f"There are {len(sorted_egdes_counts)} edge types in this graph, they are:\n{sorted_egdes_counts}")

    return graph_stats

In [8]:
def get_paper_stats(graph, seed_paper_dois):
    """get paper statistic in paper graph"""
    papers_stats = []
    for nid, node_data in graph.nodes(data=True):
        if node_data.get('nodeType') == 'Paper':
            # paper infos
            title = graph.nodes[nid].get('title')
            in_seed = True if nid in seed_paper_dois else False
            overall_cite_cnt = node_data.get('citationCount')
            overall_inf_cite_cnt = node_data.get('influentialCitationCount')
            overall_ref_cnt = node_data.get('influentialCitationCount')

            # for in edges
            in_edges_info = graph.in_edges(nid, data=True)
            local_citation_cnt = 0  # local paper graph cites papers (other cites this one)
            sim_cnt_1 = 0  # local paper graph similar papers to this one
            max_sim_to_seed_1 = -1  # max similarity of this paper to seed papers
            for u, _, edge_data in in_edges_info:
                if edge_data.get('relationshipType') == 'CITES':
                    local_citation_cnt += 1
                elif edge_data.get('relationshipType') == 'SIMILAR_TO':
                    sim_cnt_1 += 1
                    if u in seed_paper_dois:
                        if edge_data.get('weight') > max_sim_to_seed_1:
                            max_sim_to_seed_1 = edge_data.get('weight')

            # for out edges
            out_edges_info = graph.out_edges(nid, data=True)
            local_ref_cnt = 0  # local paper graph cites papers (other cites this one)
            sim_cnt_2 = 0  # local paper graph similar papers to this one
            max_sim_to_seed_2 = -1  # max similarity of this paper to seed papers
            for _, v, edge_data in out_edges_info:
                if edge_data.get('relationshipType') == 'CITES':
                    local_ref_cnt += 1
                elif edge_data.get('relationshipType') == 'SIMILAR_TO':
                    sim_cnt_2 += 1
                    if v in seed_paper_dois:
                        if edge_data.get('weight') > max_sim_to_seed_2:
                            max_sim_to_seed_2 = edge_data.get('weight')

            # author infors
            author_ids_lst = [x['authorId'] for x in node_data.get('authors', []) if x.get('authorId') is not None]
            tot_author_cnt = len(author_ids_lst)

            # get author order and h-index
            h_index_lst, author_order_lst = [], []
            for idx, aid in enumerate(author_ids_lst):
                author_order = idx + 1
                h_index = graph.nodes[aid].get('hIndex')
                if h_index is not None:
                    h_index_lst.append(h_index)
                    author_order_lst.append(author_order)

            if len(h_index_lst) > 0:
                avg_h_index = np.average(h_index_lst)
                weight_h_index = sum([x / y for x, y in zip(h_index_lst, author_order_lst)]) / len(h_index_lst)
            else:
                avg_h_index = None
                weight_h_index = None

            paper_stats = {"doi":nid, "title":title, "if_seed": in_seed,
                           "local_citation_cnt":local_citation_cnt, "local_reference_cnt": local_ref_cnt, 
                           "local_similarity_cnt":sim_cnt_1+sim_cnt_2, "max_sim_to_seed":max(max_sim_to_seed_1, max_sim_to_seed_2),
                           "global_citaion_cnt":overall_cite_cnt, "influencial_citation_cnt":overall_inf_cite_cnt, "global_refence_cnt": overall_ref_cnt,
                           "author_cnt":tot_author_cnt, "avg_h_index":avg_h_index, 'weighted_h_index':weight_h_index}
            papers_stats.append(paper_stats)
    return papers_stats

In [9]:
def get_author_stats(graph, seed_author_ids):
    """get author statistic in paper graph"""

    h_index_ref = {nid:node_data['hIndex'] for nid, node_data in graph.nodes(data=True) if node_data.get('nodeType') == 'Author' 
                   and node_data.get('hIndex') is not None}

    authors_stats = []
    for nid, node_data in graph.nodes(data=True):
        if node_data.get('nodeType') == 'Author':
            # properties
            author_name = node_data.get('name')
            h_index = node_data.get('hIndex')
            in_seed = True if nid in seed_author_ids else False
            global_paper_cnt = node_data.get('paperCount')
            global_citation_cnt = node_data.get('citationCount')

            # local stats
            out_edges_info = graph.out_edges(nid, data=True)
            local_paper_cnt = sum([1 for _, _, data in out_edges_info if data.get('relationshipType') == 'WRITES'])
            # get coauthors
            coauthor_ids = []
            for u,v, edge_data in out_edges_info:
                if edge_data.get('relationshipType') == 'WRITES':
                    coauthors = edge_data.get('coauthors', [])
                    coauthor_ids.extend([x['authorId'] for x in coauthors if x.get('authorId') is not None])
            
            # get top coauthors
            coauthor_cnt = Counter(coauthor_ids)
            top_coauthors = coauthor_cnt.most_common()[0:5]  # rank order by descending

            # calculate top coauthor h-index
            coauthor_cnt = 0
            sum_coauthor_h_index = 0
            for idx, item in enumerate(top_coauthors):
                coauthor_id = item[0]
                coauthor_hindex = h_index_ref.get(coauthor_id)
                if coauthor_hindex is not None:
                    sum_coauthor_h_index += coauthor_hindex /(idx + 1)
                    coauthor_cnt += 1
            weighted_coauthor_h_index = sum_coauthor_h_index / coauthor_cnt if coauthor_cnt > 0 else None

            author_stat = {"author_id":nid, "author_name":author_name, "is_seed":in_seed,
                           "h_index":h_index, "global_paper_cnt":global_paper_cnt, "global_citation_cnt":global_citation_cnt,
                           "local_paper_cnt":local_paper_cnt, 
                           "top_coauthors":top_coauthors, "weighted_coauthor_h_index": weighted_coauthor_h_index
                          }
            authors_stats.append(author_stat)
    return authors_stats

In [10]:
import networkx as nx
from typing import List, Dict, Union, List, Set, Tuple, Hashable, Literal, Optional

NodeType = Hashable # 节点类型通常是可哈希的

def find_wcc_subgraphs(
    graph,
    target_nodes: Union[NodeType, List[NodeType], Set[NodeType], Tuple[NodeType]]
) -> List[nx.MultiDiGraph]:
    """查找包含一个或多个指定节点的弱连通分量对应的子图。
    Args:
        graph: NetworkX MultiDiGraph 图对象。
        target_nodes: 一个节点 ID，或一个包含节点 ID 的列表、集合或元组。
    Returns:
        一个包含所有找到的弱连通分量子图 (作为独立的 MultiDiGraph 副本) 的列表。
        如果目标节点不在图中或找不到对应的连通分量，则返回空列表。
        注意：如果多个目标节点在同一个连通分量中，该分量的子图只会被返回一次。
    """
    # 1. 标准化输入为集合
    if isinstance(target_nodes, (list, set, tuple)):
        target_nodes_set = set(target_nodes)
    else:
        # 假设是单个节点 ID
        target_nodes_set = {target_nodes}

    # 2. 检查所有目标节点是否存在于图中
    missing_nodes = target_nodes_set - set(graph.nodes())
    if missing_nodes:
        print(f"警告：以下目标节点不在图中，将被忽略: {missing_nodes}")
        target_nodes_set -= missing_nodes # 移除不存在的节点

    if not target_nodes_set:
        print("错误：没有有效的目标节点可供查找。")
        return []

    # 3. 查找并收集包含任何目标节点的弱连通分量
    found_subgraphs = []
    found_components_nodes = set() # 用于跟踪已添加的分量的节点集，避免重复

    for component_nodes in nx.weakly_connected_components(graph):
        component_set = set(component_nodes)
        # 4. 检查当前分量是否包含任何目标节点 (使用集合交集)
        if not target_nodes_set.isdisjoint(component_set): # 如果交集非空
            # 检查这个分量是否已经添加过 (基于其节点集合)
            # frozenset 是可哈希的，可以放入集合中
            component_frozenset = frozenset(component_set)
            if component_frozenset not in found_components_nodes:
                # 5. 提取子图并添加到结果列表
                subgraph = graph.subgraph(component_nodes).copy()
                found_subgraphs.append(subgraph)
                found_components_nodes.add(component_frozenset)

                # Optional: 如果我们确定一个目标节点只能属于一个WCC,
                # 可以在这里从 target_nodes_set 中移除 component_set 里的目标节点
                # 以可能稍微提高后续迭代的效率，但这通常不是必需的
                # target_nodes_set -= component_set

    return found_subgraphs

In [11]:
citation_limit = 100

if len(seed_dois) < 10 or len(seed_titles) < 10:
    search_limit = 100
    recommend_limit = 100
else:
    search_limit = 50
    recommend_limit = 50

In [12]:
from paper_extension import PaperCollector

ps = PaperCollector(   
    research_topic = research_topic,   
    seed_paper_titles = seed_titles, 
    seed_paper_dois = seed_dois,
    llm_api_key = llm_api_key,
    llm_model_name = llm_model_name,
    embed_api_key = embed_api_key,
    embed_model_name = embed_model_name,
    from_dt = '2020-01-01',
    to_dt = '2025-04-30',
    fields_of_study = ['Computer Science'],
    search_limit = search_limit,
    recommend_limit = recommend_limit,
    citation_limit = citation_limit
    )

  from .autonotebook import tqdm as notebook_tqdm


## Initial Search

In [13]:
round = 1

### Data Geneeration

It may take 30 seconds to 2 mins to complete data generation task.

How to get user actively invovled?  
- a progress bar?
- showing realtime progress? 

In [14]:
# --- INITIAL QUERY on SEED ---
# initial query for seed papers basic information
print("--- Running Initial Query for Seed Papers Information ---")
await ps.init_search(
    research_topic=ps.research_topic,
    seed_paper_titles=ps.seed_paper_titles,
    seed_paper_dois=ps.seed_paper_dois,
    round=round,
    search_limit=ps.search_limit,
    from_dt=ps.from_dt,
    to_dt=ps.to_dt
)

2025-04-21 14:42:29,106 - INFO - SemanticScholarKit initialized with max_concurrency=10, sleep_interval=3.0s
2025-04-21 14:42:29,107 - INFO - Fetching papers by 3 DOIs...
2025-04-21 14:42:29,108 - INFO - Fetching papers by title: 'PaperRobot: Incremental Draft Generation of Scientific Ideas...'
2025-04-21 14:42:29,108 - INFO - Fetching papers by title: 'From Hypothesis to Publication: A Comprehensive Survey of AI-Driven Research Support Systems...'
2025-04-21 14:42:29,109 - INFO - Fetching papers by topic: 'llm literature review...'
2025-04-21 14:42:29,110 - INFO - Running 4 initial query tasks concurrently...
2025-04-21 14:42:29,111 - INFO - async_search_paper_by_ids: Creating 1 tasks for 3 IDs.
2025-04-21 14:42:29,112 - INFO - async_search_paper_by_ids: Gathering 1 tasks...
2025-04-21 14:42:29,113 - INFO - async_search_paper_by_keywords: Searching papers by keyword: 'PaperRobot: Incremental Draft Generation of Scient...' with effective limit 100.
2025-04-21 14:42:29,114 - INFO - _syn

--- Running Initial Query for Seed Papers Information ---


2025-04-21 14:42:30,281 - INFO - HTTP Request: GET https://api.semanticscholar.org/graph/v1/paper/search?query=llm%20literature%20review&fields=abstract,authors,citationCount,citationStyles,corpusId,externalIds,fieldsOfStudy,influentialCitationCount,isOpenAccess,journal,openAccessPdf,paperId,publicationDate,publicationTypes,publicationVenue,referenceCount,s2FieldsOfStudy,title,url,venue,year&offset=0&limit=100 "HTTP/1.1 429 "
2025-04-21 14:42:30,330 - INFO - HTTP Request: GET https://api.semanticscholar.org/graph/v1/paper/search?query=From%20Hypothesis%20to%20Publication:%20A%20Comprehensive%20Survey%20of%20AI-Driven%20Research%20Support%20Systems&fields=abstract,authors,citationCount,citationStyles,corpusId,externalIds,fieldsOfStudy,influentialCitationCount,isOpenAccess,journal,openAccessPdf,paperId,publicationDate,publicationTypes,publicationVenue,referenceCount,s2FieldsOfStudy,title,url,venue,year&offset=0&limit=100 "HTTP/1.1 429 "
2025-04-21 14:42:30,449 - INFO - HTTP Request: POST

### Basic Stats

In [15]:
# get seed DOIs
seed_paper_dois = [node['id'] for node in ps.nodes_json if node['labels'] == ['Paper'] and node['properties'].get('from_seed')==True]
seed_author_ids = []
for node in ps.nodes_json:
    if node['labels'] == ['Paper'] and node['properties'].get('from_seed')==True and isinstance(node['properties'].get('authors'), list):
        authors_id = [x['authorId'] for x in node['properties']['authors'] if x['authorId'] is not None] 
        seed_author_ids.extend(authors_id)
seed_paper_json = [node for node in ps.nodes_json if node['labels'] == ['Paper'] and node['properties'].get('from_seed')==True]

print(len(seed_paper_dois), len(seed_author_ids))
print(len(ps.seeds['paper']), len(ps.seeds['author']))

4 27
4 27


In [16]:
# basic stats
G_init = PaperGraph(name='Paper Graph Init Search')
G_init.add_graph_nodes(ps.nodes_json)
G_init.add_graph_edges(ps.edges_json)
g_stat = get_graph_stats(G_init)

Graph has 664 nodes and 574 edges.
There are 4 node types in this graph, they are:
[('Author', 459), ('Paper', 102), ('Journal', 56), ('Venue', 47)]
There are 3 edge types in this graph, they are:
[('WRITES', 464), ('PRINTS_ON', 59), ('RELEASES_IN', 51)]


In [17]:
if 'init_search' not in graph_stats.keys():
    graph_stats['init_search'] = {}
graph_stats['init_search']['wo_similarity'] = g_stat

**Interactive**

In [18]:
print(f"I have successfully captured {len(seed_paper_dois)} seed papers.")
print("""I would recommend you further explore the following information:
    - seed paper citation chain to trace reference and citing papers;
    - seed paper authors to see if any related work
    - let me recommend similar papers on the topic""")

I have successfully captured 4 seed papers.
I would recommend you further explore the following information:
    - seed paper citation chain to trace reference and citing papers;
    - seed paper authors to see if any related work
    - let me recommend similar papers on the topic


### Similarity

In [19]:
# --- INTERMEDIATE: CALCULATE SIMILARITY ---
# get all paper infos
paper_nodes_json = [node for node in ps.nodes_json 
                    if node['labels'] == ['Paper'] and 
                    node['properties'].get('title') is not None and node['properties'].get('abstract') is not None]
paper_dois = [node['id'] for node in paper_nodes_json]

# calculate paper nodes similarity
semantic_similar_pool = await ps.cal_embed_and_similarity(
    paper_nodes_json=paper_nodes_json,
    paper_dois_1=paper_dois, 
    paper_dois_2=paper_dois,
    similarity_threshold=similarity_threshold,
    )

2025-04-21 14:45:29,319 - INFO - Generating embeddings for 94 papers...
2025-04-21 14:45:37,603 - INFO - Shape of embeds_1: (94, 768)
2025-04-21 14:45:37,603 - INFO - Shape of embeds_2: (94, 768)
2025-04-21 14:45:37,603 - INFO - Calculating similarity matrix...
2025-04-21 14:45:37,605 - INFO - Processing similarity matrix to create relationships...


In [None]:
edges_json = semantic_similar_pool
if type(edges_json) == dict:
    edges_json = [edges_json]

nx_edges_info = []
for item in edges_json:
    source_id = item['startNodeId']
    target_id = item['endNodeId']
    properties = item['properties']
    properties['relationshipType'] = item['relationshipType']
    # be aware that relationship shall take the form like (4, 5, dict(route=282)) for networkX
    nx_edges_info.append((source_id, target_id, properties))  
    item['dataGeneration'] = {'round': 1, 'source': 'init_search'}

G_init.add_edges_from(nx_edges_info)

### Filtering & Ranking

Basic Stats

In [21]:
# basic stats
g_stat = get_graph_stats(G_init)
graph_stats['init_search']['w_similarity'] = g_stat

Graph has 664 nodes and 1447 edges.
There are 4 node types in this graph, they are:
[('Author', 459), ('Paper', 102), ('Journal', 56), ('Venue', 47)]
There are 4 edge types in this graph, they are:
[('SIMILAR_TO', 873), ('WRITES', 464), ('PRINTS_ON', 59), ('RELEASES_IN', 51)]


In [22]:
# get all paper infos
paper_nodes_json = [node for node in ps.nodes_json 
                    if node['labels'] == ['Paper'] and 
                    node['properties'].get('title') is not None and node['properties'].get('abstract') is not None]
paper_dois = [node['id'] for node in paper_nodes_json]

Paper Stats

In [23]:
# for now the paper does not have citation chain
paper_stats = get_paper_stats(G_init, seed_paper_dois)

In [24]:
# filter similar papers to help build cross reference
sorted_paper_similarity = sorted(paper_stats, key=lambda x:x['max_sim_to_seed'], reverse=True)

candit_paper_dois = []
filtered_papers_stats = []
i = 0
for x in sorted_paper_similarity:
    if i < 20:
        if (x['if_seed'] == False  # exclude seed papers 
            and x['local_similarity_cnt'] > (len(paper_dois) / 5)):  # select most similar to others
            candit_paper_dois.append(x['doi'])
            filtered_papers_stats.append(x)
            i += 1
    else:
        break

similar_papers['init_search'] = filtered_papers_stats


Author Stats

In [25]:
author_stats = get_author_stats(G_init, seed_author_ids)

In [26]:
sorted_author_writes = sorted(author_stats, key=lambda x:x['local_paper_cnt'], reverse=True)
filtered_authors = [x for x in sorted_author_writes if x['is_seed'] == False][0:top_l_key_authors]
for item in filtered_authors:
    print(item)

{'author_id': '2341358168', 'author_name': 'Dilani Wickramaarachchi', 'is_seed': False, 'h_index': None, 'global_paper_cnt': None, 'global_citation_cnt': None, 'local_paper_cnt': 2, 'top_coauthors': [('2341358168', 2)], 'weighted_coauthor_h_index': None}
{'author_id': '74214430', 'author_name': 'R. Ślepaczuk', 'is_seed': False, 'h_index': None, 'global_paper_cnt': None, 'global_citation_cnt': None, 'local_paper_cnt': 2, 'top_coauthors': [('74214430', 2)], 'weighted_coauthor_h_index': None}
{'author_id': '144992211', 'author_name': 'Shubham Agarwal', 'is_seed': False, 'h_index': None, 'global_paper_cnt': None, 'global_citation_cnt': None, 'local_paper_cnt': 2, 'top_coauthors': [('144992211', 2), ('3266173', 2), ('2275240361', 2), ('1778839', 1), ('2267339519', 1)], 'weighted_coauthor_h_index': None}
{'author_id': '3266173', 'author_name': 'I. Laradji', 'is_seed': False, 'h_index': None, 'global_paper_cnt': None, 'global_citation_cnt': None, 'local_paper_cnt': 2, 'top_coauthors': [('3266

In [27]:
key_authors['init_search'] = filtered_authors

### Pruning

In [28]:
sub_graphs = find_wcc_subgraphs(graph=G_init, target_nodes=seed_dois)

In [29]:
g_stat = get_graph_stats(sub_graphs[0])
graph_stats['init_search']['after_pruning'] = g_stat

Graph has 552 nodes and 1352 edges.
There are 4 node types in this graph, they are:
[('Author', 384), ('Paper', 87), ('Journal', 44), ('Venue', 37)]
There are 4 edge types in this graph, they are:
[('SIMILAR_TO', 873), ('WRITES', 389), ('PRINTS_ON', 48), ('RELEASES_IN', 42)]


### Graph Viz

In [None]:
# from graph.graph_viz import GraphViz
viz = GraphViz(G_init, 'Paper Graph After Init Search')
viz.preprocessing()
viz.visulization()

In [None]:
# from graph.graph_viz import GraphViz
viz = GraphViz(sub_graphs[0], 'Paper Graph After Init Search - after prunning')
viz.preprocessing()
viz.visulization()

## Basic Search

Search for more related information based on seed papers.

In [None]:
# assume user want to continue with citation chain, author and recommendations
search_citation = 'both'
search_author = True
find_recommend = True

Based on current status (# of nodes and authors) to decide search limits

In [None]:
citation_limit = 100

if len(seed_dois) > 20 or len(seed_titles) > 20 or len(paper_dois) > 100:
    search_limit = 50
    recommend_limit = 50
else :
    search_limit = 100
    recommend_limit = 100

### Data Generation

In [None]:
# --- MORE INFORMATION on SEED ---
print("--- Getting More Information Related to Seed Papers ---")
# basic search for seed papers
# may include seed paper authors, seed paper citation chain, recommendations based on seed papers 
await ps.paper_search(
    seed_paper_dois=seed_paper_dois,
    seed_author_ids=seed_author_ids,
    search_citation = search_citation,
    search_author = search_author,
    round = 1,
    find_recommend = find_recommend,
    recommend_limit = recommend_limit,
    citation_limit = citation_limit,
    from_dt=ps.from_dt,
    to_dt=ps.to_dt,
    fields_of_study = ps.fields_of_study,
    )

### Basic Stats

In [None]:
# basic stats
G = PaperGraph(name='Paper Graph Basic Search')
G.add_graph_nodes(ps.nodes_json)
G.add_graph_edges(ps.edges_json)

g_stat = get_graph_stats(G)
graph_stats['basic_search'] = {}
graph_stats['basic_search']['wo_similarity'] = g_stat

### Calculate Similarity

In [None]:
# --- INTERMEDIATE: CALCULATE SIMILARITY ---
# get all paper infos
paper_nodes_json = [node for node in ps.nodes_json 
                    if node['labels'] == ['Paper'] and 
                    node['properties'].get('title') is not None and node['properties'].get('abstract') is not None]
paper_dois = [node['id'] for node in paper_nodes_json]

# calculate paper nodes similarity
semantic_similar_pool = await ps.cal_embed_and_similarity(
    paper_nodes_json=paper_nodes_json,
    paper_dois_1=paper_dois, 
    paper_dois_2=paper_dois,
    similarity_threshold=similarity_threshold,
    )

In [None]:
edges_json = semantic_similar_pool
if type(edges_json) == dict:
    edges_json = [edges_json]

nx_edges_info = []
for item in edges_json:
    source_id = item['startNodeId']
    target_id = item['endNodeId']
    properties = item['properties']
    properties['relationshipType'] = item['relationshipType']
    nx_edges_info.append((source_id, target_id, properties))  
    item['dataGeneration'] = {'round': 1, 'source': 'init_search'}

G.add_edges_from(nx_edges_info)

In [None]:
g_stat = get_graph_stats(G)
graph_stats['basic_search']['w_similarity'] = g_stat

### Filtering & Ranking

In [None]:
paper_stats = get_paper_stats(G, seed_paper_dois)

Paper By Local References

In [None]:
sorted_paper_similarity = sorted(paper_stats, key=lambda x:x['local_citation_cnt'], reverse=True)
filtered_papers_stats = [x for x in sorted_paper_similarity if x['if_seed'] == False][0:top_m_corssref_papers]
filtered_papers_dois = [x['doi'] for x in filtered_papers_stats]
for item in filtered_papers_stats:
    print(item)

crossref_papers['basic_search'] = filtered_papers_stats

Paper By Similarity

In [None]:
sorted_paper_similarity = sorted(paper_stats, key=lambda x:x['local_similarity_cnt'], reverse=True)
filtered_papers_stats = [x for x in sorted_paper_similarity if x['if_seed'] == False][0:100]
filtered_papers_dois = [x['doi'] for x in filtered_papers_stats]
for item in filtered_papers_stats:
    print(item)

crossref_papers['basic_search'] = filtered_papers_stats

Author By Writes

In [None]:
author_stats = get_author_stats(G, seed_author_ids)

In [None]:
sorted_author_writes = sorted(author_stats, key=lambda x:x['local_paper_cnt'], reverse=True)
filtered_authors = [x for x in sorted_author_writes if x['is_seed'] == False][0:top_l_key_authors]
for item in filtered_authors:
    print(item)

### Graph Viz

In [None]:
# from graph.graph_viz import GraphViz
viz = GraphViz(G, 'Paper Graph After Init Search')
viz.preprocessing()
viz.visulization()

## Expanded Search

### Expand Citations

In [None]:
crossref_info = crossref_papers['basic_search']

candit_crossref_cnt = 0 
for item in crossref_info:
    if item['local_citation_cnt'] > len(seed_dois):
        candit_crossref_cnt += 1
    else:
        break
print(candit_crossref_cnt)


In [None]:
if_expanded_citations = 'reference' 
citation_limit = 100

if candit_crossref_cnt > 10:
    top_k_similar_papers = 20
else:
    top_k_similar_papers = 50

candit_paper_dois = [x['doi'] for x in crossref_info][0:top_k_similar_papers]

In [None]:
# --- EXPAND to CITATIONS over SIMILAR PAPERS ---
# get most similar papers to seed papers
# track citation chain of these papers
if if_expanded_citations is not None:
    print(f"\n--- Get crossref papers: ---")
    await ps.citation_expansion(
        seed_paper_dois = seed_paper_dois,
        candit_paper_dois = candit_paper_dois,  # user input of candit paper dois to search for citations
        search_citation = 'reference',
        round = 1,
        citation_limit = citation_limit,
        from_dt = ps.from_dt,
        to_dt = ps.to_dt,
        fields_of_study = ps.fields_of_study,
    )

In [None]:
# basic stats
G_citation = PaperGraph(name='Paper Graph After Citation Expansion')
G_citation.add_graph_nodes(ps.nodes_json)
G_citation.add_graph_edges(ps.edges_json)
g_stat = get_graph_stats(G)

graph_stats['expanded_search'] = {}
graph_stats['expanded_search']['citation_expansion'] = g_stat

Filtering & Ranking

In [None]:
paper_stats = get_paper_stats(G_citation, seed_paper_dois)

In [None]:
sorted_paper_similarity = sorted(paper_stats, key=lambda x:x['local_citation_cnt'], reverse=True)
filtered_papers_stats = [x for x in sorted_paper_similarity if x['if_seed'] == False][0:top_m_corssref_papers]
filtered_papers_dois = [x['doi'] for x in filtered_papers_stats]
for item in filtered_papers_stats:
    print(item)



In [None]:
crossref_papers['expanded_search'] = filtered_papers_stats

In [None]:
author_stats = get_author_stats(G_citation, seed_author_ids)

sorted_author_writes = sorted(author_stats, key=lambda x:x['local_paper_cnt'], reverse=True)
filtered_authors = [x for x in sorted_author_writes if x['is_seed'] == False][0:100]
for item in filtered_authors:
    print(item)

### Expanded Topics

Topic Generation

In [None]:
keywords_topics_json = await ps.topic_generation(
    seed_paper_json = seed_paper_json,
    llm_api_key= ps.llm_api_key,
    llm_model_name = ps.llm_model_name,
    round = 1)

In [None]:
topic_queries = [x for x in keywords_topics_json.get('queries', []) if x not in ps.explored_nodes['topic']] 

In [None]:
from thefuzz import fuzz # pip install thefuzz  https://github.com/seatgeek/thefuzz
ration_1 = fuzz.ratio('llm literature review', 'LLM for automated literature review'.lower())
ration_2 = fuzz.ratio('llm literature review', 'AI assisted peer review'.lower())
ration_3 = fuzz.ratio('llm literature review', 'AI support for scientific research'.lower())
ration_4 = fuzz.ratio('llm literature review', 'evaluation of AI generated literature reviews'.lower())
print(ration_1, ration_2, ration_3, ration_4)

In [None]:
ps.explored_nodes['topic']

In [None]:
topic_queries

In [None]:
len(paper_dois)

In [None]:
# get all paper infos
paper_nodes_json = [node for node in ps.nodes_json 
                    if node['labels'] == ['Paper'] and 
                    node['properties'].get('title') is not None and node['properties'].get('abstract') is not None]
paper_dois = [node['id'] for node in paper_nodes_json]

In [None]:
if len(topic_queries) > 4 or len(paper_dois) > 1000:
    search_limit = 50

if_related_topic = True

In [None]:
# --- EXPAND to RELATED TOPICS over SEED ---
# get related topics based on abstracts of seed papers
# search for related topics for more papers
print("--- Extend Related Topics from Seed Papers ---")
if if_related_topic:
    await ps.topic_extension(
        candit_topic_queries = topic_queries,
        round = 1,
        search_limit = search_limit,
        from_dt = ps.from_dt,
        to_dt = ps.to_dt,
        fields_of_study = ps.fields_of_study,
    )

In [None]:
# --- INTERMEDIATE: CALCULATE SIMILARITY ---
# calculate paper nodes similarity
semantic_similar_pool = await ps.cal_embed_and_similarity(
    paper_nodes_json=paper_nodes_json,
    paper_dois_1=paper_dois, 
    paper_dois_2=paper_dois,
    similarity_threshold=similarity_threshold,
    )