# Setup

In [1]:
import copy
from typing import List, Dict, Optional, Union, Tuple, Literal

In [2]:
import os
import json

import sys
import os

parent_dir = os.path.dirname(os.getcwd())
sys.path.append(parent_dir)

In [3]:
from graph.paper_graph import PaperGraph
from graph.graph_viz import GraphViz

In [4]:
# driving examples
llm_api_key = os.getenv('GEMINI_API_KEY_3')
llm_model_name = "gemini-2.5-flash-preview-04-17"
embed_api_key = os.getenv('GEMINI_API_KEY_3')
embed_model_name = "models/text-embedding-004"

research_topics = ["llm literature review"]
seed_dois = ['10.48550/arXiv.2406.10252',  # AutoSurvey: Large Language Models Can Automatically Write Surveys
            '10.48550/arXiv.2412.10415',  # Generative Adversarial Reviews: When LLMs Become the Critic
            '10.48550/arXiv.2402.12928',  # A Literature Review of Literature Reviews in Pattern Analysis and Machine Intelligence 
            ]
seed_titles = ['PaperRobot: Incremental Draft Generation of Scientific Ideas',
            'From Hypothesis to Publication: A Comprehensive Survey of AI-Driven Research Support Systems'
            ]

In [5]:
citation_limit = 100
author_paper_limit = 10

if len(seed_dois) < 10 or len(seed_titles) < 10:
    search_limit = 100
    recommend_limit = 100
else:
    search_limit = 50
    recommend_limit = 50

In [6]:
from collect.paper_data_collect import PaperCollector

ps = PaperCollector(   
    seed_research_topics = research_topics,   
    seed_paper_titles = seed_titles, 
    seed_paper_ids = seed_dois,
    from_dt = '2020-01-01',
    to_dt = '2025-04-30',
    fields_of_study = ['Computer Science'],
    author_paper_limit = author_paper_limit,
    search_limit = search_limit,
    recommend_limit = recommend_limit,
    citation_limit = citation_limit
    )

2025-04-30 15:12:03,557 - SemanticScholarKit - INFO - SemanticScholarKit initialized with: max_concurrency=20, max_retry=20, sleep_interval=3.0s
INFO:SemanticScholarKit:SemanticScholarKit initialized with: max_concurrency=20, max_retry=20, sleep_interval=3.0s
2025-04-30 15:12:03,558 - SemanticScholarKit - INFO - SemanticScholarKit initialized with: max_concurrency=20, max_retry=20, sleep_interval=3.0s
INFO:SemanticScholarKit:SemanticScholarKit initialized with: max_concurrency=20, max_retry=20, sleep_interval=3.0s


In [None]:
import pandas as pd
from datetime import datetime
from dateutil.relativedelta import relativedelta
from typing import List, Dict, Optional, Union, Any, Set, Tuple

import sys
import os

parent_dir = os.path.dirname(os.getcwd())
sys.path.append(parent_dir)

from graph.paper_graph import PaperGraph
from collect.paper_sim_calc import PaperSim


from datetime import datetime
from dateutil.relativedelta import relativedelta
import numpy as np


def graph_basic_stats(node_stats: Dict):
    """calculate key stats index for node stats"""
    node_stats_index = {}
    for key, values in node_stats.items():
        if key == 'id': continue  # skip id

        valid_values = [v for v in values if isinstance(v, (int, float))] # 只处理数值类型
        if valid_values:
            node_stats_index[key] = {
                'min': np.min(valid_values),
                'max': np.max(valid_values),
                'average': np.mean(valid_values),
                'median': np.median(valid_values),
                'quantile_25': np.percentile(valid_values, 25),
                'quantile_75': np.percentile(valid_values, 75)
            }
        else:
            node_stats_index[key] = {}  
    return node_stats_index


class PaperRouter:
    def __init__(
        self,
        nodes_json: Optional[List[Dict]] = None,
        edges_json: Optional[List[Dict]] = None,
        paper_sim: Optional[PaperSim] = None, 
        
        embed_api_key: Optional[str] = None,
        embed_model_name: Optional[str] = None
        ):
        self.nodes_json = nodes_json
        self.edges_json = edges_json
        
        # initiate semantic scholar instances
        if paper_sim and isinstance(paper_sim, PaperSim):
            self.sim_calc = paper_sim  
        else:
            self.sim_calc = PaperSim(
                embed_api_key = embed_api_key,
                embed_model_name = embed_model_name
            )


    ####################################################################################
    # similarity calculation and filter
    ####################################################################################
    async def score_paper2paper__sim(self, nodes_json):
        """calculate paper-paper similarity"""
        # valid paper with abstracts
        paper_json_w_abstract = [node for node in nodes_json 
                                if node['labels'] == ['Paper'] 
                                and node['properties'].get('title') is not None and node['properties'].get('abstract') is not None]
        paper_dois_w_abstract = [node['id'] for node in paper_json_w_abstract]

        # calculate paper nodes similarity
        semantic_similar_pool = await self.sim_cal.cal_embed_and_similarity(
            paper_nodes_json = paper_json_w_abstract,
            paper_dois_1 = paper_dois_w_abstract, 
            paper_dois_2 = paper_dois_w_abstract,
            similarity_threshold = 0.7,
            )
        return semantic_similar_pool

    async def score_paper2topic_sim(self, nodes_json):
        """calculate paper-topic similarity"""
        # valid paper with abstracts
        paper_json_w_abstract = [node for node in nodes_json 
                                if node['labels'] == ['Paper'] 
                                and node['properties'].get('title') is not None and node['properties'].get('abstract') is not None]
        paper_dois_w_abstract = [node['id'] for node in paper_json_w_abstract]

        # calculate paper nodes similarity
        semantic_similar_pool = await self.sim_cal.cal_embed_and_similarity(
            paper_nodes_json = paper_json_w_abstract,
            paper_dois_1 = paper_dois_w_abstract, 
            paper_dois_2 = paper_dois_w_abstract,
            similarity_threshold = 0.7,
            )
        return semantic_similar_pool


    ####################################################################################
    # statistics calculation and significant paper identify
    ####################################################################################
    

    def identify_paper_significant(
            self,
            paper_stats,
            ):
        """identify significant paper node"""
        paper_stats_df = pd.DataFrame(paper_stats)
        for index, row in paper_stats_df.iterrows():
            pid = row['id']
            # rule 1: global citation greater than or equal to 20
            if row['citationCount'] >= 20:
                significant_ind = 1
                info = 'citationCount'

            # rule 2: influential citation greater than or equal to 3
            elif row['influentialCitationCount'] >= 3:
                significant_ind = 1
                info = 'influentialCitationCount'

            # rule 3: monthly citation greater than or equal to 5
            elif row['monthlyCitationCount'] >= 5:
                significant_ind = 1
                info = 'monthlyCitationCount'

            # rule 4: local citation greater than or equal to 5
            elif row['localCitationCount'] >= 5:
                significant_ind = 1
                info = 'localCitationCount'

            row['significance'] = significant_ind
            row['sig_info'] = info

    def identify_author_significant(
            self,
            author_stats,
            ):
        """identify significant paper node"""
        author_stats_df = pd.DataFrame(author_stats)
        for index, row in author_stats_df.iterrows():
            aid = row['id']
            # rule 1: h-index greater than or equal to 10
            if row['hIndex'] >= 10:
                significant_ind = 1

            # rule 2: average paper ciation greater than or equal to 20
            elif row['paperCount'] / row['paperCount'] >= 20:
                significant_ind = 1

            # rule 4: local citation greater than or equal to 5
            elif row['localPaperCount'] >= 5:
                significant_ind = 1

            row['significance'] = significant_ind

    def identify_paper_significant_relative(self):
        pass

    def identify_paper_significant_lpa(self):
        pass

  from .autonotebook import tqdm as notebook_tqdm


# Data Collection

## Initial Search

### Data Generation

Get paper data

In [7]:
iteration = 1

In [8]:
await ps.consolidated_search(
    paper_titles = seed_titles,
    paper_ids = seed_dois
)

2025-04-30 15:12:21,612 - Paper Collector - INFO - consolidated_search: Starting...
INFO:Paper Collector:consolidated_search: Starting...
2025-04-30 15:12:21,614 - Paper Collector - INFO - consolidated_search: Running 1 sub-tasks concurrently...
INFO:Paper Collector:consolidated_search: Running 1 sub-tasks concurrently...
2025-04-30 15:12:21,615 - Paper Collector - INFO - Search 2 paper titles and 3 for paper information.
INFO:Paper Collector:Search 2 paper titles and 3 for paper information.
2025-04-30 15:12:21,616 - Paper Collector - INFO - paper_search: Creating task for 3 IDs...
INFO:Paper Collector:paper_search: Creating task for 3 IDs...
2025-04-30 15:12:21,617 - Paper Collector - INFO - paper_search: Creating 2 tasks for titles...
INFO:Paper Collector:paper_search: Creating 2 tasks for titles...
2025-04-30 15:12:21,618 - Paper Collector - INFO - paper_search: Running 3 query tasks concurrently...
INFO:Paper Collector:paper_search: Running 3 query tasks concurrently...
2025-04-30

### Post-Processing

Paper post process

In [9]:
await ps.post_process(if_supplement_abstract=True)

2025-04-30 15:12:34,019 - Paper Collector - INFO - post_process: Starting data processing...
INFO:Paper Collector:post_process: Starting data processing...
2025-04-30 15:12:34,020 - Paper Collector - INFO - Processing 5 raw paper entries...
INFO:Paper Collector:Processing 5 raw paper entries...
2025-04-30 15:12:34,022 - Paper Collector - INFO - Generated 75 nodes/edges from papers.
INFO:Paper Collector:Generated 75 nodes/edges from papers.
2025-04-30 15:12:34,023 - Paper Collector - INFO - Total items after paper processing: 75
INFO:Paper Collector:Total items after paper processing: 75
2025-04-30 15:12:34,024 - Paper Collector - INFO - No author data in pool to process.
INFO:Paper Collector:No author data in pool to process.
2025-04-30 15:12:34,026 - Paper Collector - INFO - No topic data in pool to process.
INFO:Paper Collector:No topic data in pool to process.
2025-04-30 15:12:34,027 - Paper Collector - INFO - No citation data (references or citings) in pool to process.
INFO:Paper C

### Routing

Plan next move

In [10]:
# core papers and core authors nodes
if iteration == 1:
    core_paper_ids = set(node['id'] for node in ps.nodes_json if node['labels'] == ['Paper'])
    core_author_ids = set(node['id'] for node in ps.nodes_json if node['labels'] == ['Author'])
    print(len(core_paper_ids), len(core_author_ids))

5 33


In [11]:
ps.explored_nodes['paper'].update(core_paper_ids)

In [12]:
# for authors
author_ids = [author_id for author_id in core_author_ids if author_id not in ps.explored_nodes['author']]

In [13]:
# for reference and citings
ref_ids = [pid for pid in core_paper_ids if pid not in ps.explored_nodes['reference']]
cit_ids = [pid for pid in core_paper_ids if pid not in ps.explored_nodes['citing']]

In [14]:
# recommendation 
if len(ps.explored_nodes['recommendation']) == 0:
    if len(core_paper_ids) > 3:
        pos_paper_ids = list(core_paper_ids)
        neg_paper_ids = []

In [15]:
# topics generation
core_paper_json = [x for x in ps.nodes_json if x['id'] in core_paper_ids]
if len(ps.explored_nodes['topic']) < 4:  # explored topic less than 4, generate new topics
    await ps.topic_generation(
        paper_json = core_paper_json,
        llm_api_key = llm_api_key,
        llm_model_name = llm_model_name,
        )

2025-04-30 15:12:40,185 - Paper Collector - INFO - Generating related topics for 5 seed papers...
INFO:Paper Collector:Generating related topics for 5 seed papers...
2025-04-30 15:12:40,188 - Paper Collector - INFO - Calling LLM to generate topics...
INFO:Paper Collector:Calling LLM to generate topics...
2025-04-30 15:12:54,312 - Paper Collector - INFO - LLM generated topics: [{"query": "AI research workflow automation support systems", "description": "This query targets papers providing a broad overview or discussing AI's role in automating or assisting multiple stages of the scientific research process, from literature search and hypothesis generation to writing and publication including peer review. Useful for understanding the landscape of AI in research."}, {"query": "LLM literature review survey generation", "description": "Focuses specifically on studies investigating the use of Large Language Models or other AI techniques for creating, generating, or assisting in the production

In [16]:
# identify unexplored topics
# covert topic data to k-v format
topic_pids = {}

for item in ps.data_pool['topic']:
    topic = item['topic']
    paper_id = item['paperId']
    
    if topic not in topic_pids:
        topic_pids[topic] = []
        
    topic_pids[topic].append(paper_id)

# identify topics with insufficient papers
topics = []
for topic, pids in topic_pids.items():
    if len(pids) < 10:
        topics.append(topic)
print(topics)
    

['AI research workflow automation support systems', 'LLM literature review survey generation', 'automated scientific peer review AI models', 'generative AI scientific paper writing assistance', 'evaluating AI generated academic content quality']


## Expanded Search

### Data Generation

In [17]:
await ps.consolidated_search(
    topics = topics,
    paper_titles = None,
    paper_ids = None,
    author_ids = author_ids,
    author_paper_ids = None,
    ref_paper_ids = ref_ids,
    cit_paper_ids = cit_ids,
    pos_paper_ids = pos_paper_ids,
    neg_paper_ids = neg_paper_ids,
    author_limit = 10,
    search_limit = ps.search_limit,
    citation_limit = ps.citation_limit,
    recommend_limit = ps.recommend_limit,
    from_dt = ps.from_dt,
    to_dt = ps.to_dt,
    fields_of_study = ps.fields_of_study
)

2025-04-30 15:13:17,839 - Paper Collector - INFO - consolidated_search: Starting...
INFO:Paper Collector:consolidated_search: Starting...
2025-04-30 15:13:17,841 - Paper Collector - INFO - consolidated_search: Running 5 sub-tasks concurrently...
INFO:Paper Collector:consolidated_search: Running 5 sub-tasks concurrently...
2025-04-30 15:13:17,842 - Paper Collector - INFO - topic_search: Searching 5 topics.
INFO:Paper Collector:topic_search: Searching 5 topics.
2025-04-30 15:13:17,844 - Paper Collector - INFO - topic_search: Running 5 topic search tasks concurrently...
INFO:Paper Collector:topic_search: Running 5 topic search tasks concurrently...
2025-04-30 15:13:17,845 - Paper Collector - INFO - authors_search: Searching 33 authors.
INFO:Paper Collector:authors_search: Searching 33 authors.
2025-04-30 15:13:17,846 - SemanticScholarKit - INFO - get_authors: Creating 1 tasks for 1 IDs.
INFO:SemanticScholarKit:get_authors: Creating 1 tasks for 1 IDs.
2025-04-30 15:13:17,847 - SemanticScho

Paper Post-progress

In [18]:
await ps.post_process(if_supplement_abstract=True)

2025-04-30 15:13:55,969 - Paper Collector - INFO - post_process: Starting data processing...
INFO:Paper Collector:post_process: Starting data processing...
2025-04-30 15:13:55,971 - Paper Collector - INFO - Processing 2219 raw paper entries...
INFO:Paper Collector:Processing 2219 raw paper entries...
2025-04-30 15:13:56,023 - Paper Collector - INFO - Generated 22769 nodes/edges from papers.
INFO:Paper Collector:Generated 22769 nodes/edges from papers.
2025-04-30 15:13:56,026 - Paper Collector - INFO - Found 1368 paper nodes missing abstracts. Attempting to supplement...
INFO:Paper Collector:Found 1368 paper nodes missing abstracts. Attempting to supplement...
2025-04-30 15:13:56,027 - Paper Collector - INFO - supplement_abstract: Fetching abstracts for 1368 papers...
INFO:Paper Collector:supplement_abstract: Fetching abstracts for 1368 papers...
2025-04-30 15:13:56,027 - SemanticScholarKit - INFO - get_papers: Creating 14 tasks for 14 IDs.
INFO:SemanticScholarKit:get_papers: Creating 1

Similarity Calculation

In [20]:
topic_nodes_json = [node for node in ps.nodes_json if node['labels'] == ['Topic']
                    if node['properties'].get('description') is not None]
paper_nodes_json = [node for node in ps.nodes_json if node['labels'] == ['Paper'] 
                    and node['properties'].get('title') is not None and node['properties'].get('abstract') is not None]

In [21]:
topic_ids = [x['id'] for x in topic_nodes_json]
paper_ids = [x['id'] for x in paper_nodes_json]

In [25]:
topic_name_ref = {x['id']:x['properties']['name'] for x in topic_nodes_json}
paper_title_ref = {x['id']:x['properties']['title'] for x in paper_nodes_json}

In [23]:
# similarity calculation
from collect.paper_sim_calc import PaperSim
sim_calc = PaperSim(embed_api_key, embed_model_name)

await sim_calc.get_topic_embeds(topic_nodes_json, embed_api_key, embed_model_name)
await sim_calc.get_abstract_embeds(paper_nodes_json, embed_api_key, embed_model_name)

In [24]:
p2p_sim_pool = await sim_calc.cal_p2p_similarity(paper_nodes_json, paper_ids, paper_ids)
p2t_sim_pool = await sim_calc.cal_p2t_similarity(paper_nodes_json, topic_nodes_json, paper_ids, topic_ids)

2-Hop Filter
- 1-hop of similar papers to core
- 1-hop of papers of citation chain
- similar papers to topic
- (if possible) 2-hop of similar papers to 1-hop

In [26]:
hop_1_paper_ids = set()

In [30]:
hop_1_paper_ids.update(core_paper_ids)  # first append core paper ids

In [None]:
# add papers similar to core paper ids
i = 0
for item in p2p_sim_pool:  # iterate paper -> SIMILAR_TO -> paper
    start_id = item['startNodeId']
    end_id = item['endNodeId']
    if start_id in core_paper_ids and end_id not in core_paper_ids:
        hop_1_paper_ids.add(end_id)
        i += 1
    elif start_id not in core_paper_ids and end_id in core_paper_ids:
        hop_1_paper_ids.add(start_id)
        i += 1
print(i, len(hop_1_paper_ids))

333 218


In [36]:
# add papers belong to core paper ids's citation chain
i = 0
for item in ps.edges_json:
    relationship = item.get('relationshipType')
    if relationship == 'CITES':
        start_id = item['startNodeId']
        end_id = item['endNodeId']
        if start_id in core_paper_ids and end_id not in core_paper_ids:
            hop_1_paper_ids.add(end_id)
            i += 1
        elif start_id not in core_paper_ids and end_id in core_paper_ids:
            hop_1_paper_ids.add(start_id)
            i += 1
print(i, len(hop_1_paper_ids))

435 554


In [None]:
# add papers similar to topic nodes
paper_topic_tuple = set((x['startNodeId'], x['endNodeId']) for x in p2t_sim_pool)  # paper similar to nodes

i = 0
for edge in ps.edges_json:
    if edge.get('relationshipType') == 'DISCUSS':  # paper discuss the topic
        if (edge['startNodeId'], edge['endNodeId']) in paper_topic_tuple:  # paper similar to topic
            # print(topic_name_ref.get(edge['endNodeId']), paper_title_ref.get(edge['startNodeId']))
            hop_1_paper_ids.add(edge['startNodeId'])
            i += 1
print(i, len(hop_1_paper_ids))

In [None]:
# 2-hop extension
hop_2_paper_ids = set()

i = 0
for item in p2p_sim_pool:  # iterate paper -> SIMILAR_TO -> paper
    start_id = item['startNodeId']
    end_id = item['endNodeId']
    if start_id in hop_1_paper_ids and end_id not in hop_1_paper_ids:
        hop_2_paper_ids.add(end_id)
        i += 1
    elif start_id not in hop_1_paper_ids and end_id in hop_1_paper_ids:
        hop_2_paper_ids.add(start_id)
        i += 1
print(i, len(hop_2_paper_ids))

6637 820


In [44]:
print(len(hop_1_paper_ids), len(hop_2_paper_ids), len(set(hop_1_paper_ids) | set(hop_2_paper_ids)))

569 820 1389


Graph Pruning based on hop 1 papers

In [46]:
# for nodes
hop_1_author_ids = set()

i = 0
for item in ps.edges_json:
    relationship = item.get('relationshipType')
    if relationship == 'WRITES':
        start_id = item['startNodeId']
        end_id = item['endNodeId']
        if end_id in hop_1_paper_ids:
            hop_1_author_ids.add(start_id)
            i += 1
print(i, len(hop_1_author_ids))

3152 2820


In [49]:
# for nodes
hop_1_topic_ids = set()

i = 0
for item in ps.edges_json:
    relationship = item.get('relationshipType')
    if relationship == 'DISCUSS':
        start_id = item['startNodeId']
        end_id = item['endNodeId']
        if start_id in hop_1_paper_ids:
            hop_1_topic_ids.add(end_id)
            i += 1
print(i, len(hop_1_topic_ids))

139 5


In [52]:
# filter node json
node_ids = hop_1_paper_ids | hop_1_author_ids | hop_1_topic_ids
nodes_json = [x for x in ps.nodes_json if x['id'] in node_ids]

In [67]:
# filter edge json
edges_json = []

i = 0
for edge in ps.edges_json:
    relationship = edge.get('relationshipType')
    if relationship in ['CITES', 'WRITES', 'DISCUSS']:
        start_id = edge['startNodeId']
        end_id = edge['endNodeId']
        if start_id in node_ids and end_id in node_ids:
            edges_json.append(edge)
            i += 1
print(i, len(edges_json))

3730 3730


In [69]:
# append similarity edges
i = 0
for edge in p2p_sim_pool:  # iterate paper -> SIMILAR_TO -> paper
    start_id = edge['startNodeId']
    end_id = edge['endNodeId']
    if start_id in node_ids and end_id in node_ids:
        edges_json.append(edge)
        i += 1
print(i, len(edges_json))

i = 0
for edge in p2t_sim_pool:  # iterate paper -> DISCUSS -> topic
    start_id = edge['startNodeId']
    end_id = edge['endNodeId']
    if start_id in node_ids and end_id in node_ids:
        edges_json.append(edge)
        i += 1
print(i, len(edges_json))

9004 12734
142 12876


In [70]:
print(len(ps.nodes_json), len(ps.edges_json), 
      len(nodes_json), len(edges_json))

9941 13917 3514 12876


Generate paper graph

In [71]:
# generate paper graph from nodes / edges json
G_pre = PaperGraph(name='Paper Graph Pre')
G_pre.add_graph_nodes(nodes_json)
G_pre.add_graph_edges(edges_json)

In [147]:
def gen_nodes_stats(paper_graph):
    """calculate statistics for paper node"""
    # ---------- 1. Initiate stats  ------------
    # for paper stats
    paper_stats = {
        'id': [],
        'citationCount': [],
        'influentialCitationCount': [],
        'referenceCount': [],
        'monthlyCitationCount': [],
        'localCitationCount': [],
        'localReferenceCount': [],
        'localSimilarPaperCont': [],
        'maxSimilarityToSeedPapers': [],
        'avgSimilarityToSeedPapers': []
    }

    # for author stats
    author_stats = {
        'id': [],
        'paperCount': [],
        'citationCount': [],
        'hIndex': [],
        'localPaperCount': [],
    }

    # iterate paper node to calculate measurements
    for nid, node_data in paper_graph.nodes(data=True):
        # ---------- 2. Calculate paper stats  ------------
        if node_data.get('nodeType') == 'Paper':
            paper_stats['id'].append(nid)
            # global measurement
            pub_dt = node_data.get('publicationDate')
            paper_cit_cnt = node_data.get('citationCount')
            paper_stats['citationCount'].append(paper_cit_cnt)
            paper_stats['influentialCitationCount'].append(node_data.get('influentialCitationCount'))
            paper_stats['referenceCount'].append(node_data.get('referenceCount'))

            # generated measurement
            if pub_dt:
                try:
                    tm_to_dt = relativedelta(datetime.now().date(), datetime.strptime(pub_dt, '%Y-%m-%d').date())
                    mth_to_dt = tm_to_dt.years * 12 + tm_to_dt.months
                    if paper_cit_cnt is not None and mth_to_dt > 0:
                        paper_mthly_cit_cnt = paper_cit_cnt / mth_to_dt
                        paper_stats['monthlyCitationCount'].append(paper_mthly_cit_cnt)
                except ValueError:
                    paper_stats['monthlyCitationCount'].append(None)

            # local measurement
            # other paper -> CITES -> this paper
            in_edges_info = paper_graph.in_edges(nid, data=True)
            paper_loc_cit_cnt = sum([1 for _, _, edge_data in in_edges_info if edge_data['relationshipType'] == 'CITES'])
            paper_stats['localCitationCount'].append(paper_loc_cit_cnt)

            # this paper -> CITES -> other paper
            out_edges_info = paper_graph.out_edges(nid, data=True)
            paper_loc_ref_cnt = sum([1 for _, _, edge_data in out_edges_info if edge_data['relationshipType'] == 'CITES'])
            paper_stats['localReferenceCount'].append(paper_loc_ref_cnt)

            # local similar papers count
            paper_loc_sim_cnt_1 = sum([1 for _, _, edge_data in in_edges_info if edge_data['relationshipType'] == 'SIMILAR_TO'])
            paper_loc_sim_cnt_2 = sum([1 for _, _, edge_data in out_edges_info if edge_data['relationshipType'] == 'SIMILAR_TO'])
            paper_stats['localSimilarPaperCont'].append((paper_loc_sim_cnt_1 + paper_loc_sim_cnt_2))

            # similarity score to core papers
            paper_sim_to_core_papers_1 = [edge_data['weight'] for u, _, edge_data in in_edges_info 
                                          if edge_data['relationshipType'] == 'SIMILAR_TO'and u in core_paper_ids]
            paper_sim_to_core_papers_2 = [edge_data['weight'] for _, v, edge_data in out_edges_info 
                                          if edge_data['relationshipType'] == 'SIMILAR_TO'and v in core_paper_ids]
            sims_score = paper_sim_to_core_papers_1 + paper_sim_to_core_papers_2
            if sims_score:
                max_sim_score = max(sims_score)
                avg_sim_score = np.average(sims_score)
            else:
                max_sim_score = None
                avg_sim_score = None
            paper_stats['maxSimilarityToSeedPapers'].append(max_sim_score)
            paper_stats['avgSimilarityToSeedPapers'].append(avg_sim_score)

        # ---------- 3. Calculate author stats  ------------
        elif node_data.get('nodeType') == 'Author':
            author_stats['id'].append(nid)
            # global measurement
            author_stats['paperCount'].append(node_data.get('paperCount'))
            author_stats['citationCount'].append(node_data.get('citationCount'))
            author_stats['hIndex'].append(node_data.get('hIndex'))

            # local measurement
            successors = paper_graph.successors(nid)
            author_loc_paper_cnt = sum([1 for x in successors if paper_graph.nodes[x].get('nodeType') == 'Paper'])
            author_stats['localPaperCount'].append(author_loc_paper_cnt)

    node_stats = {'paper_stats': paper_stats, 'author_stats': author_stats}

    return node_stats

In [148]:
# calculate paper / author node stats
node_stats = gen_nodes_stats(G_pre)

In [149]:
paper_ids = node_stats['paper_stats']['id']
paper_local_cit_cnt = node_stats['paper_stats']['localCitationCount']
paper_local_cit_ref = {id: cit_cnt for id, cit_cnt in zip(paper_ids, paper_local_cit_cnt)}

In [150]:
# rank order by local citations
paper_local_cit_sorted = sorted(paper_local_cit_ref.items(), key=lambda x:x[1], reverse=True)

In [151]:
crossref_paper_ids = set()
for item in paper_local_cit_sorted:
    if item[1] > len(core_paper_ids):
        crossref_paper_ids.add(item[0])

In [152]:
if len(crossref_paper_ids) < 20:
    print("cross ref insufficient")

cross ref insufficient


In [153]:
# all citations from core papers
core_cit_paper_ids = set()

for u, v, edge_data in G_pre.edges(data=True):
    if u in core_paper_ids:
        core_cit_paper_ids.add(v)
    elif v in core_paper_ids:
        core_cit_paper_ids.add(u)

In [158]:
paper_ids = node_stats['paper_stats']['id']
paper_local_sim_scores = node_stats['paper_stats']['maxSimilarityToSeedPapers']
paper_local_sim_ref = {id: sim_score for id, sim_score in zip(paper_ids, paper_local_sim_scores) if sim_score is not None}

In [159]:
# rank order by local citations
paper_local_sim_sorted = sorted(paper_local_sim_ref.items(), key=lambda x:x[1], reverse=True)

In [161]:
G_pre.nodes['c890b28a001c885d1f7aa05f5d24ead9bf6ae058']

{'paperId': 'c890b28a001c885d1f7aa05f5d24ead9bf6ae058',
 'title': 'MARG: Multi-Agent Review Generation for Scientific Papers',
 'year': 2024,
 'referenceCount': 0,
 'citationCount': 27,
 'influentialCitationCount': 0,
 'isOpenAccess': True,
 'openAccessPdf': {'url': 'https://arxiv.org/pdf/2401.04259.pdf'},
 'fieldsOfStudy': ['Computer Science'],
 'publicationDate': '2024-01-08',
 'arxivUrl': 'https://arxiv.org/abs/2401.04259',
 'arxivId': '2401.04259',
 'doi': '10.48550/arXiv.2401.04259',
 'abstract': 'We study the ability of LLMs to generate feedback for scientific papers and develop MARG, a feedback generation approach using multiple LLM instances that engage in internal discussion. By distributing paper text across agents, MARG can consume the full text of papers beyond the input length limitations of the base LLM, and by specializing agents and incorporating sub-tasks tailored to different comment types (experiments, clarity, impact) it improves the helpfulness and specificity of f

In [None]:
def identify_paper_significant(paper_stats):
    """identify significant paper node"""
    paper_stats_df = pd.DataFrame(paper_stats)

    sig_paper_ids, sig_paper_info = [], []
    for index, row in paper_stats_df.iterrows():
        pid = row['id']
        # rule 1: global citation greater than or equal to 20
        if row.get('citationCount', -1) >= 20:
            sig_paper_ids.append(pid)
            sig_paper_info.append('citationCount')

        # rule 2: influential citation greater than or equal to 3
        elif row.get('influentialCitationCount', -1) >= 3:
            sig_paper_ids.append(pid)
            sig_paper_info.append('influentialCitationCount')

        # rule 3: monthly citation greater than or equal to 5
        elif row.get('monthlyCitationCount', -1) >= 5:
            sig_paper_ids.append(pid)
            sig_paper_info.append('monthlyCitationCount')

        # rule 4: local citation greater than or equal to 5
        elif row.get('localCitationCount', -1) >= 5:
            sig_paper_ids.append(pid)
            sig_paper_info.append('localCitationCount')


def identify_author_significant(author_stats):
    """identify significant paper node"""
    author_stats_df = pd.DataFrame(author_stats)

    sig_author_ids, sig_author_info = [], []
    for index, row in author_stats_df.iterrows():
        aid = row['id']
        # rule 1: h-index greater than or equal to 10
        if row.get('hIndex', -1) >= 10:
            sig_author_ids.append(aid)
            sig_author_info.append('hIndex')

        # rule 2: average paper ciation greater than or equal to 20
        elif row.get('paperCount', -1) > 0 and row.get('citationCount', -1) > 0 and row.get('citationCount', -1) / row.get('paperCount', -1) >= 20:
            sig_author_ids.append(aid)
            sig_author_info.append('avgPaperCitation')

        # rule 4: local citation greater than or equal to 5
        elif row.get('localPaperCount', -1) >= 5:
            sig_author_ids.append(aid)
            sig_author_info.append('localPaperCount')


In [163]:
import networkx as nx
from collections import defaultdict
import time # For demo or debugging iteration process

def propagate_labels(
    paper_graph: nx.MultiDiGraph,
    sig_paper_ids: set,
    sig_author_ids: set,
    k: int,
    m: int,
    n: int
) -> tuple[list, dict]:
    """
    Performs label propagation on a MultiDiGraph to find potentially significant nodes.

    Args:
        paper_graph (nx.MultiDiGraph): The graph containing paper and author nodes,
                                        and edges of type 'WRITES' and 'CITES'.
                                        Assumes edges have a 'type' attribute for distinction.
        sig_paper_ids (set): Set of initial significant paper node IDs.
        sig_author_ids (set): Set of initial significant author node IDs.
        k (int): Threshold for rule 3.1 (author -> WRITES -> paper).
                 A paper becomes potentially significant if written by k or more significant authors.
        m (int): Threshold for rule 3.2 (other paper -> CITES -> this paper).
                 A paper becomes potentially significant if cited by m or more significant papers.
        n (int): Threshold for rule 3.3 (this paper -> CITES -> other paper).
                 A paper becomes potentially significant if it cites n or more significant papers.

    Returns:
        tuple[list, dict]:
            - candidate_sig_node_ids (list): List of identified potentially significant node IDs (excluding initial ones).
            - candidate_sig_node_info (dict): Dictionary where keys are candidate node IDs
                                              and values are lists of reasons why the node was identified.
    """

    # Ensure inputs are sets for efficient lookup
    initial_significant_nodes = set(sig_paper_ids).union(set(sig_author_ids))
    current_significant_nodes = initial_significant_nodes.copy()

    # Store the final output: candidate significant nodes and their reasons
    candidate_sig_node_ids = set()
    candidate_sig_node_info = defaultdict(list)

    iteration = 0
    while True:
        iteration += 1
        newly_identified_nodes = set() # Nodes newly identified in this iteration

        # Iterate through all nodes in the graph, checking if they meet the criteria
        # Note: Only check nodes that are not currently significant
        nodes_to_check = set(paper_graph.nodes()) - current_significant_nodes

        for node_id in nodes_to_check:
            reasons_for_this_node = []

            # --- Check Rule 3.1 (author -> WRITES -> paper) ---
            # Assumes only paper nodes have incoming 'WRITES' edges
            # and only author nodes have outgoing 'WRITES' edges
            significant_authors_count = 0
            authors = set() # Store distinct author IDs writing this paper
            try:
                for u, v, data in paper_graph.in_edges(node_id, data=True):
                     # Ensure it's a WRITES relationship and the source 'u' is significant.
                     # A more robust way might be checking node type attributes if they exist.
                     # Here we infer u is author, v is paper based on 'WRITES' edge type.
                    if data.get('type') == 'WRITES' and u in current_significant_nodes:
                        # Use a set to ensure each author is counted only once, even with parallel edges (unlikely)
                        authors.add(u)
                significant_authors_count = len(authors)

                if significant_authors_count >= k:
                    reason = f"Rule 3.1: Written by {significant_authors_count} significant authors (threshold k={k})"
                    reasons_for_this_node.append(reason)
            except KeyError:
                # If the node has no incoming edges, skip this check
                pass

            # --- Check Rule 3.2 (other paper -> CITES -> this paper) ---
            # Assumes only paper nodes have 'CITES' type edges between them
            significant_citing_papers_count = 0
            citing_papers = set() # Store distinct significant paper IDs citing this one
            try:
                for u, v, data in paper_graph.in_edges(node_id, data=True):
                    # Ensure it's a CITES relationship
                    if data.get('type') == 'CITES' and u in current_significant_nodes:
                         # Infer u is a paper by checking if it's in the significant set
                         # Use a set to count each citing paper only once
                        citing_papers.add(u)
                significant_citing_papers_count = len(citing_papers)

                if significant_citing_papers_count >= m:
                    reason = f"Rule 3.2: Cited by {significant_citing_papers_count} significant papers (threshold m={m})"
                    reasons_for_this_node.append(reason)
            except KeyError:
                # If the node has no incoming edges, skip this check
                pass

            # --- Check Rule 3.3 (this paper -> CITES -> other paper) ---
            significant_cited_papers_count = 0
            cited_papers = set() # Store distinct significant paper IDs cited by this one
            try:
                for u, v, data in paper_graph.out_edges(node_id, data=True):
                    # Ensure it's a CITES relationship
                    if data.get('type') == 'CITES' and v in current_significant_nodes:
                        # Infer v is a paper by checking if it's in the significant set
                        # Use a set to count each cited paper only once
                        cited_papers.add(v)
                significant_cited_papers_count = len(cited_papers)

                if significant_cited_papers_count >= n:
                    reason = f"Rule 3.3: Cites {significant_cited_papers_count} significant papers (threshold n={n})"
                    reasons_for_this_node.append(reason)
            except KeyError:
                 # If the node has no outgoing edges, skip this check
                pass

            # If this node was identified by any rule
            if reasons_for_this_node:
                newly_identified_nodes.add(node_id)
                # Record the reason(s). Add new reasons even if recorded before
                # (could be a new iteration or a different rule met).
                # Only add to the final output list if it's a *newly found* candidate node.
                if node_id not in candidate_sig_node_ids:
                     candidate_sig_node_info[node_id].extend(reasons_for_this_node)
                else:
                    # If node was identified before, append new reasons.
                    # The defaultdict handles appending smoothly.
                     candidate_sig_node_info[node_id].extend(reasons_for_this_node)


        # Check if any new nodes were identified; if not, stability reached, exit loop
        if not newly_identified_nodes:
            print(f"Propagation stabilized at iteration {iteration-1}.")
            break
        else:
            print(f"Iteration {iteration}: Identified {len(newly_identified_nodes)} new candidate significant nodes.")
            # Update the set of current significant nodes for the next iteration
            current_significant_nodes.update(newly_identified_nodes)
            # Add the newly identified nodes (that weren't initial) to the final candidate list
            candidate_sig_node_ids.update(newly_identified_nodes - initial_significant_nodes)

        # Safety measure: prevent infinite loops (e.g., due to graph structure or oscillating rules)
        # You could add a max iteration limit
        # max_iterations = 100 # Example limit
        # if iteration > max_iterations:
        #     print(f"Reached maximum iteration limit ({max_iterations}).")
        #     break

    # Clean up candidate_sig_node_info to only contain info for nodes in the final candidate_sig_node_ids set
    # Also, remove duplicate reason strings for clarity
    final_candidate_info = {node_id: list(set(reasons)) # Deduplicate reasons
                           for node_id, reasons in candidate_sig_node_info.items()
                           if node_id in candidate_sig_node_ids}


    return list(candidate_sig_node_ids), final_candidate_info

# --- Example Usage ---
if __name__ == '__main__':
    # 1. Create an example MultiDiGraph
    paper_graph = nx.MultiDiGraph()

    # Add nodes (explicitly adding node types might be better, but here we infer from edges)
    # Authors
    authors = {'A1', 'A2', 'A3', 'A4', 'A5'}
    # Papers
    papers = {'P1', 'P2', 'P3', 'P4', 'P5', 'P6', 'P7'}

    paper_graph.add_nodes_from(authors)
    paper_graph.add_nodes_from(papers)

    # Add edges (author -> WRITES -> paper)
    paper_graph.add_edge('A1', 'P1', type='WRITES')
    paper_graph.add_edge('A2', 'P1', type='WRITES')
    paper_graph.add_edge('A3', 'P2', type='WRITES')
    paper_graph.add_edge('A1', 'P3', type='WRITES') # A1 wrote P3
    paper_graph.add_edge('A4', 'P3', type='WRITES') # A4 wrote P3
    paper_graph.add_edge('A5', 'P4', type='WRITES') # A5 wrote P4 (A5 is not significant)
    paper_graph.add_edge('A1', 'P5', type='WRITES') # A1 wrote P5
    paper_graph.add_edge('A2', 'P5', type='WRITES') # A2 wrote P5
    paper_graph.add_edge('A3', 'P5', type='WRITES') # A3 wrote P5
    paper_graph.add_edge('A4', 'P6', type='WRITES') # A4 wrote P6

    # Add edges (paper -> CITES -> paper)
    paper_graph.add_edge('P1', 'P2', type='CITES') # P1 cites P2
    paper_graph.add_edge('P3', 'P1', type='CITES') # P3 cites P1
    paper_graph.add_edge('P4', 'P1', type='CITES') # P4 cites P1
    paper_graph.add_edge('P4', 'P3', type='CITES') # P4 cites P3
    paper_graph.add_edge('P5', 'P6', type='CITES') # P5 cites P6
    paper_graph.add_edge('P6', 'P7', type='CITES') # P6 cites P7 (P7 has no other connections)
    paper_graph.add_edge('P1', 'P7', type='CITES') # P1 cites P7


    # 2. Define initial significant nodes
    sig_author_ids = {'A1', 'A2', 'A3'}
    sig_paper_ids = {'P1'} # P1 is initially significant

    # 3. Set propagation thresholds
    k = 2 # At least 2 significant authors write a paper
    m = 1 # Cited by at least 1 significant paper
    n = 1 # Cites at least 1 significant paper

    print("Initial significant authors:", sig_author_ids)
    print("Initial significant papers:", sig_paper_ids)
    print(f"Propagation thresholds: k={k}, m={m}, n={n}\n")

    # 4. Execute label propagation
    start_time = time.time()
    candidate_ids, candidate_info = propagate_labels(
        paper_graph,
        sig_paper_ids,
        sig_author_ids,
        k, m, n
    )
    end_time = time.time()

    # 5. Print results
    print("\n--- Propagation Results ---")
    print(f"Found {len(candidate_ids)} candidate significant nodes:")
    print("Candidate significant node IDs:", candidate_ids)

    print("\nDetailed Reasons:")
    if not candidate_info:
        print("No new candidate significant nodes found.")
    else:
        # Sort by node ID for consistent output, converting to list first if needed
        sorted_candidate_ids = sorted(list(candidate_ids))
        for node_id in sorted_candidate_ids:
            print(f"  Node {node_id}:")
            # Retrieve reasons, provide default if somehow missing
            reasons = candidate_info.get(node_id, ["No detailed reasons recorded"])
            # Sort reasons for consistent output
            for reason in sorted(reasons):
                print(f"    - {reason}")

    print(f"\nAlgorithm execution time: {end_time - start_time:.4f} seconds")

    # Expected outcome verification:
    # P1: Initial significant.
    # P2: Cited by P1 (significant, m=1 met). Written by A3 (significant, k=1 < k=2). Expected candidate via citation.
    # P3: Written by A1 (sig), A4 (non-sig) (k=1 < k=2). Cited by P4 (non-sig initially). Cites P1 (significant, n=1 met). Expected candidate via citation.
    # P4: Written by A5 (non-sig). Cites P1 (sig), P3 (becomes sig). Expected candidate via citation (possibly in later round if P3 becomes sig).
    # P5: Written by A1, A2, A3 (all sig) (k=3 >= k=2 met). Expected candidate via authors.
    # P6: Cited by P5 (becomes sig, m=1 met). Written by A4 (non-sig). Cites P7. Expected candidate via being cited (after P5 becomes sig).
    # P7: Cited by P6 (becomes sig) and P1 (sig) (m=2 >= m=1 met). Expected candidate via being cited (after P1 or P6 is sig).

Initial significant authors: {'A1', 'A3', 'A2'}
Initial significant papers: {'P1'}
Propagation thresholds: k=2, m=1, n=1

Iteration 1: Identified 5 new candidate significant nodes.
Iteration 2: Identified 1 new candidate significant nodes.
Propagation stabilized at iteration 2.

--- Propagation Results ---
Found 6 candidate significant nodes:
Candidate significant node IDs: ['P7', 'P6', 'P4', 'P5', 'P2', 'P3']

Detailed Reasons:
  Node P2:
    - Rule 3.2: Cited by 1 significant papers (threshold m=1)
  Node P3:
    - Rule 3.3: Cites 1 significant papers (threshold n=1)
  Node P4:
    - Rule 3.3: Cites 1 significant papers (threshold n=1)
  Node P5:
    - Rule 3.1: Written by 3 significant authors (threshold k=2)
  Node P6:
    - Rule 3.2: Cited by 1 significant papers (threshold m=1)
    - Rule 3.3: Cites 1 significant papers (threshold n=1)
  Node P7:
    - Rule 3.2: Cited by 1 significant papers (threshold m=1)

Algorithm execution time: 0.0006 seconds


In [None]:
## TASK
请根据以下要求，构造一个图上的标签传播算法。请输出python代码。

## INSTRUCTION
1. 已有一个networkx的MultiDiGraph名为paper_graph. 其中的主要node类型为paper和author，主要的edge类型为 paper -> CITES -> paper, author -> WRITES -> paper；
2. 已知该paper_graph中存在关键节点sig_paper_ids和sig_author_ids，对应为paper和author节点的id；
3. 现希望基于以上关键节点做标签传播，找出图中更多的潜在关键节点来，传播规则如下：
    - 3.1. 在 author -> WRITES -> paper关系中，如果paper节点对应有k个或以上的author节点为关键节点，则此paper节点为潜在的关键节点；
    - 3.2. 在 other paper -> CITES -> this paper关系中，如果引用的other paper节点中有m个或以上为关键节点，则this paper节点为潜在的关键节点；
    - 3.3. 在 this paper -> CITES -> other paper关系中，如果被引用的other paper节点中有n个或以上为关键节点，则this paper节点为潜在的关键节点；
4. 关键节点/潜在关键节点可以进一步传播，直到稳定；
5. 最终输出潜在关键节点列表 candidate_sig_node_ids 和 识别为潜在关键节点的原因 candidate_sig_node_info，二者一一对应。


In [160]:
paper_local_sim_sorted

[('c890b28a001c885d1f7aa05f5d24ead9bf6ae058', 0.8388),
 ('924956d6c788c9ea67ecdc80b63742d74350549e', 0.8349),
 ('d0b5194032451157f264db4a6da569f03347d1cb', 0.8272),
 ('987d0cbe751780b9b993ebf8e670fb0d18fdaabe', 0.8224),
 ('1191a81272747b2add72d16c6a1ff00d1d8b8b2f', 0.8153),
 ('9348b7b95982d0a675a767e92c23647aa6915a94', 0.8145),
 ('c7471d25c1cee7f1b4457a52190d0b1a69a024ea', 0.812),
 ('a0e76e07fec917f9e8dd11b096a6fd524c1a76f5', 0.8119),
 ('fa39f0cc3e1dcb001f53735ae3d174d308f34301', 0.8107),
 ('bfed3d4c959b64148811376965db84f77ea8292e', 0.8094),
 ('235992701c6e70872f5292fd5818d5c5719063de', 0.8049),
 ('94fb5a19f86d81a746bb5502a5debf2659814e8e', 0.8037),
 ('8a52ac9268ec522fe310c7b75df00d9aaa80efd0', 0.7977),
 ('0490dc8f2ee87418aa0a9d9f198a12895af15b98', 0.7961),
 ('62729cff7dda7614f648a84e8967076d8878a5ff', 0.7953),
 ('51b7b3ad7645a69e3c1c80cae69473b8bd472f67', 0.7949),
 ('e4eb81ad222ba047770d5a90bdd7406c138c6126', 0.7924),
 ('e7e46bc5ef80187084d2d2c626ca95e68ee6e74b', 0.7895),
 ('a5093974

Try build crossref

Significance Identification
- absolute significance
- relative significance
- graph significance

In [None]:
topic_ref

In [None]:
len(tmp_ids)

In [None]:
# generate paper graph from nodes / edges json
G_pre = PaperGraph(name='Paper Graph 1')
G_pre.add_graph_nodes(ps.nodes_json)
G_pre.add_graph_edges(ps.edges_json)

In [None]:
import numpy as np
np.array(list({'a':1, 'b':2}.values()))

In [None]:
G_pre.nodes['2335569348']

In [None]:
ps.nodes_json[1]

In [None]:
author_ids

In [None]:
from datetime import datetime
from dateutil.relativedelta import relativedelta

def calculate_months_since(date_str):
    """
    计算给定 'yyyy-mm-dd' 格式的日期距今的月份数。

    Args:
        date_str (str): 'yyyy-mm-dd' 格式的日期字符串。

    Returns:
        int: 给定日期距今的月份数。如果输入格式错误，返回 None。
    """
    try:
        given_date = datetime.strptime(date_str, '%Y-%m-%d').date()
        today = datetime.now().date()
        difference = relativedelta(today, given_date)
        return difference.years * 12 + difference.months
    except ValueError:
        return None

# 示例用法
date_to_calculate = '2024-01-15'
months = calculate_months_since(date_to_calculate)

if months is not None:
    print(f"日期 {date_to_calculate} 距今有 {months} 个月。")
else:
    print("输入的日期格式不正确，请使用 'yyyy-mm-dd' 格式。")

date_to_calculate_invalid = '2024/01/15'
months_invalid = calculate_months_since(date_to_calculate_invalid)

if months_invalid is not None:
    print(f"日期 {date_to_calculate_invalid} 距今有 {months_invalid} 个月。")
else:
    print(f"输入的日期格式 '{date_to_calculate_invalid}' 不正确，请使用 'yyyy-mm-dd' 格式。")

In [None]:
ps.edges_json[0]

In [None]:
# --- Graph Stat ---
from graph.graph_stats import get_graph_stats, get_author_stats, get_paper_stats
g_stat = get_graph_stats(G_pre)   # graph stats

In [None]:
paper_stats = get_paper_stats(G_pre, core_paper_ids)  # paper stats on graph
author_stats = get_author_stats(G_pre, core_author_ids)  # author stats on graph

In [None]:
# check crossref
crossref_stats = []
for x in paper_stats:
    if (x['if_seed'] == False  # exclude seed papers 
        and x['local_citation_cnt'] > min(len(core_paper_ids),  5)):  # select most refered papers in graph
        crossref_stats.append(x)

In [None]:
# calculate similarity
from collect.paper_similarity_calculation import PaperSim

sim = PaperSim(
    embed_api_key = embed_api_key,
    embed_model_name = embed_model_name
)

# --- SIMILARITY CALCULATION ---
# check if similarity with edge type
edge_types = [x[0] for x in g_stat['edge_type']]

# valid paper with abstracts
complete_paper_json = [node for node in ps.nodes_json 
                        if node['labels'] == ['Paper'] 
                        and node['properties'].get('title') is not None and node['properties'].get('abstract') is not None]
complete_paper_dois = [node['id'] for node in complete_paper_json]

if 'SIMILAR_TO' not in edge_types:
    # calculate paper nodes similarity
    semantic_similar_pool = await sim.cal_embed_and_similarity(
        paper_nodes_json = complete_paper_json,
        paper_dois_1 = complete_paper_dois, 
        paper_dois_2 = complete_paper_dois,
        similarity_threshold = 0.7,
        )

    # add similarity edges to graph
    G_pre.add_graph_edges(semantic_similar_pool)  

In [None]:
# --- PRUNNING ---
# pruning by connectivity
sub_graphs = G_pre.find_wcc_subgraphs(target_nodes=core_paper_ids)
if sub_graphs is not None and len(sub_graphs) > 0:
    G_post  = sub_graphs[0]
    # get stats after prunning
    g_stat = get_graph_stats(G_post)
else:
    G_post = G_pre

In [None]:
paper_stats = get_paper_stats(G_post, core_paper_ids)  # paper stats on graph
author_stats = get_author_stats(G_post, core_author_ids)  # author stats on graph

# check crossref
crossref_stats = []
for x in paper_stats:
    if (x['if_seed'] == False  # exclude seed papers 
        and x['local_citation_cnt'] > min(len(core_paper_ids),  5)):  # select most refered papers in graph
        crossref_stats.append(x)

# check key authors
key_authors_stats = []
for x in author_stats:
    if (x['if_seed'] == False  # exclude seed authors 
        and x['local_paper_cnt'] > min(len(core_paper_ids), 5)):  # select most refered papers in graph
        key_authors_stats.append(x)

In [None]:
# check paper similarity
sorted_paper_similarity = sorted(paper_stats, key=lambda x:x['max_sim_to_seed'], reverse=True)

In [None]:
ref_ids = []
# if cross ref insufficient, further expand similar papers on citation chain
if len(crossref_stats) < 20:
    # filter top similar papers (to help build crossref)
    for item in sorted_paper_similarity:
        if item['if_seed'] == False and item['doi'] not in ps.explored_nodes['reference']:
            if item['max_sim_to_seed'] > 0.7 and item['global_citaion_cnt'] > 10:
                ref_ids.append(item['doi'])
        else:
            break

In [None]:
ref_ids = ref_ids[0:20]

In [None]:
# if key authors not have complete information
author_ids = []
if len(key_authors_stats) > 20:
    sorted_key_authors = sorted(key_authors_stats, key=lambda x:x['local_paper_cnt'], reverse=True)
    # filter key authors (to amplify information)
    for item in sorted_key_authors:
        if item['if_seed'] == False and item['author_id'] not in ps.explored_nodes['author']:
            author_ids.append(item['author_id'])

author_ids = author_ids[0:50]

In [None]:
for item in ps.nodes_json:
    if item['labels'] == ['Paper']:
        print(item['properties']['title'])
        # print(item.get('title'))

In [None]:
hop_1_sim_paper_ids = []
for u, v, edge_data in G_post.edges(data=True):
    if edge_data.get('relationshipType') == 'SIMILAR_TO' and edge_data.get('weight') > 0.7:
        if u in core_paper_ids and v not in core_paper_ids:
            hop_1_sim_paper_ids.append(v)
        elif u not in core_paper_ids and v in core_paper_ids:
            hop_1_sim_paper_ids.append(u)

In [None]:
len(hop_1_sim_paper_ids)

In [None]:
hop_1_citation_paper_ids = []
for u, v, edge_data in G_post.edges(data=True):
    if edge_data.get('relationshipType') == 'CITES':
        if u in core_paper_ids and v not in core_paper_ids:
            hop_1_citation_paper_ids.append(v)
        elif u not in core_paper_ids and v in core_paper_ids:
            hop_1_citation_paper_ids.append(u)

In [None]:
hop_1_topic_paper_ids = []
for u, v, edge_data in G_post.edges(data=True):
    if edge_data.get('relationshipType') == 'DISCUSS':
        topic = G_post.nodes[v].get('name')
        if u not in core_paper_ids:
            title = G_post.nodes[u].get('title')
            gloabl_citation = G_post.nodes[u].get('citationCount')
            if gloabl_citation > 10 and u in hop_1_sim_paper_ids:
                print(topic, title, gloabl_citation)
                # hop_1_topic_paper_ids.append(u)

In [None]:
# recommendation, author papers
hop_1_expand_papers = set(list(core_paper_ids) + hop_1_sim_paper_ids + hop_1_citation_paper_ids)

In [None]:
hop_2_sim_paper_ids = []
for u, v, edge_data in G_post.edges(data=True):
    if edge_data.get('relationshipType') == 'SIMILAR_TO' and edge_data.get('weight') > 0.7:
        if u in hop_1_expand_papers and v not in hop_1_expand_papers:
            hop_2_sim_paper_ids.append(v)
        elif u not in hop_1_expand_papers and v in hop_1_expand_papers:
            hop_2_sim_paper_ids.append(u)

In [None]:
len(hop_1_expand_papers)

In [None]:
len(set(hop_2_sim_paper_ids))

In [None]:
for key, values in {'a':1, 'b':2}.items():
    if key == 'a': continue
    print(key, values)

In [None]:
None > 5

In [None]:
There are 6 node types in this graph, they are:
[('Author', 5560), ('Paper', 1933), ('Journal', 513), ('Venue', 380), ('Institution', 5), ('Topic', 4)]
There are 7 edge types in this graph, they are:
[('SIMILAR_TO', 20872), ('WRITES', 9366), ('RELEASES_IN', 1244), ('PRINTS_ON', 702), ('DISCUSS', 419), ('CITES', 288), ('WORKS_IN', 5)]

In [None]:
x = {'a':2, 'b':1, 'c':3}

In [None]:
x_sorted = sorted(x.items(), key=lambda x:x[1], reverse=True)

In [None]:
ranked_keys = {}
for index, (key, value) in enumerate(x_sorted):
    ranked_keys[key] = index

In [None]:
ranked_keys

In [None]:
import asyncio

import sys
import os

parent_dir = os.path.dirname(os.getcwd())
sys.path.append(parent_dir)

from apis.s2_api import SemanticScholarKit

s2 = SemanticScholarKit()


topics = ['llm security']

# --- 2. Create tasks ---
tasks = []

for topic in topics:
    tasks.append(s2.search_paper(
        query=topic,
        limit=100,
    ))

# --- 3. Execute and collect ---
if tasks:
    results = await asyncio.gather(*tasks, return_exceptions=True)


In [None]:
len(results[0])

In [None]:
results[0].items

In [None]:
for idx, result in enumerate(results):
    current_topic = topics_to_search[idx]
    if isinstance(result, Exception):
        logging.error(f"topic_search: Task for topic '{current_topic}' failed: {result}", exc_info=False)
        self.not_found_nodes['topic'].add(current_topic)

    elif isinstance(result, PaginatedResults) and hasattr(result, '_items') and result._items:
        papers_list = result._items

In [None]:
new_papers_count = 0
new_topic_links = 0
existing_paper_ids = {p.get('paperId') for p in self.data_pool['paper'] if p.get('paperId')}

for idx, result in enumerate(results):
    current_topic = topics_to_search[idx]
    if isinstance(result, Exception):
        logging.error(f"topic_search: Task for topic '{current_topic}' failed: {result}", exc_info=False)
        self.not_found_nodes['topic'].add(current_topic)

    elif isinstance(result, PaginatedResults) and hasattr(result, '_items') and result._items:
        papers_list = result._items
        # Add paper metadata to data pool (deduplicated)
        papers_dict = [item.raw_data for item in papers_list if isinstance(item, Paper) and hasattr(item, 'raw_data')]
        new_papers = [p for p in papers_dict if p.get('paperId') and p['paperId'] not in existing_paper_ids]
        self.data_pool['paper'].extend(new_papers)
        new_papers_count += len(new_papers)
        existing_paper_ids.update(p['paperId'] for p in new_papers) # Update seen IDs

        # Add topic metadata links
        paper_ids_in_result = [paper.get('paperId') for paper in papers_dict if paper.get('paperId')]
        for pid in paper_ids_in_result:
            # Use 'paperId' key for consistency
            self.data_pool['topic'].append({'topic': current_topic, 'paperId': pid})
            new_topic_links += 1
    else:
        logging.warning(f"topic_search: Task for topic '{current_topic}' returned no results or failed silently.")
        self.not_found_nodes['topic'].add(current_topic) # Mark as not found if no results

logging.info(f"topic_search: Added {new_papers_count} new papers and {new_topic_links} topic links to data pool.")


In [None]:
from datetime import datetime
from dateutil.relativedelta import relativedelta
import numpy as np

def graph_basic_stats(paper_graph):
    """
    统计 paper_graph 中 paper 和 author 的基本指标。

    Args:
        paper_graph: 一个 NetworkX 图对象，包含 'Paper' 和 'Author' 类型的节点。

    Returns:
        一个字典，包含 paper_stats 和 author_stats 两个子字典，
        分别存储 paper 和 author 的基本指标。
    """
    paper_stats = {
        'citationCount': [],
        'influentialCitationCount': [],
        'referenceCount': [],
        'monthlyCitationCount': [],
        'localCitationCount': [],
        'localReferenceCount': []
    }
    author_stats = {
        'paperCount': [],
        'citationCount': [],
        'hIndex': [],
        'localPaperCount': []
    }

    for nid, node_data in paper_graph.nodes(data=True):
        if node_data.get('nodeType') == 'Paper':
            pub_dt = node_data.get('publicationDate')
            paper_cit_cnt = node_data.get('citationCount')
            paper_sig_cit_cnt = node_data.get('influentialCitationCount')
            paper_ref_cnt = node_data.get('referenceCount')

            if pub_dt:
                try:
                    tm_to_dt = relativedelta(datetime.now().date(), datetime.strptime(pub_dt, '%Y-%m-%d').date())
                    mth_to_dt = tm_to_dt.years * 12 + tm_to_dt.months
                    if paper_cit_cnt is not None and mth_to_dt > 0:
                        paper_mthly_cit_cnt = paper_cit_cnt / mth_to_dt
                        paper_stats['monthlyCitationCount'].append(paper_mthly_cit_cnt)
                except ValueError:
                    print(f"Warning: Invalid date format for paper {nid}: {pub_dt}")

            if paper_cit_cnt is not None:
                paper_stats['citationCount'].append(paper_cit_cnt)
            if paper_sig_cit_cnt is not None:
                paper_stats['influentialCitationCount'].append(paper_sig_cit_cnt)
            if paper_ref_cnt is not None:
                paper_stats['referenceCount'].append(paper_ref_cnt)

            predecessors = paper_graph.predecessors(nid)
            paper_loc_cit_cnt = sum([1 for x in predecessors if paper_graph.nodes[x].get('nodeType') == 'Paper'])
            paper_stats['localCitationCount'].append(paper_loc_cit_cnt)

            successors = paper_graph.successors(nid)
            paper_loc_ref_cnt = sum([1 for x in successors if paper_graph.nodes[x].get('nodeType') == 'Paper'])
            paper_stats['localReferenceCount'].append(paper_loc_ref_cnt)

        elif node_data.get('nodeType') == 'Author':
            author_paper_cnt = node_data.get('paperCount')
            author_citation_cnt = node_data.get('citationCount')
            author_h_index = node_data.get('hIndex')

            if author_paper_cnt is not None:
                author_stats['paperCount'].append(author_paper_cnt)
            if author_citation_cnt is not None:
                author_stats['citationCount'].append(author_citation_cnt)
            if author_h_index is not None:
                author_stats['hIndex'].append(author_h_index)

            successors = paper_graph.successors(nid)
            author_loc_paper_cnt = sum([1 for x in successors if paper_graph.nodes[x].get('nodeType') == 'Paper'])
            author_stats['localPaperCount'].append(author_loc_paper_cnt)

    return {'paper_stats': paper_stats, 'author_stats': author_stats}

def analyze_graph_stats(basic_stats):
    """
    对 paper 和 author 的基本统计指标进行分析，计算 min, max, avg, 分位数等。

    Args:
        basic_stats: 一个字典，包含 paper_stats 和 author_stats 子字典，
                     由 graph_basic_stats 函数返回。

    Returns:
        一个字典，包含 paper_analysis 和 author_analysis 两个子字典，
        分别存储 paper 和 author 各指标的统计分析结果。
    """
    paper_analysis = {}
    author_analysis = {}

    paper_data = basic_stats.get('paper_stats', {})
    author_data = basic_stats.get('author_stats', {})

    for key, values in paper_data.items():
        valid_values = [v for v in values if isinstance(v, (int, float))] # 只处理数值类型
        if valid_values:
            paper_analysis[key] = {
                'min': np.min(valid_values),
                'max': np.max(valid_values),
                'average': np.mean(valid_values),
                'median': np.median(valid_values),
                'quantile_25': np.percentile(valid_values, 25),
                'quantile_75': np.percentile(valid_values, 75)
            }
        else:
            paper_analysis[key] = {}

    for key, values in author_data.items():
        valid_values = [v for v in values if isinstance(v, (int, float))] # 只处理数值类型
        if valid_values:
            author_analysis[key] = {
                'min': np.min(valid_values),
                'max': np.max(valid_values),
                'average': np.mean(valid_values),
                'median': np.median(valid_values),
                'quantile_25': np.percentile(valid_values, 25),
                'quantile_75': np.percentile(valid_values, 75)
            }
        else:
            author_analysis[key] = {}

    return {'paper_analysis': paper_analysis, 'author_analysis': author_analysis}

# 示例用法 (假设你已经有了一个 paper_graph 对象)
if __name__ == '__main__':
    import networkx as nx
    import random

    # 创建一个简单的示例 paper_graph
    paper_graph = nx.DiGraph()
    paper_graph.add_node(1, nodeType='Paper', publicationDate='2023-01-15', citationCount=10, influentialCitationCount=5, referenceCount=20)
    paper_graph.add_node(2, nodeType='Paper', publicationDate='2023-03-20', citationCount=15, influentialCitationCount=8, referenceCount=25)
    paper_graph.add_node(3, nodeType='Author', paperCount=2, citationCount=25, hIndex=3)
    paper_graph.add_node(4, nodeType='Author', paperCount=1, citationCount=5, hIndex=1)
    paper_graph.add_edge(1, 2)
    paper_graph.add_edge(3, 1)
    paper_graph.add_edge(3, 2)
    paper_graph.add_edge(1, 3)

    basic_stats = graph_basic_stats(paper_graph)
    print("Basic Stats:")
    print(basic_stats)

    analysis_results = analyze_graph_stats(basic_stats)
    print("\nAnalysis Results:")
    print(analysis_results)