From 05581dc9aad94ebe3d26f6742b96bb1bf7c29e66 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 30 Jul 2025 17:51:19 +0800 Subject: [PATCH 1/6] refactor: refact search_wiki --- graphgen/configs/graphgen_config.yaml | 4 +- graphgen/generate.py | 94 +++++---- graphgen/graphgen.py | 178 +++++++++++------- graphgen/models/__init__.py | 21 +-- graphgen/models/search/kg/__init__.py | 0 .../models/search/{ => kg}/wiki_search.py | 3 +- graphgen/models/search/web/bing_search.py | 0 graphgen/operators/__init__.py | 15 +- graphgen/operators/search/__init__.py | 0 graphgen/operators/search/db/__init__.py | 0 .../operators/search/db/search_mongodb.py | 0 .../operators/search/db/search_uniprot.py | 0 graphgen/operators/search/kg/__init__.py | 0 .../operators/search/kg/search_google_kg.py | 0 .../operators/search/kg/search_wikipedia.py | 84 +++++++++ graphgen/operators/search/search_all.py | 37 ++++ graphgen/operators/search/web/__init__.py | 0 graphgen/operators/search/web/search_bing.py | 0 .../operators/search/web/search_google.py | 0 graphgen/operators/search_wikipedia.py | 71 ------- 20 files changed, 312 insertions(+), 195 deletions(-) create mode 100644 graphgen/models/search/kg/__init__.py rename graphgen/models/search/{ => kg}/wiki_search.py (99%) create mode 100644 graphgen/models/search/web/bing_search.py create mode 100644 graphgen/operators/search/__init__.py create mode 100644 graphgen/operators/search/db/__init__.py create mode 100644 graphgen/operators/search/db/search_mongodb.py create mode 100644 graphgen/operators/search/db/search_uniprot.py create mode 100644 graphgen/operators/search/kg/__init__.py create mode 100644 graphgen/operators/search/kg/search_google_kg.py create mode 100644 graphgen/operators/search/kg/search_wikipedia.py create mode 100644 graphgen/operators/search/search_all.py create mode 100644 graphgen/operators/search/web/__init__.py create mode 100644 graphgen/operators/search/web/search_bing.py create mode 100644 graphgen/operators/search/web/search_google.py delete mode 100644 graphgen/operators/search_wikipedia.py diff --git a/graphgen/configs/graphgen_config.yaml b/graphgen/configs/graphgen_config.yaml index 4ddb66c7..b02eaf5f 100644 --- a/graphgen/configs/graphgen_config.yaml +++ b/graphgen/configs/graphgen_config.yaml @@ -12,5 +12,7 @@ traverse_strategy: max_extra_edges: 2 max_tokens: 256 loss_strategy: only_edge -web_search: false +search: + if_search: true + search_types: ["wikipedia", "google"] re_judge: false diff --git a/graphgen/generate.py b/graphgen/generate.py index 14693471..1165d63d 100644 --- a/graphgen/generate.py +++ b/graphgen/generate.py @@ -1,8 +1,9 @@ -import os +import argparse import json +import os import time -import argparse from importlib.resources import files + import yaml from dotenv import load_dotenv @@ -14,47 +15,63 @@ load_dotenv() + def set_working_dir(folder): os.makedirs(folder, exist_ok=True) os.makedirs(os.path.join(folder, "data", "graphgen"), exist_ok=True) os.makedirs(os.path.join(folder, "logs"), exist_ok=True) + def save_config(config_path, global_config): if not os.path.exists(os.path.dirname(config_path)): os.makedirs(os.path.dirname(config_path)) - with open(config_path, "w", encoding='utf-8') as config_file: - yaml.dump(global_config, config_file, default_flow_style=False, allow_unicode=True) + with open(config_path, "w", encoding="utf-8") as config_file: + yaml.dump( + global_config, config_file, default_flow_style=False, allow_unicode=True + ) + def main(): parser = argparse.ArgumentParser() - parser.add_argument('--config_file', - help='Config parameters for GraphGen.', - # default=os.path.join(sys_path, "configs", "graphgen_config.yaml"), - default=files('graphgen').joinpath("configs", "graphgen_config.yaml"), - type=str) - parser.add_argument('--output_dir', - help='Output directory for GraphGen.', - default=sys_path, - required=True, - type=str) + parser.add_argument( + "--config_file", + help="Config parameters for GraphGen.", + default=files("graphgen").joinpath("configs", "graphgen_config.yaml"), + type=str, + ) + parser.add_argument( + "--output_dir", + help="Output directory for GraphGen.", + default=sys_path, + required=True, + type=str, + ) args = parser.parse_args() working_dir = args.output_dir set_working_dir(working_dir) unique_id = int(time.time()) - set_logger(os.path.join(working_dir, "logs", f"graphgen_{unique_id}.log"), if_stream=False) + set_logger( + os.path.join(working_dir, "logs", f"graphgen_{unique_id}.log"), if_stream=False + ) + print( + "GraphGen with unique ID", + unique_id, + "logging to", + os.path.join(working_dir, "logs", f"graphgen_{unique_id}.log"), + ) - with open(args.config_file, "r", encoding='utf-8') as f: + with open(args.config_file, "r", encoding="utf-8") as f: config = yaml.load(f, Loader=yaml.FullLoader) - input_file = config['input_file'] + input_file = config["input_file"] - if config['data_type'] == 'raw': - with open(input_file, "r", encoding='utf-8') as f: + if config["data_type"] == "raw": + with open(input_file, "r", encoding="utf-8") as f: data = [json.loads(line) for line in f] - elif config['data_type'] == 'chunked': - with open(input_file, "r", encoding='utf-8') as f: + elif config["data_type"] == "chunked": + with open(input_file, "r", encoding="utf-8") as f: data = json.load(f) else: raise ValueError(f"Invalid data type: {config['data_type']}") @@ -62,40 +79,37 @@ def main(): synthesizer_llm_client = OpenAIModel( model_name=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), - base_url=os.getenv("SYNTHESIZER_BASE_URL") + base_url=os.getenv("SYNTHESIZER_BASE_URL"), ) trainee_llm_client = OpenAIModel( model_name=os.getenv("TRAINEE_MODEL"), api_key=os.getenv("TRAINEE_API_KEY"), - base_url=os.getenv("TRAINEE_BASE_URL") + base_url=os.getenv("TRAINEE_BASE_URL"), ) - traverse_strategy = TraverseStrategy( - **config['traverse_strategy'] - ) + traverse_strategy = TraverseStrategy(**config["traverse_strategy"]) graph_gen = GraphGen( working_dir=working_dir, unique_id=unique_id, synthesizer_llm_client=synthesizer_llm_client, trainee_llm_client=trainee_llm_client, - if_web_search=config['web_search'], - tokenizer_instance=Tokenizer( - model_name=config['tokenizer'] - ), - traverse_strategy=traverse_strategy + search=config["search"], + tokenizer_instance=Tokenizer(model_name=config["tokenizer"]), + traverse_strategy=traverse_strategy, ) - graph_gen.insert(data, config['data_type']) - - graph_gen.quiz(max_samples=config['quiz_samples']) - - graph_gen.judge(re_judge=config["re_judge"]) + graph_gen.insert(data, config["data_type"]) - graph_gen.traverse() + # graph_gen.quiz(max_samples=config['quiz_samples']) + # + # graph_gen.judge(re_judge=config["re_judge"]) + # + # graph_gen.traverse() + # + # path = os.path.join(working_dir, "data", "graphgen", str(unique_id), f"config-{unique_id}.yaml") + # save_config(path, config) - path = os.path.join(working_dir, "data", "graphgen", str(unique_id), f"config-{unique_id}.yaml") - save_config(path, config) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 265d32a9..9f384792 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -1,5 +1,3 @@ -# Adapt from https://github.com/HKUDS/LightRAG - import asyncio import os import time @@ -16,14 +14,13 @@ OpenAIModel, Tokenizer, TraverseStrategy, - WikiSearch, ) from .models.storage.base_storage import StorageNameSpace from .operators import ( extract_kg, judge_statement, quiz, - search_wikipedia, + search_all, skip_judge_statement, traverse_graph_atomically, traverse_graph_by_edge, @@ -33,6 +30,7 @@ sys_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + @dataclass class GraphGen: unique_id: int = int(time.time()) @@ -47,11 +45,12 @@ class GraphGen: trainee_llm_client: OpenAIModel = None tokenizer_instance: Tokenizer = None - # web search - if_web_search: bool = False - wiki_client: WikiSearch = field(default_factory=WikiSearch) + # search + search: dict = field( + default_factory=lambda: {"if_search": False, "search_types": ["wikipedia"]} + ) - # traverse strategy + # traverse traverse_strategy: TraverseStrategy = field(default_factory=TraverseStrategy) # webui @@ -64,20 +63,23 @@ def __post_init__(self): self.text_chunks_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="text_chunks" ) - self.wiki_storage: JsonKVStorage = JsonKVStorage( - self.working_dir, namespace="wiki" - ) self.graph_storage: NetworkXStorage = NetworkXStorage( self.working_dir, namespace="graph" ) + self.search_storage: JsonKVStorage = JsonKVStorage( + self.working_dir, namespace="search" + ) self.rephrase_storage: JsonKVStorage = JsonKVStorage( self.working_dir, namespace="rephrase" ) self.qa_storage: JsonKVStorage = JsonKVStorage( - os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)), namespace=f"qa-{self.unique_id}" + os.path.join(self.working_dir, "data", "graphgen", str(self.unique_id)), + namespace=f"qa-{self.unique_id}", ) - async def async_split_chunks(self, data: Union[List[list], List[dict]], data_type: str) -> dict: + async def async_split_chunks( + self, data: Union[List[list], List[dict]], data_type: str + ) -> dict: # TODO: 是否进行指代消解 if len(data) == 0: return {} @@ -88,9 +90,14 @@ async def async_split_chunks(self, data: Union[List[list], List[dict]], data_typ assert isinstance(data, list) and isinstance(data[0], dict) # compute hash for each document new_docs = { - compute_content_hash(doc['content'], prefix="doc-"): {'content': doc['content']} for doc in data + compute_content_hash(doc["content"], prefix="doc-"): { + "content": doc["content"] + } + for doc in data } - _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys())) + _add_doc_keys = await self.full_docs_storage.filter_keys( + list(new_docs.keys()) + ) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} if len(new_docs) == 0: logger.warning("All docs are already in the storage") @@ -100,47 +107,62 @@ async def async_split_chunks(self, data: Union[List[list], List[dict]], data_typ cur_index = 1 doc_number = len(new_docs) async for doc_key, doc in tqdm_async( - new_docs.items(), desc="[1/4]Chunking documents", unit="doc" - ): + new_docs.items(), desc="[1/4]Chunking documents", unit="doc" + ): chunks = { compute_content_hash(dp["content"], prefix="chunk-"): { **dp, - 'full_doc_id': doc_key - } for dp in self.tokenizer_instance.chunk_by_token_size(doc["content"], - self.chunk_overlap_size, self.chunk_size) + "full_doc_id": doc_key, + } + for dp in self.tokenizer_instance.chunk_by_token_size( + doc["content"], self.chunk_overlap_size, self.chunk_size + ) } inserting_chunks.update(chunks) if self.progress_bar is not None: - self.progress_bar( - cur_index / doc_number, f"Chunking {doc_key}" - ) + self.progress_bar(cur_index / doc_number, f"Chunking {doc_key}") cur_index += 1 - _add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys())) - inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys} + _add_chunk_keys = await self.text_chunks_storage.filter_keys( + list(inserting_chunks.keys()) + ) + inserting_chunks = { + k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys + } elif data_type == "chunked": assert isinstance(data, list) and isinstance(data[0], list) new_docs = { - compute_content_hash("".join(chunk['content']), prefix="doc-"): {'content': "".join(chunk['content'])} - for doc in data for chunk in doc + compute_content_hash("".join(chunk["content"]), prefix="doc-"): { + "content": "".join(chunk["content"]) + } + for doc in data + for chunk in doc } - _add_doc_keys = await self.full_docs_storage.filter_keys(list(new_docs.keys())) + _add_doc_keys = await self.full_docs_storage.filter_keys( + list(new_docs.keys()) + ) new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys} if len(new_docs) == 0: logger.warning("All docs are already in the storage") return {} logger.info("[New Docs] inserting %d docs", len(new_docs)) - async for doc in tqdm_async(data, desc="[1/4]Chunking documents", unit="doc"): - doc_str = "".join([chunk['content'] for chunk in doc]) + async for doc in tqdm_async( + data, desc="[1/4]Chunking documents", unit="doc" + ): + doc_str = "".join([chunk["content"] for chunk in doc]) for chunk in doc: - chunk_key = compute_content_hash(chunk['content'], prefix="chunk-") + chunk_key = compute_content_hash(chunk["content"], prefix="chunk-") inserting_chunks[chunk_key] = { **chunk, - 'full_doc_id': compute_content_hash(doc_str, prefix="doc-") + "full_doc_id": compute_content_hash(doc_str, prefix="doc-"), } - _add_chunk_keys = await self.text_chunks_storage.filter_keys(list(inserting_chunks.keys())) - inserting_chunks = {k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys} + _add_chunk_keys = await self.text_chunks_storage.filter_keys( + list(inserting_chunks.keys()) + ) + inserting_chunks = { + k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys + } await self.full_docs_storage.upsert(new_docs) await self.text_chunks_storage.upsert(inserting_chunks) @@ -169,29 +191,39 @@ async def async_insert(self, data: Union[List[list], List[dict]], data_type: str llm_client=self.synthesizer_llm_client, kg_instance=self.graph_storage, tokenizer_instance=self.tokenizer_instance, - chunks=[Chunk(id=k, content=v['content']) for k, v in inserting_chunks.items()], - progress_bar = self.progress_bar, + chunks=[ + Chunk(id=k, content=v["content"]) for k, v in inserting_chunks.items() + ], + progress_bar=self.progress_bar, ) if not _add_entities_and_relations: logger.warning("No entities or relations extracted") return - logger.info("[Wiki Search] is %s", 'enabled' if self.if_web_search else 'disabled') - if self.if_web_search: - logger.info("[Wiki Search]...") - _add_wiki_data = await search_wikipedia( - llm_client= self.synthesizer_llm_client, - wiki_search_client=self.wiki_client, - knowledge_graph_instance=_add_entities_and_relations + logger.info( + "Search is %s", "enabled" if self.search["if_search"] else "disabled" + ) + if self.search["if_search"]: + logger.info("[Search] %s ...", ", ".join(self.search["search_types"])) + _add_search_data = await search_all( + llm_client=self.synthesizer_llm_client, + search_types=self.search["search_types"], + kg_instance=_add_entities_and_relations, ) - await self.wiki_storage.upsert(_add_wiki_data) + if _add_search_data: + await self.search_storage.upsert(_add_search_data) + logger.info("[Search] %d entities searched", len(_add_search_data)) await self._insert_done() async def _insert_done(self): tasks = [] - for storage_instance in [self.full_docs_storage, self.text_chunks_storage, - self.graph_storage, self.wiki_storage]: + for storage_instance in [ + self.full_docs_storage, + self.text_chunks_storage, + self.graph_storage, + self.search_storage, + ]: if storage_instance is None: continue tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback()) @@ -202,7 +234,12 @@ def quiz(self, max_samples=1): loop.run_until_complete(self.async_quiz(max_samples)) async def async_quiz(self, max_samples=1): - await quiz(self.synthesizer_llm_client, self.graph_storage, self.rephrase_storage, max_samples) + await quiz( + self.synthesizer_llm_client, + self.graph_storage, + self.rephrase_storage, + max_samples, + ) await self.rephrase_storage.index_done_callback() def judge(self, re_judge=False, skip=False): @@ -213,8 +250,12 @@ async def async_judge(self, re_judge=False, skip=False): if skip: _update_relations = await skip_judge_statement(self.graph_storage) else: - _update_relations = await judge_statement(self.trainee_llm_client, self.graph_storage, - self.rephrase_storage, re_judge) + _update_relations = await judge_statement( + self.trainee_llm_client, + self.graph_storage, + self.rephrase_storage, + re_judge, + ) await _update_relations.index_done_callback() def traverse(self): @@ -223,23 +264,32 @@ def traverse(self): async def async_traverse(self): if self.traverse_strategy.qa_form == "atomic": - results = await traverse_graph_atomically(self.synthesizer_llm_client, - self.tokenizer_instance, - self.graph_storage, - self.traverse_strategy, - self.text_chunks_storage, - self.progress_bar) + results = await traverse_graph_atomically( + self.synthesizer_llm_client, + self.tokenizer_instance, + self.graph_storage, + self.traverse_strategy, + self.text_chunks_storage, + self.progress_bar, + ) elif self.traverse_strategy.qa_form == "multi_hop": - results = await traverse_graph_for_multi_hop(self.synthesizer_llm_client, - self.tokenizer_instance, - self.graph_storage, - self.traverse_strategy, - self.text_chunks_storage, - self.progress_bar) + results = await traverse_graph_for_multi_hop( + self.synthesizer_llm_client, + self.tokenizer_instance, + self.graph_storage, + self.traverse_strategy, + self.text_chunks_storage, + self.progress_bar, + ) elif self.traverse_strategy.qa_form == "aggregated": - results = await traverse_graph_by_edge(self.synthesizer_llm_client, self.tokenizer_instance, - self.graph_storage, self.traverse_strategy, self.text_chunks_storage, - self.progress_bar) + results = await traverse_graph_by_edge( + self.synthesizer_llm_client, + self.tokenizer_instance, + self.graph_storage, + self.traverse_strategy, + self.text_chunks_storage, + self.progress_bar, + ) else: raise ValueError(f"Unknown qa_form: {self.traverse_strategy.qa_form}") await self.qa_storage.upsert(results) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index c2f9e714..a9190fa5 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,22 +1,17 @@ -from .text.chunk import Chunk -from .text.text_pair import TextPair - -from .llm.topk_token_model import Token, TopkTokenModel -from .llm.openai_model import OpenAIModel -from .llm.tokenizer import Tokenizer - -from .storage.networkx_storage import NetworkXStorage -from .storage.json_storage import JsonKVStorage - -from .search.wiki_search import WikiSearch +from graphgen.models.search.kg.wiki_search import WikiSearch from .evaluate.length_evaluator import LengthEvaluator from .evaluate.mtld_evaluator import MTLDEvaluator from .evaluate.reward_evaluator import RewardEvaluator from .evaluate.uni_evaluator import UniEvaluator - +from .llm.openai_model import OpenAIModel +from .llm.tokenizer import Tokenizer +from .llm.topk_token_model import Token, TopkTokenModel +from .storage.json_storage import JsonKVStorage +from .storage.networkx_storage import NetworkXStorage from .strategy.travserse_strategy import TraverseStrategy - +from .text.chunk import Chunk +from .text.text_pair import TextPair __all__ = [ # llm models diff --git a/graphgen/models/search/kg/__init__.py b/graphgen/models/search/kg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/models/search/wiki_search.py b/graphgen/models/search/kg/wiki_search.py similarity index 99% rename from graphgen/models/search/wiki_search.py rename to graphgen/models/search/kg/wiki_search.py index db312a2b..cb080bc0 100644 --- a/graphgen/models/search/wiki_search.py +++ b/graphgen/models/search/kg/wiki_search.py @@ -1,8 +1,9 @@ -from typing import List, Union from dataclasses import dataclass +from typing import List, Union import wikipedia from wikipedia import set_lang + from graphgen.utils import detect_main_language, logger diff --git a/graphgen/models/search/web/bing_search.py b/graphgen/models/search/web/bing_search.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index 8ef14fdc..a56d06b2 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -1,16 +1,21 @@ +from graphgen.operators.search.search_all import search_all + from .extract_kg import extract_kg -from .quiz import quiz from .judge import judge_statement, skip_judge_statement -from .search_wikipedia import search_wikipedia -from .traverse_graph import traverse_graph_by_edge, traverse_graph_atomically, traverse_graph_for_multi_hop +from .quiz import quiz +from .traverse_graph import ( + traverse_graph_atomically, + traverse_graph_by_edge, + traverse_graph_for_multi_hop, +) __all__ = [ "extract_kg", "quiz", "judge_statement", "skip_judge_statement", - "search_wikipedia", + "search_all", "traverse_graph_by_edge", "traverse_graph_atomically", - "traverse_graph_for_multi_hop" + "traverse_graph_for_multi_hop", ] diff --git a/graphgen/operators/search/__init__.py b/graphgen/operators/search/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/search/db/__init__.py b/graphgen/operators/search/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/search/db/search_mongodb.py b/graphgen/operators/search/db/search_mongodb.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/search/db/search_uniprot.py b/graphgen/operators/search/db/search_uniprot.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/search/kg/__init__.py b/graphgen/operators/search/kg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/search/kg/search_google_kg.py b/graphgen/operators/search/kg/search_google_kg.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/search/kg/search_wikipedia.py b/graphgen/operators/search/kg/search_wikipedia.py new file mode 100644 index 00000000..b594f5da --- /dev/null +++ b/graphgen/operators/search/kg/search_wikipedia.py @@ -0,0 +1,84 @@ +from tqdm.asyncio import tqdm_asyncio as tqdm_async + +from graphgen.models import NetworkXStorage, OpenAIModel, WikiSearch +from graphgen.templates import SEARCH_JUDGEMENT_PROMPT +from graphgen.utils import logger + + +async def _process_single_entity( + entity_name: str, + description: str, + llm_client: OpenAIModel, + wiki_search_client: WikiSearch, +) -> tuple[str, None] | tuple[str, str]: + """ + Process single entity + + """ + search_results = await wiki_search_client.search(entity_name) + if not search_results: + return entity_name, None + examples = "\n".join(SEARCH_JUDGEMENT_PROMPT["EXAMPLES"]) + search_results.append("None of the above") + + search_results_str = "\n".join( + [f"{i + 1}. {sr}" for i, sr in enumerate(search_results)] + ) + prompt = SEARCH_JUDGEMENT_PROMPT["TEMPLATE"].format( + examples=examples, + entity_name=entity_name, + description=description, + search_results=search_results_str, + ) + response = await llm_client.generate_answer(prompt) + try: + response = response.strip() + response = int(response) + if response < 1 or response >= len(search_results): + response = None + else: + response = await wiki_search_client.summary(search_results[response - 1]) + except ValueError: + response = None + + logger.info( + "Entity %s search result: %s response: %s", + entity_name, + str(search_results), + response, + ) + + return entity_name, response + + +async def search_wikipedia( + llm_client: OpenAIModel, + wiki_search_client: WikiSearch, + kg_instance: NetworkXStorage, +) -> dict: + """ + Search wikipedia for entities + + :param llm_client: LLM model + :param wiki_search_client: wiki search client + :param kg_instance: knowledge graph instance + :return: nodes with search results + """ + nodes = await kg_instance.get_all_nodes() + nodes = list(nodes) + wiki_data = {} + + async for node in tqdm_async( + (node for node in nodes), desc="Searching Wikipedia", total=len(nodes) + ): + entity_name = node[0].strip('"') + description = node[1]["description"] + try: + entity, summary = await _process_single_entity( + entity_name, description, llm_client, wiki_search_client + ) + wiki_data[entity] = summary + logger.info("Searched entity: %s, Summary: %s", entity, summary) + except Exception as e: # pylint: disable=broad-except + logger.error("Error processing entity %s: %s", entity_name, str(e)) + return wiki_data diff --git a/graphgen/operators/search/search_all.py b/graphgen/operators/search/search_all.py new file mode 100644 index 00000000..0d2093a4 --- /dev/null +++ b/graphgen/operators/search/search_all.py @@ -0,0 +1,37 @@ +from graphgen.models import NetworkXStorage, OpenAIModel +from graphgen.utils import logger + + +async def search_all( + llm_client: OpenAIModel, search_types: dict, kg_instance: NetworkXStorage +) -> dict[str, str]: + """ + :param llm_client + :param search_types + :param kg_instance + :return: nodes with search results + """ + + results = {} + + for search_type in search_types: + if search_type == "wikipedia": + from graphgen.models import WikiSearch + from graphgen.operators.search.kg.search_wikipedia import search_wikipedia + + wiki_search_client = WikiSearch() + + await search_wikipedia(llm_client, wiki_search_client, kg_instance) + # elif search_type == "google": + # from graphgen.operators.search.web.search_google import search_google + # return await search_google(llm_client, kg_instance) + # + # elif search_type == "bing": + # from graphgen.operators.search.web.search_bing import search_bing + # return await search_bing(llm_client, kg_instance) + + else: + logger.error("Search type %s is not supported yet.", search_type) + continue + + return results diff --git a/graphgen/operators/search/web/__init__.py b/graphgen/operators/search/web/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/search/web/search_bing.py b/graphgen/operators/search/web/search_bing.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/search/web/search_google.py b/graphgen/operators/search/web/search_google.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/operators/search_wikipedia.py b/graphgen/operators/search_wikipedia.py deleted file mode 100644 index d3d7e283..00000000 --- a/graphgen/operators/search_wikipedia.py +++ /dev/null @@ -1,71 +0,0 @@ -import asyncio -from graphgen.models import WikiSearch, OpenAIModel -from graphgen.models.storage.base_storage import BaseGraphStorage -from graphgen.templates import SEARCH_JUDGEMENT_PROMPT -from graphgen.utils import logger - - -async def _process_single_entity(entity_name: str, - description: str, - llm_client: OpenAIModel, - wiki_search_client: WikiSearch) -> tuple[str, None] | tuple[str, str]: - """ - Process single entity - - """ - search_results = await wiki_search_client.search(entity_name) - if not search_results: - return entity_name, None - examples = "\n".join(SEARCH_JUDGEMENT_PROMPT["EXAMPLES"]) - search_results.append("None of the above") - - search_results_str = "\n".join([f"{i + 1}. {sr}" for i, sr in enumerate(search_results)]) - prompt = SEARCH_JUDGEMENT_PROMPT["TEMPLATE"].format( - examples=examples, - entity_name=entity_name, - description=description, - search_results=search_results_str, - ) - response = await llm_client.generate_answer(prompt) - - try: - response = response.strip() - response = int(response) - if response < 1 or response >= len(search_results): - response = None - else: - response = await wiki_search_client.summary(search_results[response - 1]) - except ValueError: - response = None - - logger.info("Entity %s search result: %s response: %s", entity_name, str(search_results), response) - - return entity_name, response - -async def search_wikipedia(llm_client: OpenAIModel, - wiki_search_client: WikiSearch, - knowledge_graph_instance: BaseGraphStorage,) -> dict: - """ - Search wikipedia for entities - - :param llm_client: LLM model - :param wiki_search_client: wiki search client - :param knowledge_graph_instance: knowledge graph instance - :return: nodes with search results - """ - - - nodes = await knowledge_graph_instance.get_all_nodes() - nodes = list(nodes) - wiki_data = {} - - tasks = [ - _process_single_entity(node[0].strip('"'), node[1]["description"], llm_client, wiki_search_client) - for node in nodes - ] - - for task in asyncio.as_completed(tasks): - result = await task - wiki_data[result[0]] = result[1] - - return wiki_data From b54632addf38514b475d91f5494ecb76d8c4ddcc Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 31 Jul 2025 12:26:33 +0800 Subject: [PATCH 2/6] feat: use search results to enrich data --- graphgen/configs/graphgen_config.yaml | 4 +- graphgen/generate.py | 18 ++---- graphgen/graphgen.py | 61 +++++++++++++------ .../operators/search/kg/search_wikipedia.py | 5 +- graphgen/operators/search/search_all.py | 11 +++- graphgen/utils/__init__.py | 20 +++--- graphgen/utils/file.py | 24 ++++++++ resources/examples/keywords_demo.txt | 5 ++ 8 files changed, 104 insertions(+), 44 deletions(-) create mode 100644 graphgen/utils/file.py create mode 100644 resources/examples/keywords_demo.txt diff --git a/graphgen/configs/graphgen_config.yaml b/graphgen/configs/graphgen_config.yaml index b02eaf5f..6239b104 100644 --- a/graphgen/configs/graphgen_config.yaml +++ b/graphgen/configs/graphgen_config.yaml @@ -1,5 +1,5 @@ data_type: raw -input_file: resources/examples/raw_demo.jsonl +input_file: resources/examples/keywords_demo.txt tokenizer: cl100k_base quiz_samples: 2 traverse_strategy: @@ -13,6 +13,6 @@ traverse_strategy: max_tokens: 256 loss_strategy: only_edge search: - if_search: true + enabled: true search_types: ["wikipedia", "google"] re_judge: false diff --git a/graphgen/generate.py b/graphgen/generate.py index 1165d63d..7208d597 100644 --- a/graphgen/generate.py +++ b/graphgen/generate.py @@ -1,5 +1,4 @@ import argparse -import json import os import time from importlib.resources import files @@ -9,7 +8,7 @@ from .graphgen import GraphGen from .models import OpenAIModel, Tokenizer, TraverseStrategy -from .utils import set_logger +from .utils import read_file, set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) @@ -66,15 +65,7 @@ def main(): config = yaml.load(f, Loader=yaml.FullLoader) input_file = config["input_file"] - - if config["data_type"] == "raw": - with open(input_file, "r", encoding="utf-8") as f: - data = [json.loads(line) for line in f] - elif config["data_type"] == "chunked": - with open(input_file, "r", encoding="utf-8") as f: - data = json.load(f) - else: - raise ValueError(f"Invalid data type: {config['data_type']}") + data = read_file(input_file) synthesizer_llm_client = OpenAIModel( model_name=os.getenv("SYNTHESIZER_MODEL"), @@ -94,13 +85,16 @@ def main(): unique_id=unique_id, synthesizer_llm_client=synthesizer_llm_client, trainee_llm_client=trainee_llm_client, - search=config["search"], + search_config=config["search"], tokenizer_instance=Tokenizer(model_name=config["tokenizer"]), traverse_strategy=traverse_strategy, ) graph_gen.insert(data, config["data_type"]) + if config["search"]["enabled"]: + graph_gen.search() + # graph_gen.quiz(max_samples=config['quiz_samples']) # # graph_gen.judge(re_judge=config["re_judge"]) diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 9f384792..35df98db 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -46,8 +46,8 @@ class GraphGen: tokenizer_instance: Tokenizer = None # search - search: dict = field( - default_factory=lambda: {"if_search": False, "search_types": ["wikipedia"]} + search_config: dict = field( + default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]} ) # traverse @@ -84,7 +84,6 @@ async def async_split_chunks( if len(data) == 0: return {} - new_docs = {} inserting_chunks = {} if data_type == "raw": assert isinstance(data, list) and isinstance(data[0], dict) @@ -163,6 +162,8 @@ async def async_split_chunks( inserting_chunks = { k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys } + else: + raise ValueError(f"Unknown data type: {data_type}") await self.full_docs_storage.upsert(new_docs) await self.text_chunks_storage.upsert(inserting_chunks) @@ -200,20 +201,6 @@ async def async_insert(self, data: Union[List[list], List[dict]], data_type: str logger.warning("No entities or relations extracted") return - logger.info( - "Search is %s", "enabled" if self.search["if_search"] else "disabled" - ) - if self.search["if_search"]: - logger.info("[Search] %s ...", ", ".join(self.search["search_types"])) - _add_search_data = await search_all( - llm_client=self.synthesizer_llm_client, - search_types=self.search["search_types"], - kg_instance=_add_entities_and_relations, - ) - if _add_search_data: - await self.search_storage.upsert(_add_search_data) - logger.info("[Search] %d entities searched", len(_add_search_data)) - await self._insert_done() async def _insert_done(self): @@ -229,6 +216,46 @@ async def _insert_done(self): tasks.append(cast(StorageNameSpace, storage_instance).index_done_callback()) await asyncio.gather(*tasks) + def search(self): + loop = create_event_loop() + loop.run_until_complete(self.async_search()) + + async def async_search(self): + logger.info( + "Search is %s", "enabled" if self.search_config["enabled"] else "disabled" + ) + if self.search_config["enabled"]: + logger.info( + "[Search] %s ...", ", ".join(self.search_config["search_types"]) + ) + all_nodes = await self.graph_storage.get_all_nodes() + all_nodes_names = [node[0] for node in all_nodes] + new_search_entities = await self.full_docs_storage.filter_keys( + all_nodes_names + ) + logger.info( + "[Search] Found %d entities to search", len(new_search_entities) + ) + _add_search_data = await search_all( + llm_client=self.synthesizer_llm_client, + search_types=self.search_config["search_types"], + kg_instance=self.graph_storage, + ) + if _add_search_data: + await self.search_storage.upsert(_add_search_data) + logger.info("[Search] %d entities searched", len(_add_search_data)) + + # Format search results for inserting + search_results = [] + for _, search_data in _add_search_data.items(): + search_results.extend( + [ + {"content": search_data[key]} + for key in list(search_data.keys()) + ] + ) + await self.async_insert(search_results, "raw") + def quiz(self, max_samples=1): loop = create_event_loop() loop.run_until_complete(self.async_quiz(max_samples)) diff --git a/graphgen/operators/search/kg/search_wikipedia.py b/graphgen/operators/search/kg/search_wikipedia.py index b594f5da..7d25d767 100644 --- a/graphgen/operators/search/kg/search_wikipedia.py +++ b/graphgen/operators/search/kg/search_wikipedia.py @@ -68,9 +68,7 @@ async def search_wikipedia( nodes = list(nodes) wiki_data = {} - async for node in tqdm_async( - (node for node in nodes), desc="Searching Wikipedia", total=len(nodes) - ): + async for node in tqdm_async(nodes, desc="Searching Wikipedia", total=len(nodes)): entity_name = node[0].strip('"') description = node[1]["description"] try: @@ -78,7 +76,6 @@ async def search_wikipedia( entity_name, description, llm_client, wiki_search_client ) wiki_data[entity] = summary - logger.info("Searched entity: %s, Summary: %s", entity, summary) except Exception as e: # pylint: disable=broad-except logger.error("Error processing entity %s: %s", entity_name, str(e)) return wiki_data diff --git a/graphgen/operators/search/search_all.py b/graphgen/operators/search/search_all.py index 0d2093a4..f8cbde2f 100644 --- a/graphgen/operators/search/search_all.py +++ b/graphgen/operators/search/search_all.py @@ -4,7 +4,7 @@ async def search_all( llm_client: OpenAIModel, search_types: dict, kg_instance: NetworkXStorage -) -> dict[str, str]: +) -> dict[str, dict[str, str]]: """ :param llm_client :param search_types @@ -12,6 +12,8 @@ async def search_all( :return: nodes with search results """ + # 增量建图时,只需要搜索新增实体 + results = {} for search_type in search_types: @@ -21,7 +23,12 @@ async def search_all( wiki_search_client = WikiSearch() - await search_wikipedia(llm_client, wiki_search_client, kg_instance) + wiki_results = await search_wikipedia( + llm_client, wiki_search_client, kg_instance + ) + for entity_name, description in wiki_results.items(): + if description: + results[entity_name] = {"wikipedia": description} # elif search_type == "google": # from graphgen.operators.search.web.search_google import search_google # return await search_google(llm_client, kg_instance) diff --git a/graphgen/utils/__init__.py b/graphgen/utils/__init__.py index 932f8df1..13881c10 100644 --- a/graphgen/utils/__init__.py +++ b/graphgen/utils/__init__.py @@ -1,9 +1,15 @@ -from .log import logger, set_logger, parse_log -from .loop import create_event_loop -from .format import (pack_history_conversations, split_string_by_multi_markers, - handle_single_entity_extraction, handle_single_relationship_extraction, - load_json, write_json) -from .hash import compute_content_hash, compute_args_hash -from .detect_lang import detect_main_language, detect_if_chinese from .calculate_confidence import yes_no_loss_entropy +from .detect_lang import detect_if_chinese, detect_main_language +from .file import read_file +from .format import ( + handle_single_entity_extraction, + handle_single_relationship_extraction, + load_json, + pack_history_conversations, + split_string_by_multi_markers, + write_json, +) +from .hash import compute_args_hash, compute_content_hash from .help_nltk import NLTKHelper +from .log import logger, parse_log, set_logger +from .loop import create_event_loop diff --git a/graphgen/utils/file.py b/graphgen/utils/file.py new file mode 100644 index 00000000..11298616 --- /dev/null +++ b/graphgen/utils/file.py @@ -0,0 +1,24 @@ +import json + + +def read_file(input_file: str) -> list: + """ + Read data from a file based on the specified data type. + :param input_file + :return: + """ + + if input_file.endswith(".jsonl"): + with open(input_file, "r", encoding="utf-8") as f: + data = [json.loads(line) for line in f] + elif input_file.endswith(".json"): + with open(input_file, "r", encoding="utf-8") as f: + data = json.load(f) + elif input_file.endswith(".txt"): + with open(input_file, "r", encoding="utf-8") as f: + data = [line.strip() for line in f if line.strip()] + data = [{"content": line} for line in data] + else: + raise ValueError(f"Unsupported file format: {input_file}") + + return data diff --git a/resources/examples/keywords_demo.txt b/resources/examples/keywords_demo.txt new file mode 100644 index 00000000..0d4c47eb --- /dev/null +++ b/resources/examples/keywords_demo.txt @@ -0,0 +1,5 @@ +TATA Box +TFBS +E-box +Enhancer +AP-1 From 760b27e3e3dcf251a5be279848d31e41af15b1d9 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 31 Jul 2025 16:17:30 +0800 Subject: [PATCH 3/6] feat: add google search --- graphgen/configs/graphgen_config.yaml | 2 +- graphgen/graphgen.py | 3 +- graphgen/models/__init__.py | 5 +- graphgen/models/search/kg/wiki_search.py | 4 +- graphgen/models/search/web/__init__.py | 0 graphgen/models/search/web/google_search.py | 45 +++++++++++ .../operators/search/kg/search_wikipedia.py | 77 +++++++------------ graphgen/operators/search/search_all.py | 43 +++++++---- graphgen/operators/search/web/search_bing.py | 10 +++ .../operators/search/web/search_google.py | 49 ++++++++++++ requirements.txt | 3 + 11 files changed, 171 insertions(+), 70 deletions(-) create mode 100644 graphgen/models/search/web/__init__.py create mode 100644 graphgen/models/search/web/google_search.py diff --git a/graphgen/configs/graphgen_config.yaml b/graphgen/configs/graphgen_config.yaml index 6239b104..b535ffbf 100644 --- a/graphgen/configs/graphgen_config.yaml +++ b/graphgen/configs/graphgen_config.yaml @@ -14,5 +14,5 @@ traverse_strategy: loss_strategy: only_edge search: enabled: true - search_types: ["wikipedia", "google"] + search_types: ["google"] re_judge: false diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 35df98db..896f25b9 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -237,9 +237,8 @@ async def async_search(self): "[Search] Found %d entities to search", len(new_search_entities) ) _add_search_data = await search_all( - llm_client=self.synthesizer_llm_client, search_types=self.search_config["search_types"], - kg_instance=self.graph_storage, + search_entities=new_search_entities, ) if _add_search_data: await self.search_storage.upsert(_add_search_data) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index a9190fa5..8012ab79 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -1,5 +1,3 @@ -from graphgen.models.search.kg.wiki_search import WikiSearch - from .evaluate.length_evaluator import LengthEvaluator from .evaluate.mtld_evaluator import MTLDEvaluator from .evaluate.reward_evaluator import RewardEvaluator @@ -7,6 +5,8 @@ from .llm.openai_model import OpenAIModel from .llm.tokenizer import Tokenizer from .llm.topk_token_model import Token, TopkTokenModel +from .search.kg.wiki_search import WikiSearch +from .search.web.google_search import GoogleSearch from .storage.json_storage import JsonKVStorage from .storage.networkx_storage import NetworkXStorage from .strategy.travserse_strategy import TraverseStrategy @@ -25,6 +25,7 @@ "JsonKVStorage", # search models "WikiSearch", + "GoogleSearch", # evaluate models "TextPair", "LengthEvaluator", diff --git a/graphgen/models/search/kg/wiki_search.py b/graphgen/models/search/kg/wiki_search.py index cb080bc0..e9513f21 100644 --- a/graphgen/models/search/kg/wiki_search.py +++ b/graphgen/models/search/kg/wiki_search.py @@ -14,9 +14,9 @@ def set_language(language: str): assert language in ["en", "zh"], "Only support English and Chinese" set_lang(language) - async def search(self, query: str) -> Union[List[str], None]: + async def search(self, query: str, num_results: int = 1) -> Union[List[str], None]: self.set_language(detect_main_language(query)) - return wikipedia.search(query) + return wikipedia.search(query, results=num_results, suggestion=False) async def summary(self, query: str) -> Union[str, None]: self.set_language(detect_main_language(query)) diff --git a/graphgen/models/search/web/__init__.py b/graphgen/models/search/web/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/models/search/web/google_search.py b/graphgen/models/search/web/google_search.py new file mode 100644 index 00000000..1abfcdf3 --- /dev/null +++ b/graphgen/models/search/web/google_search.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass + +import requests +from fastapi import HTTPException + +from graphgen.utils import logger + +GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1" + + +@dataclass +class GoogleSearch: + def __init__(self, subscription_key: str, cx: str): + """ + Initialize the Google Search client with the subscription key and custom search engine ID. + :param subscription_key: Your Google API subscription key. + :param cx: Your custom search engine ID. + """ + self.subscription_key = subscription_key + self.cx = cx + + def search(self, query: str, num_results: int = 1): + """ + Search with Google and return the contexts. + :param query: The search query. + :param num_results: The number of results to return. + :return: A list of search results. + """ + params = { + "key": self.subscription_key, + "cx": self.cx, + "q": query, + "num": num_results, + } + response = requests.get(GOOGLE_SEARCH_ENDPOINT, params=params, timeout=10) + if not response.ok: + logger.error("Search engine error: %s", response.text) + raise HTTPException(response.status_code, "Search engine error.") + json_content = response.json() + try: + contexts = json_content["items"][:num_results] + except KeyError: + logger.error("Error encountered: %s", json_content) + return [] + return contexts diff --git a/graphgen/operators/search/kg/search_wikipedia.py b/graphgen/operators/search/kg/search_wikipedia.py index 7d25d767..dd3a35ba 100644 --- a/graphgen/operators/search/kg/search_wikipedia.py +++ b/graphgen/operators/search/kg/search_wikipedia.py @@ -1,81 +1,58 @@ from tqdm.asyncio import tqdm_asyncio as tqdm_async -from graphgen.models import NetworkXStorage, OpenAIModel, WikiSearch -from graphgen.templates import SEARCH_JUDGEMENT_PROMPT +from graphgen.models import WikiSearch from graphgen.utils import logger async def _process_single_entity( entity_name: str, - description: str, - llm_client: OpenAIModel, wiki_search_client: WikiSearch, -) -> tuple[str, None] | tuple[str, str]: +) -> str | None: """ - Process single entity - + Process single entity by searching Wikipedia + :param entity_name + :param wiki_search_client + :return: summary of the entity or None if not found """ search_results = await wiki_search_client.search(entity_name) if not search_results: - return entity_name, None - examples = "\n".join(SEARCH_JUDGEMENT_PROMPT["EXAMPLES"]) - search_results.append("None of the above") + return None - search_results_str = "\n".join( - [f"{i + 1}. {sr}" for i, sr in enumerate(search_results)] - ) - prompt = SEARCH_JUDGEMENT_PROMPT["TEMPLATE"].format( - examples=examples, - entity_name=entity_name, - description=description, - search_results=search_results_str, - ) - response = await llm_client.generate_answer(prompt) + summary = None try: - response = response.strip() - response = int(response) - if response < 1 or response >= len(search_results): - response = None - else: - response = await wiki_search_client.summary(search_results[response - 1]) - except ValueError: - response = None - - logger.info( - "Entity %s search result: %s response: %s", - entity_name, - str(search_results), - response, - ) + summary = await wiki_search_client.summary(search_results[-1]) + logger.info( + "Entity %s search result: %s summary: %s", + entity_name, + str(search_results), + summary, + ) + except Exception as e: # pylint: disable=broad-except + logger.error("Error processing entity %s: %s", entity_name, str(e)) - return entity_name, response + return summary async def search_wikipedia( - llm_client: OpenAIModel, wiki_search_client: WikiSearch, - kg_instance: NetworkXStorage, + entities: set[str], ) -> dict: """ Search wikipedia for entities - :param llm_client: LLM model :param wiki_search_client: wiki search client - :param kg_instance: knowledge graph instance + :param entities: list of entities to search :return: nodes with search results """ - nodes = await kg_instance.get_all_nodes() - nodes = list(nodes) wiki_data = {} - async for node in tqdm_async(nodes, desc="Searching Wikipedia", total=len(nodes)): - entity_name = node[0].strip('"') - description = node[1]["description"] + async for entity in tqdm_async( + entities, desc="Searching Wikipedia", total=len(entities) + ): try: - entity, summary = await _process_single_entity( - entity_name, description, llm_client, wiki_search_client - ) - wiki_data[entity] = summary + entity, summary = await _process_single_entity(entity, wiki_search_client) + if summary: + wiki_data[entity] = summary except Exception as e: # pylint: disable=broad-except - logger.error("Error processing entity %s: %s", entity_name, str(e)) + logger.error("Error processing entity %s: %s", entity, str(e)) return wiki_data diff --git a/graphgen/operators/search/search_all.py b/graphgen/operators/search/search_all.py index f8cbde2f..94d21e2c 100644 --- a/graphgen/operators/search/search_all.py +++ b/graphgen/operators/search/search_all.py @@ -1,19 +1,27 @@ -from graphgen.models import NetworkXStorage, OpenAIModel +""" +To use Google Web Search API, +follow the instructions [here](https://developers.google.com/custom-search/v1/overview) +to get your Google search api key. + +To use Bing Web Search API, +follow the instructions [here](https://www.microsoft.com/en-us/bing/apis/bing-web-search-api) +and obtain your Bing subscription key. +""" + +import os + from graphgen.utils import logger async def search_all( - llm_client: OpenAIModel, search_types: dict, kg_instance: NetworkXStorage + search_types: dict, search_entities: set[str] ) -> dict[str, dict[str, str]]: """ - :param llm_client :param search_types - :param kg_instance + :param search_entities: list of entities to search :return: nodes with search results """ - # 增量建图时,只需要搜索新增实体 - results = {} for search_type in search_types: @@ -23,16 +31,25 @@ async def search_all( wiki_search_client = WikiSearch() - wiki_results = await search_wikipedia( - llm_client, wiki_search_client, kg_instance - ) + wiki_results = await search_wikipedia(wiki_search_client, search_entities) for entity_name, description in wiki_results.items(): if description: results[entity_name] = {"wikipedia": description} - # elif search_type == "google": - # from graphgen.operators.search.web.search_google import search_google - # return await search_google(llm_client, kg_instance) - # + elif search_type == "google": + from graphgen.models import GoogleSearch + from graphgen.operators.search.web.search_google import search_google + + google_search_client = GoogleSearch( + subscription_key=os.environ["GOOGLE_SEARCH_API_KEY"], + cx=os.environ["GOOGLE_SEARCH_CX"], + ) + + google_results = await search_google(google_search_client, search_entities) + for entity_name, description in google_results.items(): + if description: + results[entity_name] = results.get(entity_name, {}) + results[entity_name]["google"] = description + # elif search_type == "bing": # from graphgen.operators.search.web.search_bing import search_bing # return await search_bing(llm_client, kg_instance) diff --git a/graphgen/operators/search/web/search_bing.py b/graphgen/operators/search/web/search_bing.py index e69de29b..1a58c2d2 100644 --- a/graphgen/operators/search/web/search_bing.py +++ b/graphgen/operators/search/web/search_bing.py @@ -0,0 +1,10 @@ +BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search" +BING_MKT = "en-US" + + +async def search_bing(): + """ + Search with Bing and return the contexts. + :return: + """ + raise NotImplementedError("Bing search is not implemented yet.") diff --git a/graphgen/operators/search/web/search_google.py b/graphgen/operators/search/web/search_google.py index e69de29b..803ce107 100644 --- a/graphgen/operators/search/web/search_google.py +++ b/graphgen/operators/search/web/search_google.py @@ -0,0 +1,49 @@ +import trafilatura +from tqdm.asyncio import tqdm_asyncio as tqdm_async + +from graphgen.models import GoogleSearch +from graphgen.utils import logger + + +async def _process_single_entity( + entity_name: str, google_search_client: GoogleSearch +) -> str | None: + search_results = google_search_client.search(entity_name) + if not search_results: + return None + + # Get more details from the first search result + first_result = search_results[0] + content = trafilatura.fetch_url(first_result["link"]) + summary = trafilatura.extract(content, include_comments=False, include_links=False) + summary = summary.strip() + logger.info( + "Entity %s search result: %s", + entity_name, + summary, + ) + return summary + + +async def search_google( + google_search_client: GoogleSearch, + entities: set[str], +) -> dict: + """ + Search with Google and return the contexts. + :param google_search_client: Google search client + :param entities: list of entities to search + :return: + """ + google_data = {} + + async for entity in tqdm_async( + entities, desc="Searching Google", total=len(entities) + ): + try: + summary = await _process_single_entity(entity, google_search_client) + if summary: + google_data[entity] = summary + except Exception as e: # pylint: disable=broad-except + logger.error("Error processing entity %s: %s", entity, str(e)) + return google_data diff --git a/requirements.txt b/requirements.txt index ab329cb5..f169cb09 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,3 +17,6 @@ gradio-i18n==0.3.0 kaleido pyyaml langcodes +requests +fastapi +trafilatura From e63968bb0bda9a478e8a6acfd7ab49842a012f9b Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 31 Jul 2025 19:56:40 +0800 Subject: [PATCH 4/6] feat: uniprot search --- graphgen/models/__init__.py | 4 ++ .../search/db/__init__.py} | 0 graphgen/models/search/db/uniprot_search.py | 64 +++++++++++++++++++ graphgen/models/search/web/bing_search.py | 43 +++++++++++++ .../operators/search/kg/search_google_kg.py | 0 graphgen/operators/search/search_all.py | 27 +++++++- graphgen/operators/search/web/search_bing.py | 51 +++++++++++++-- 7 files changed, 182 insertions(+), 7 deletions(-) rename graphgen/{operators/search/db/search_mongodb.py => models/search/db/__init__.py} (100%) create mode 100644 graphgen/models/search/db/uniprot_search.py delete mode 100644 graphgen/operators/search/kg/search_google_kg.py diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 8012ab79..7e1f6e8a 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -5,7 +5,9 @@ from .llm.openai_model import OpenAIModel from .llm.tokenizer import Tokenizer from .llm.topk_token_model import Token, TopkTokenModel +from .search.db.uniprot_search import UniProtSearch from .search.kg.wiki_search import WikiSearch +from .search.web.bing_search import BingSearch from .search.web.google_search import GoogleSearch from .storage.json_storage import JsonKVStorage from .storage.networkx_storage import NetworkXStorage @@ -26,6 +28,8 @@ # search models "WikiSearch", "GoogleSearch", + "BingSearch", + "UniProtSearch", # evaluate models "TextPair", "LengthEvaluator", diff --git a/graphgen/operators/search/db/search_mongodb.py b/graphgen/models/search/db/__init__.py similarity index 100% rename from graphgen/operators/search/db/search_mongodb.py rename to graphgen/models/search/db/__init__.py diff --git a/graphgen/models/search/db/uniprot_search.py b/graphgen/models/search/db/uniprot_search.py new file mode 100644 index 00000000..96bdd99c --- /dev/null +++ b/graphgen/models/search/db/uniprot_search.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass + +import requests +from fastapi import HTTPException + +from graphgen.utils import logger + +UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search" + + +@dataclass +class UniProtSearch: + """ + UniProt Search client to search with UniProt. + 1) Get the protein by accession number. + 2) Search with keywords or protein names. + """ + + def get_entry(self, accession: str) -> dict: + """ + Get the UniProt entry by accession number(e.g., P04637). + """ + url = f"{UNIPROT_BASE}/{accession}.json" + return self._safe_get(url).json() + + def search( + self, + query: str, + *, + size: int = 10, + cursor: str = None, + fields: list[str] = None, + ) -> dict: + """ + Search UniProt with a query string. + :param query: The search query. + :param size: The number of results to return. + :param cursor: The cursor for pagination. + :param fields: The fields to return in the response. + :return: A dictionary containing the search results. + """ + params = { + "query": query, + "size": size, + } + if cursor: + params["cursor"] = cursor + if fields: + params["fields"] = ",".join(fields) + url = UNIPROT_BASE + return self._safe_get(url, params=params).json() + + @staticmethod + def _safe_get(url: str, params: dict = None) -> requests.Response: + r = requests.get( + url, + params=params, + headers={"Accept": "application/json"}, + timeout=10, + ) + if not r.ok: + logger.error("Search engine error: %s", r.text) + raise HTTPException(r.status_code, "Search engine error.") + return r diff --git a/graphgen/models/search/web/bing_search.py b/graphgen/models/search/web/bing_search.py index e69de29b..a769ba76 100644 --- a/graphgen/models/search/web/bing_search.py +++ b/graphgen/models/search/web/bing_search.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass + +import requests +from fastapi import HTTPException + +from graphgen.utils import logger + +BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search" +BING_MKT = "en-US" + + +@dataclass +class BingSearch: + """ + Bing Search client to search with Bing. + """ + + subscription_key: str + + def search(self, query: str, num_results: int = 1): + """ + Search with Bing and return the contexts. + :param query: The search query. + :param num_results: The number of results to return. + :return: A list of search results. + """ + params = {"q": query, "mkt": BING_MKT, "count": num_results} + response = requests.get( + BING_SEARCH_V7_ENDPOINT, + headers={"Ocp-Apim-Subscription-Key": self.subscription_key}, + params=params, + timeout=10, + ) + if not response.ok: + logger.error("Search engine error: %s", response.text) + raise HTTPException(response.status_code, "Search engine error.") + json_content = response.json() + try: + contexts = json_content["webPages"]["value"][:num_results] + except KeyError: + logger.error("Error encountered: %s", json_content) + return [] + return contexts diff --git a/graphgen/operators/search/kg/search_google_kg.py b/graphgen/operators/search/kg/search_google_kg.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graphgen/operators/search/search_all.py b/graphgen/operators/search/search_all.py index 94d21e2c..d7ecbea1 100644 --- a/graphgen/operators/search/search_all.py +++ b/graphgen/operators/search/search_all.py @@ -49,10 +49,31 @@ async def search_all( if description: results[entity_name] = results.get(entity_name, {}) results[entity_name]["google"] = description + elif search_type == "bing": + from graphgen.models import BingSearch + from graphgen.operators.search.web.search_bing import search_bing - # elif search_type == "bing": - # from graphgen.operators.search.web.search_bing import search_bing - # return await search_bing(llm_client, kg_instance) + bing_search_client = BingSearch( + subscription_key=os.environ["BING_SEARCH_API_KEY"] + ) + + bing_results = await search_bing(bing_search_client, search_entities) + for entity_name, description in bing_results.items(): + if description: + results[entity_name] = results.get(entity_name, {}) + results[entity_name]["bing"] = description + elif search_type == "uniprot": + # from graphgen.models import UniProtSearch + # from graphgen.operators.search.db.search_uniprot import search_uniprot + # + # uniprot_search_client = UniProtSearch() + # + # uniprot_results = await search_uniprot( + # uniprot_search_client, search_entities + # ) + raise NotImplementedError( + "Processing of UniProt search results is not implemented yet." + ) else: logger.error("Search type %s is not supported yet.", search_type) diff --git a/graphgen/operators/search/web/search_bing.py b/graphgen/operators/search/web/search_bing.py index 1a58c2d2..69f65f7b 100644 --- a/graphgen/operators/search/web/search_bing.py +++ b/graphgen/operators/search/web/search_bing.py @@ -1,10 +1,53 @@ -BING_SEARCH_V7_ENDPOINT = "https://api.bing.microsoft.com/v7.0/search" -BING_MKT = "en-US" +import trafilatura +from tqdm.asyncio import tqdm_asyncio as tqdm_async +from graphgen.models import BingSearch +from graphgen.utils import logger -async def search_bing(): + +async def _process_single_entity( + entity_name: str, bing_search_client: BingSearch +) -> str | None: + """ + Process single entity by searching Bing. + :param entity_name: The name of the entity to search. + :param bing_search_client: The Bing search client. + :return: Summary of the entity or None if not found. + """ + search_results = bing_search_client.search(entity_name) + if not search_results: + return None + + # Get more details from the first search result + first_result = search_results[0] + content = trafilatura.fetch_url(first_result["url"]) + summary = trafilatura.extract(content, include_comments=False, include_links=False) + summary = summary.strip() + logger.info( + "Entity %s search result: %s", + entity_name, + summary, + ) + return summary + + +async def search_bing( + bing_search_client: BingSearch, + entities: set[str], +) -> dict[str, str]: """ Search with Bing and return the contexts. :return: """ - raise NotImplementedError("Bing search is not implemented yet.") + bing_data = {} + + async for entity in tqdm_async( + entities, desc="Searching Bing", total=len(entities) + ): + try: + summary = await _process_single_entity(entity, bing_search_client) + if summary: + bing_data[entity] = summary + except Exception as e: # pylint: disable=broad-except + logger.error("Error processing entity %s: %s", entity, str(e)) + return bing_data From 4180b9b33be4c47e943cf48020dcc11a90195e04 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 31 Jul 2025 20:32:04 +0800 Subject: [PATCH 5/6] fix: fix async_clear() --- graphgen/graphgen.py | 2 +- webui/app.py | 481 +++++++++++++++++++++++------------------ webui/translation.json | 26 ++- 3 files changed, 291 insertions(+), 218 deletions(-) diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 896f25b9..a8bfc63d 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -328,7 +328,7 @@ def clear(self): async def async_clear(self): await self.full_docs_storage.drop() await self.text_chunks_storage.drop() - await self.wiki_storage.drop() + await self.search_storage.drop() await self.graph_storage.clear() await self.rephrase_storage.drop() await self.qa_storage.drop() diff --git a/webui/app.py b/webui/app.py index 4fc5e517..153f159a 100644 --- a/webui/app.py +++ b/webui/app.py @@ -1,17 +1,16 @@ +import json import os import sys -import json import tempfile -import pandas as pd import gradio as gr - -from gradio_i18n import Translate, gettext as _ - +import pandas as pd from base import GraphGenParams -from test_api import test_api_connection -from cache_utils import setup_workspace, cleanup_workspace +from cache_utils import cleanup_workspace, setup_workspace from count_tokens import count_tokens +from gradio_i18n import Translate +from gradio_i18n import gettext as _ +from test_api import test_api_connection # pylint: disable=wrong-import-position root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -22,7 +21,6 @@ from graphgen.models.llm.limitter import RPM, TPM from graphgen.utils import set_logger - css = """ .center-row { display: flex; @@ -37,9 +35,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache")) set_logger(log_file, if_stream=False) - graph_gen = GraphGen( - working_dir=working_dir - ) + graph_gen = GraphGen(working_dir=working_dir) # Set up LLM clients graph_gen.synthesizer_llm_client = OpenAIModel( @@ -47,8 +43,8 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: base_url=env.get("SYNTHESIZER_BASE_URL", ""), api_key=env.get("SYNTHESIZER_API_KEY", ""), request_limit=True, - rpm= RPM(env.get("RPM", 1000)), - tpm= TPM(env.get("TPM", 50000)), + rpm=RPM(env.get("RPM", 1000)), + tpm=TPM(env.get("TPM", 50000)), ) graph_gen.trainee_llm_client = OpenAIModel( @@ -56,12 +52,11 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: base_url=env.get("TRAINEE_BASE_URL", ""), api_key=env.get("TRAINEE_API_KEY", ""), request_limit=True, - rpm= RPM(env.get("RPM", 1000)), - tpm= TPM(env.get("TPM", 50000)), + rpm=RPM(env.get("RPM", 1000)), + tpm=TPM(env.get("TPM", 50000)), ) - graph_gen.tokenizer_instance = Tokenizer( - config.get("tokenizer", "cl100k_base")) + graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base")) strategy_config = config.get("traverse_strategy", {}) graph_gen.traverse_strategy = TraverseStrategy( @@ -73,11 +68,12 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: max_depth=strategy_config.get("max_depth"), edge_sampling=strategy_config.get("edge_sampling"), isolated_node_strategy=strategy_config.get("isolated_node_strategy"), - loss_strategy=str(strategy_config.get("loss_strategy")) + loss_strategy=str(strategy_config.get("loss_strategy")), ) return graph_gen + # pylint: disable=too-many-statements def run_graphgen(params, progress=gr.Progress()): def sum_tokens(client): @@ -88,7 +84,6 @@ def sum_tokens(client): "input_file": params.input_file, "tokenizer": params.tokenizer, "qa_form": params.qa_form, - "web_search": False, "quiz_samples": params.quiz_samples, "traverse_strategy": { "bidirectional": params.bidirectional, @@ -98,7 +93,7 @@ def sum_tokens(client): "max_depth": params.max_depth, "edge_sampling": params.edge_sampling, "isolated_node_strategy": params.isolated_node_strategy, - "loss_strategy": params.loss_strategy + "loss_strategy": params.loss_strategy, }, "chunk_size": params.chunk_size, } @@ -115,11 +110,15 @@ def sum_tokens(client): } # Test API connection - test_api_connection(env["SYNTHESIZER_BASE_URL"], - env["SYNTHESIZER_API_KEY"], env["SYNTHESIZER_MODEL"]) - if config['if_trainee_model']: - test_api_connection(env["TRAINEE_BASE_URL"], - env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"]) + test_api_connection( + env["SYNTHESIZER_BASE_URL"], + env["SYNTHESIZER_API_KEY"], + env["SYNTHESIZER_MODEL"], + ) + if config["if_trainee_model"]: + test_api_connection( + env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"] + ) # Initialize GraphGen graph_gen = init_graph_gen(config, env) @@ -129,7 +128,7 @@ def sum_tokens(client): try: # Load input data - file = config['input_file'] + file = config["input_file"] if isinstance(file, list): file = file[0] @@ -137,24 +136,22 @@ def sum_tokens(client): if file.endswith(".jsonl"): data_type = "raw" - with open(file, "r", encoding='utf-8') as f: + with open(file, "r", encoding="utf-8") as f: data.extend(json.loads(line) for line in f) elif file.endswith(".json"): data_type = "chunked" - with open(file, "r", encoding='utf-8') as f: + with open(file, "r", encoding="utf-8") as f: data.extend(json.load(f)) elif file.endswith(".txt"): # 读取文件后根据chunk_size转成raw格式的数据 data_type = "raw" content = "" - with open(file, "r", encoding='utf-8') as f: + with open(file, "r", encoding="utf-8") as f: lines = f.readlines() for line in lines: content += line.strip() + " " size = int(config.get("chunk_size", 512)) - chunks = [ - content[i:i + size] for i in range(0, len(content), size) - ] + chunks = [content[i : i + size] for i in range(0, len(content), size)] data.extend([{"content": chunk} for chunk in chunks]) else: raise ValueError(f"Unsupported file type: {file}") @@ -162,9 +159,9 @@ def sum_tokens(client): # Process the data graph_gen.insert(data, data_type) - if config['if_trainee_model']: + if config["if_trainee_model"]: # Generate quiz - graph_gen.quiz(max_samples=config['quiz_samples']) + graph_gen.quiz(max_samples=config["quiz_samples"]) # Judge statements graph_gen.judge() @@ -179,42 +176,39 @@ def sum_tokens(client): # Save output output_data = graph_gen.qa_storage.data with tempfile.NamedTemporaryFile( - mode="w", - suffix=".jsonl", - delete=False, - encoding="utf-8") as tmpfile: + mode="w", suffix=".jsonl", delete=False, encoding="utf-8" + ) as tmpfile: json.dump(output_data, tmpfile, ensure_ascii=False) output_file = tmpfile.name synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client) - trainee_tokens = sum_tokens(graph_gen.trainee_llm_client) if config['if_trainee_model'] else 0 + trainee_tokens = ( + sum_tokens(graph_gen.trainee_llm_client) + if config["if_trainee_model"] + else 0 + ) total_tokens = synthesizer_tokens + trainee_tokens data_frame = params.token_counter try: _update_data = [ - [ - data_frame.iloc[0, 0], - data_frame.iloc[0, 1], - str(total_tokens) - ] + [data_frame.iloc[0, 0], data_frame.iloc[0, 1], str(total_tokens)] ] - new_df = pd.DataFrame( - _update_data, - columns=data_frame.columns - ) + new_df = pd.DataFrame(_update_data, columns=data_frame.columns) data_frame = new_df except Exception as e: raise gr.Error(f"DataFrame operation error: {str(e)}") - return output_file, gr.DataFrame(label='Token Stats', - headers=["Source Text Token Count", "Expected Token Usage", "Token Used"], - datatype="str", - interactive=False, - value=data_frame, - visible=True, - wrap=True) + return output_file, gr.DataFrame( + label="Token Stats", + headers=["Source Text Token Count", "Expected Token Usage", "Token Used"], + datatype="str", + interactive=False, + value=data_frame, + visible=True, + wrap=True, + ) except Exception as e: # pylint: disable=broad-except raise gr.Error(f"Error occurred: {str(e)}") @@ -223,16 +217,18 @@ def sum_tokens(client): # Clean up workspace cleanup_workspace(graph_gen.working_dir) -with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), - css=css) as demo): + +with gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), css=css) as demo: # Header - gr.Image(value=os.path.join(root_dir, 'resources', 'images', 'logo.png'), - label="GraphGen Banner", - elem_id="banner", - interactive=False, - container=False, - show_download_button=False, - show_fullscreen_button=False) + gr.Image( + value=os.path.join(root_dir, "resources", "images", "logo.png"), + label="GraphGen Banner", + elem_id="banner", + interactive=False, + container=False, + show_download_button=False, + show_fullscreen_button=False, + ) lang_btn = gr.Radio( choices=[ ("English", "en"), @@ -245,7 +241,8 @@ def sum_tokens(client): elem_classes=["center-row"], ) - gr.HTML(""" + gr.HTML( + """ - """) + """ + ) with Translate( - os.path.join(root_dir, 'webui', 'translation.json'), - lang_btn, - placeholder_langs=["en", "zh"], - persistant= - False, # True to save the language setting in the browser. Requires gradio >= 5.6.0 + os.path.join(root_dir, "webui", "translation.json"), + lang_btn, + placeholder_langs=["en", "zh"], + persistant=False, # True to save the language setting in the browser. Requires gradio >= 5.6.0 ): lang_btn.render() gr.Markdown( - value = "# " + _("Title") + "\n\n" + \ - "### [GraphGen](https://github.com/open-sciencelab/GraphGen) " + _("Intro") + value="# " + + _("Title") + + "\n\n" + + "### [GraphGen](https://github.com/open-sciencelab/GraphGen) " + + _("Intro") ) - if_trainee_model = gr.Checkbox(label=_("Use Trainee Model"), - value=False, - interactive=True) + if_trainee_model = gr.Checkbox( + label=_("Use Trainee Model"), value=False, interactive=True + ) with gr.Accordion(label=_("Model Config"), open=False): - synthesizer_url = gr.Textbox(label="Synthesizer URL", - value="https://api.siliconflow.cn/v1", - info=_("Synthesizer URL Info"), - interactive=True) - synthesizer_model = gr.Textbox(label="Synthesizer Model", - value="Qwen/Qwen2.5-7B-Instruct", - info=_("Synthesizer Model Info"), - interactive=True) - trainee_url = gr.Textbox(label="Trainee URL", - value="https://api.siliconflow.cn/v1", - info=_("Trainee URL Info"), - interactive=True, - visible=if_trainee_model.value is True) + synthesizer_url = gr.Textbox( + label="Synthesizer URL", + value="https://api.siliconflow.cn/v1", + info=_("Synthesizer URL Info"), + interactive=True, + ) + synthesizer_model = gr.Textbox( + label="Synthesizer Model", + value="Qwen/Qwen2.5-7B-Instruct", + info=_("Synthesizer Model Info"), + interactive=True, + ) + trainee_url = gr.Textbox( + label="Trainee URL", + value="https://api.siliconflow.cn/v1", + info=_("Trainee URL Info"), + interactive=True, + visible=if_trainee_model.value is True, + ) trainee_model = gr.Textbox( label="Trainee Model", value="Qwen/Qwen2.5-7B-Instruct", info=_("Trainee Model Info"), interactive=True, - visible=if_trainee_model.value is True) + visible=if_trainee_model.value is True, + ) trainee_api_key = gr.Textbox( - label=_("SiliconFlow Token for Trainee Model"), - type="password", - value="", - info="https://cloud.siliconflow.cn/account/ak", - visible=if_trainee_model.value is True) - + label=_("SiliconFlow Token for Trainee Model"), + type="password", + value="", + info="https://cloud.siliconflow.cn/account/ak", + visible=if_trainee_model.value is True, + ) with gr.Accordion(label=_("Generation Config"), open=False): - chunk_size = gr.Slider(label="Chunk Size", - minimum=256, - maximum=4096, - value=512, - step=256, - interactive=True) - tokenizer = gr.Textbox(label="Tokenizer", - value="cl100k_base", - interactive=True) - qa_form = gr.Radio(choices=["atomic", "multi_hop", "aggregated"], - label="QA Form", - value="aggregated", - interactive=True) - quiz_samples = gr.Number(label="Quiz Samples", - value=2, - minimum=1, - interactive=True, - visible=if_trainee_model.value is True) - bidirectional = gr.Checkbox(label="Bidirectional", - value=True, - interactive=True) - - expand_method = gr.Radio(choices=["max_width", "max_tokens"], - label="Expand Method", - value="max_tokens", - interactive=True) + chunk_size = gr.Slider( + label="Chunk Size", + minimum=256, + maximum=4096, + value=512, + step=256, + interactive=True, + ) + tokenizer = gr.Textbox( + label="Tokenizer", value="cl100k_base", interactive=True + ) + qa_form = gr.Radio( + choices=["atomic", "multi_hop", "aggregated"], + label="QA Form", + value="aggregated", + interactive=True, + ) + quiz_samples = gr.Number( + label="Quiz Samples", + value=2, + minimum=1, + interactive=True, + visible=if_trainee_model.value is True, + ) + bidirectional = gr.Checkbox( + label="Bidirectional", value=True, interactive=True + ) + + expand_method = gr.Radio( + choices=["max_width", "max_tokens"], + label="Expand Method", + value="max_tokens", + interactive=True, + ) max_extra_edges = gr.Slider( minimum=1, maximum=10, @@ -341,36 +356,45 @@ def sum_tokens(client): label="Max Extra Edges", step=1, interactive=True, - visible=expand_method.value == "max_width") - max_tokens = gr.Slider(minimum=64, - maximum=1024, - value=256, - label="Max Tokens", - step=64, - interactive=True, - visible=(expand_method.value - != "max_width")) - - max_depth = gr.Slider(minimum=1, - maximum=5, - value=2, - label="Max Depth", - step=1, - interactive=True) + visible=expand_method.value == "max_width", + ) + max_tokens = gr.Slider( + minimum=64, + maximum=1024, + value=256, + label="Max Tokens", + step=64, + interactive=True, + visible=(expand_method.value != "max_width"), + ) + + max_depth = gr.Slider( + minimum=1, + maximum=5, + value=2, + label="Max Depth", + step=1, + interactive=True, + ) edge_sampling = gr.Radio( choices=["max_loss", "min_loss", "random"], label="Edge Sampling", value="max_loss", interactive=True, - visible=if_trainee_model.value is True) - isolated_node_strategy = gr.Radio(choices=["add", "ignore"], - label="Isolated Node Strategy", - value="ignore", - interactive=True) - loss_strategy = gr.Radio(choices=["only_edge", "both"], - label="Loss Strategy", - value="only_edge", - interactive=True) + visible=if_trainee_model.value is True, + ) + isolated_node_strategy = gr.Radio( + choices=["add", "ignore"], + label="Isolated Node Strategy", + value="ignore", + interactive=True, + ) + loss_strategy = gr.Radio( + choices=["only_edge", "both"], + label="Loss Strategy", + value="only_edge", + interactive=True, + ) with gr.Row(equal_height=True): with gr.Column(scale=3): @@ -378,7 +402,8 @@ def sum_tokens(client): label=_("SiliconFlow Token"), type="password", value="", - info="https://cloud.siliconflow.cn/account/ak") + info="https://cloud.siliconflow.cn/account/ak", + ) with gr.Column(scale=1): test_connection_btn = gr.Button(_("Test Connection")) @@ -392,7 +417,8 @@ def sum_tokens(client): value=1000, step=100, interactive=True, - visible=True) + visible=True, + ) with gr.Column(): tpm = gr.Slider( label="TPM", @@ -401,8 +427,8 @@ def sum_tokens(client): value=50000, step=1000, interactive=True, - visible=True) - + visible=True, + ) with gr.Blocks(): with gr.Row(equal_height=True): @@ -413,15 +439,17 @@ def sum_tokens(client): file_types=[".txt", ".json", ".jsonl"], interactive=True, ) - examples_dir = os.path.join(root_dir, 'webui', 'examples') - gr.Examples(examples=[ - [os.path.join(examples_dir, "txt_demo.txt")], - [os.path.join(examples_dir, "raw_demo.jsonl")], - [os.path.join(examples_dir, "chunked_demo.json")], - ], - inputs=upload_file, - label=_("Example Files"), - examples_per_page=3) + examples_dir = os.path.join(root_dir, "webui", "examples") + gr.Examples( + examples=[ + [os.path.join(examples_dir, "txt_demo.txt")], + [os.path.join(examples_dir, "raw_demo.jsonl")], + [os.path.join(examples_dir, "chunked_demo.json")], + ], + inputs=upload_file, + label=_("Example Files"), + examples_per_page=3, + ) with gr.Column(scale=1): output = gr.File( label="Output(See Github FAQ)", @@ -430,12 +458,18 @@ def sum_tokens(client): ) with gr.Blocks(): - token_counter = gr.DataFrame(label='Token Stats', - headers=["Source Text Token Count", "Estimated Token Usage", "Token Used"], - datatype="str", - interactive=False, - visible=False, - wrap=True) + token_counter = gr.DataFrame( + label="Token Stats", + headers=[ + "Source Text Token Count", + "Estimated Token Usage", + "Token Used", + ], + datatype="str", + interactive=False, + visible=False, + wrap=True, + ) submit_btn = gr.Button(_("Run GraphGen")) @@ -443,23 +477,36 @@ def sum_tokens(client): test_connection_btn.click( test_api_connection, inputs=[synthesizer_url, api_key, synthesizer_model], - outputs=[]) + outputs=[], + ) if if_trainee_model.value: - test_connection_btn.click(test_api_connection, - inputs=[trainee_url, api_key, trainee_model], - outputs=[]) + test_connection_btn.click( + test_api_connection, + inputs=[trainee_url, api_key, trainee_model], + outputs=[], + ) - expand_method.change(lambda method: - (gr.update(visible=method == "max_width"), - gr.update(visible=method != "max_width")), - inputs=expand_method, - outputs=[max_extra_edges, max_tokens]) + expand_method.change( + lambda method: ( + gr.update(visible=method == "max_width"), + gr.update(visible=method != "max_width"), + ), + inputs=expand_method, + outputs=[max_extra_edges, max_tokens], + ) if_trainee_model.change( lambda use_trainee: [gr.update(visible=use_trainee)] * 5, inputs=if_trainee_model, - outputs=[trainee_url, trainee_model, quiz_samples, edge_sampling, trainee_api_key]) + outputs=[ + trainee_url, + trainee_model, + quiz_samples, + edge_sampling, + trainee_api_key, + ], + ) upload_file.change( lambda x: (gr.update(visible=True)), @@ -479,41 +526,61 @@ def sum_tokens(client): ) submit_btn.click( - lambda *args: run_graphgen(GraphGenParams( - if_trainee_model=args[0], - input_file=args[1], - tokenizer=args[2], - qa_form=args[3], - bidirectional=args[4], - expand_method=args[5], - max_extra_edges=args[6], - max_tokens=args[7], - max_depth=args[8], - edge_sampling=args[9], - isolated_node_strategy=args[10], - loss_strategy=args[11], - synthesizer_url=args[12], - synthesizer_model=args[13], - trainee_model=args[14], - api_key=args[15], - chunk_size=args[16], - rpm=args[17], - tpm=args[18], - quiz_samples=args[19], - trainee_url=args[20], - trainee_api_key=args[21], - token_counter=args[22], - )), + lambda *args: run_graphgen( + GraphGenParams( + if_trainee_model=args[0], + input_file=args[1], + tokenizer=args[2], + qa_form=args[3], + bidirectional=args[4], + expand_method=args[5], + max_extra_edges=args[6], + max_tokens=args[7], + max_depth=args[8], + edge_sampling=args[9], + isolated_node_strategy=args[10], + loss_strategy=args[11], + synthesizer_url=args[12], + synthesizer_model=args[13], + trainee_model=args[14], + api_key=args[15], + chunk_size=args[16], + rpm=args[17], + tpm=args[18], + quiz_samples=args[19], + trainee_url=args[20], + trainee_api_key=args[21], + token_counter=args[22], + ) + ), inputs=[ - if_trainee_model, upload_file, tokenizer, qa_form, - bidirectional, expand_method, max_extra_edges, max_tokens, - max_depth, edge_sampling, isolated_node_strategy, - loss_strategy, synthesizer_url, synthesizer_model, trainee_model, - api_key, chunk_size, rpm, tpm, quiz_samples, trainee_url, trainee_api_key, token_counter + if_trainee_model, + upload_file, + tokenizer, + qa_form, + bidirectional, + expand_method, + max_extra_edges, + max_tokens, + max_depth, + edge_sampling, + isolated_node_strategy, + loss_strategy, + synthesizer_url, + synthesizer_model, + trainee_model, + api_key, + chunk_size, + rpm, + tpm, + quiz_samples, + trainee_url, + trainee_api_key, + token_counter, ], outputs=[output, token_counter], ) if __name__ == "__main__": demo.queue(api_open=False, default_concurrency_limit=2) - demo.launch(server_name='0.0.0.0') + demo.launch(server_name="0.0.0.0") diff --git a/webui/translation.json b/webui/translation.json index 583a8eba..29398013 100644 --- a/webui/translation.json +++ b/webui/translation.json @@ -1,36 +1,42 @@ { "en": { "Title": "✨Easy-to-use LLM Training Data Generation Framework✨", + "\n\n": "\n\n", + "### [GraphGen](https://github.com/open-sciencelab/GraphGen) ": "### [GraphGen](https://github.com/open-sciencelab/GraphGen) ", "Intro": "is a framework for synthetic data generation guided by knowledge graphs, designed to tackle challenges for knowledge-intensive QA generation. \n\nBy uploading your text chunks (such as knowledge in agriculture, healthcare, or marine science) and filling in the LLM API key, you can generate the training data required by **[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)** and **[xtuner](https://github.com/InternLM/xtuner)** online. We will automatically delete user information after completion.", + "# ": "# ", "Use Trainee Model": "Use Trainee Model to identify knowledge blind spots, please keep disable for SiliconCloud", "Synthesizer URL Info": "Base URL for the Synthesizer Model API, use SiliconFlow as default", - "Trainee URL Info": "Base URL for the Trainee Model API, use SiliconFlow as default", "Synthesizer Model Info": "Model for constructing KGs and generating QAs", + "Trainee URL Info": "Base URL for the Trainee Model API, use SiliconFlow as default", "Trainee Model Info": "Model for training", + "SiliconFlow Token for Trainee Model": "SiliconFlow API Key for Trainee Model", "Model Config": "Model Configuration", "Generation Config": "Generation Config", "SiliconFlow Token": "SiliconFlow API Key", - "SiliconFlow Token for Trainee Model": "SiliconFlow API Key for Trainee Model", "Test Connection": "Test Connection", - "Run GraphGen": "Run GraphGen", "Upload File": "Upload File", - "Example Files": "Example Files" + "Example Files": "Example Files", + "Run GraphGen": "Run GraphGen" }, "zh": { "Title": "✨开箱即用的LLM训练数据生成框架✨", + "\n\n": "\n\n", + "### [GraphGen](https://github.com/open-sciencelab/GraphGen) ": "### [GraphGen](https://github.com/open-sciencelab/GraphGen) ", "Intro": "是一个基于知识图谱的数据合成框架,旨在知识密集型任务中生成问答。\n\n 上传你的文本块(如农业、医疗、海洋知识),填写 LLM api key,即可在线生成 **[LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory)**、**[xtuner](https://github.com/InternLM/xtuner)** 所需训练数据。结束后我们将自动删除用户信息。", + "# ": "# ", "Use Trainee Model": "使用Trainee Model来识别知识盲区,使用硅基流动时请保持禁用", "Synthesizer URL Info": "调用合成模型API的URL,默认使用硅基流动", - "Trainee URL Info": "调用学生模型API的URL,默认使用硅基流动", "Synthesizer Model Info": "用于构建知识图谱和生成问答的模型", + "Trainee URL Info": "调用学生模型API的URL,默认使用硅基流动", "Trainee Model Info": "用于训练的模型", + "SiliconFlow Token for Trainee Model": "SiliconFlow Token for Trainee Model", "Model Config": "模型配置", "Generation Config": "生成配置", - "SiliconCloud Token": "硅基流动 API Key", - "SiliconCloud Token for Trainee Model": "硅基流动 API Key (学生模型)", + "SiliconFlow Token": "SiliconFlow Token", "Test Connection": "测试接口", - "Run GraphGen": "运行GraphGen", "Upload File": "上传文件", - "Example Files": "示例文件" + "Example Files": "示例文件", + "Run GraphGen": "运行GraphGen" } -} +} \ No newline at end of file From e35162d79d3c96f2fa62a4eba3e5184bdb0b860e Mon Sep 17 00:00:00 2001 From: chenzihong <58508660+ChenZiHong-Gavin@users.noreply.github.com> Date: Thu, 31 Jul 2025 20:42:34 +0800 Subject: [PATCH 6/6] Update graphgen/operators/search/kg/search_wikipedia.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- graphgen/operators/search/kg/search_wikipedia.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphgen/operators/search/kg/search_wikipedia.py b/graphgen/operators/search/kg/search_wikipedia.py index dd3a35ba..05449fe1 100644 --- a/graphgen/operators/search/kg/search_wikipedia.py +++ b/graphgen/operators/search/kg/search_wikipedia.py @@ -50,7 +50,7 @@ async def search_wikipedia( entities, desc="Searching Wikipedia", total=len(entities) ): try: - entity, summary = await _process_single_entity(entity, wiki_search_client) + summary = await _process_single_entity(entity, wiki_search_client) if summary: wiki_data[entity] = summary except Exception as e: # pylint: disable=broad-except