From 04e928c18cc78fd1eda2d240ffbc9a7c65fe9f24 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 10 Sep 2025 16:28:10 +0800 Subject: [PATCH 1/2] fix(webui): sync gradio demo --- .github/workflows/push-to-hf.yml | 2 +- graphgen/graphgen.py | 9 +- graphgen/operators/__init__.py | 8 +- graphgen/operators/traverse_graph.py | 42 +++++--- webui/app.py | 150 +++++++++++++-------------- webui/base.py | 5 +- webui/i18n.py | 1 + 7 files changed, 115 insertions(+), 102 deletions(-) diff --git a/.github/workflows/push-to-hf.yml b/.github/workflows/push-to-hf.yml index 1ddf87b8..6ae84cab 100644 --- a/.github/workflows/push-to-hf.yml +++ b/.github/workflows/push-to-hf.yml @@ -43,7 +43,7 @@ jobs: [[ -d hf-repo ]] && rm -rf hf-repo git clone https://huggingface.co/${HF_REPO_TYPE}/${HF_REPO_ID} hf-repo - rsync -a --delete --exclude='.git' ./ hf-repo/ || true + rsync -a --delete --exclude='.git' --exclude='hf-repo' ./ hf-repo/ cd hf-repo git add . diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 7b7b302a..68fb19d6 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -23,8 +23,8 @@ judge_statement, quiz, search_all, - traverse_graph_atomically, - traverse_graph_by_edge, + traverse_graph_for_aggregated, + traverse_graph_for_atomic, traverse_graph_for_multi_hop, ) from .utils import ( @@ -69,6 +69,7 @@ def __post_init__(self): self.tokenizer_instance: Tokenizer = Tokenizer( model_name=self.config["tokenizer"] ) + print(os.getenv("SYNTHESIZER_MODEL"), os.getenv("SYNTHESIZER_API_KEY")) self.synthesizer_llm_client: OpenAIModel = OpenAIModel( model_name=os.getenv("SYNTHESIZER_MODEL"), api_key=os.getenv("SYNTHESIZER_API_KEY"), @@ -326,7 +327,7 @@ async def async_traverse(self): output_data_type = self.config["output_data_type"] if output_data_type == "atomic": - results = await traverse_graph_atomically( + results = await traverse_graph_for_atomic( self.synthesizer_llm_client, self.tokenizer_instance, self.graph_storage, @@ -344,7 +345,7 @@ async def async_traverse(self): self.progress_bar, ) elif output_data_type == "aggregated": - results = await traverse_graph_by_edge( + results = await traverse_graph_for_aggregated( self.synthesizer_llm_client, self.tokenizer_instance, self.graph_storage, diff --git a/graphgen/operators/__init__.py b/graphgen/operators/__init__.py index f74e013a..b3329704 100644 --- a/graphgen/operators/__init__.py +++ b/graphgen/operators/__init__.py @@ -5,8 +5,8 @@ from .judge import judge_statement from .quiz import quiz from .traverse_graph import ( - traverse_graph_atomically, - traverse_graph_by_edge, + traverse_graph_for_aggregated, + traverse_graph_for_atomic, traverse_graph_for_multi_hop, ) @@ -15,8 +15,8 @@ "quiz", "judge_statement", "search_all", - "traverse_graph_by_edge", - "traverse_graph_atomically", + "traverse_graph_for_aggregated", + "traverse_graph_for_atomic", "traverse_graph_for_multi_hop", "generate_cot", ] diff --git a/graphgen/operators/traverse_graph.py b/graphgen/operators/traverse_graph.py index da1b6685..dfc815f4 100644 --- a/graphgen/operators/traverse_graph.py +++ b/graphgen/operators/traverse_graph.py @@ -158,7 +158,7 @@ def _post_process_synthetic_data(data): return qas -async def traverse_graph_by_edge( +async def traverse_graph_for_aggregated( llm_client: OpenAIModel, tokenizer: Tokenizer, graph_storage: NetworkXStorage, @@ -251,7 +251,6 @@ async def _process_single_batch( qas = _post_process_synthetic_data(content) if len(qas) == 0: - print(content) logger.error( "Error occurred while processing batch, question or answer is None" ) @@ -307,7 +306,8 @@ async def _process_single_batch( return results -async def traverse_graph_atomically( +# pylint: disable=too-many-branches, too-many-statements +async def traverse_graph_for_atomic( llm_client: OpenAIModel, tokenizer: Tokenizer, graph_storage: NetworkXStorage, @@ -328,17 +328,28 @@ async def traverse_graph_atomically( :param max_concurrent :return: question and answer """ - assert traverse_strategy.qa_form == "atomic" + assert traverse_strategy.qa_form == "atomic" semaphore = asyncio.Semaphore(max_concurrent) + def _parse_qa(qa: str) -> tuple: + if "Question:" in qa and "Answer:" in qa: + question = qa.split("Question:")[1].split("Answer:")[0].strip() + answer = qa.split("Answer:")[1].strip() + elif "问题:" in qa and "答案:" in qa: + question = qa.split("问题:")[1].split("答案:")[0].strip() + answer = qa.split("答案:")[1].strip() + else: + return None, None + return question.strip('"'), answer.strip('"') + async def _generate_question(node_or_edge: tuple): if len(node_or_edge) == 2: des = node_or_edge[0] + ": " + node_or_edge[1]["description"] - loss = node_or_edge[1]["loss"] + loss = node_or_edge[1]["loss"] if "loss" in node_or_edge[1] else -1.0 else: des = node_or_edge[2]["description"] - loss = node_or_edge[2]["loss"] + loss = node_or_edge[2]["loss"] if "loss" in node_or_edge[2] else -1.0 async with semaphore: try: @@ -350,13 +361,8 @@ async def _generate_question(node_or_edge: tuple): ) ) - if "Question:" in qa and "Answer:" in qa: - question = qa.split("Question:")[1].split("Answer:")[0].strip() - answer = qa.split("Answer:")[1].strip() - elif "问题:" in qa and "答案:" in qa: - question = qa.split("问题:")[1].split("答案:")[0].strip() - answer = qa.split("答案:")[1].strip() - else: + question, answer = _parse_qa(qa) + if question is None or answer is None: return {} question = question.strip('"') @@ -386,16 +392,18 @@ async def _generate_question(node_or_edge: tuple): if "" in node[1]["description"]: description_list = node[1]["description"].split("") for item in description_list: - tasks.append((node[0], {"description": item, "loss": node[1]["loss"]})) + tasks.append((node[0], {"description": item})) + if "loss" in node[1]: + tasks[-1][1]["loss"] = node[1]["loss"] else: tasks.append((node[0], node[1])) for edge in edges: if "" in edge[2]["description"]: description_list = edge[2]["description"].split("") for item in description_list: - tasks.append( - (edge[0], edge[1], {"description": item, "loss": edge[2]["loss"]}) - ) + tasks.append((edge[0], edge[1], {"description": item})) + if "loss" in edge[2]: + tasks[-1][2]["loss"] = edge[2]["loss"] else: tasks.append((edge[0], edge[1], edge[2])) diff --git a/webui/app.py b/webui/app.py index 50f57131..917495e0 100644 --- a/webui/app.py +++ b/webui/app.py @@ -1,4 +1,3 @@ -# pylint: skip-file import json import os import sys @@ -6,6 +5,7 @@ import gradio as gr import pandas as pd +from dotenv import load_dotenv from webui.base import GraphGenParams from webui.cache_utils import cleanup_workspace, setup_workspace @@ -19,10 +19,12 @@ sys.path.append(root_dir) from graphgen.graphgen import GraphGen -from graphgen.models import OpenAIModel, Tokenizer, TraverseStrategy +from graphgen.models import OpenAIModel, Tokenizer from graphgen.models.llm.limitter import RPM, TPM from graphgen.utils import set_logger +load_dotenv() + css = """ .center-row { display: flex; @@ -37,7 +39,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache")) set_logger(log_file, if_stream=False) - graph_gen = GraphGen(working_dir=working_dir) + graph_gen = GraphGen(working_dir=working_dir, config=config) # Set up LLM clients graph_gen.synthesizer_llm_client = OpenAIModel( @@ -60,19 +62,6 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: graph_gen.tokenizer_instance = Tokenizer(config.get("tokenizer", "cl100k_base")) - strategy_config = config.get("traverse_strategy", {}) - graph_gen.traverse_strategy = TraverseStrategy( - qa_form=strategy_config.get("qa_form"), - expand_method=strategy_config.get("expand_method"), - bidirectional=strategy_config.get("bidirectional"), - max_extra_edges=strategy_config.get("max_extra_edges"), - max_tokens=strategy_config.get("max_tokens"), - max_depth=strategy_config.get("max_depth"), - edge_sampling=strategy_config.get("edge_sampling"), - isolated_node_strategy=strategy_config.get("isolated_node_strategy"), - loss_strategy=str(strategy_config.get("loss_strategy")), - ) - return graph_gen @@ -84,10 +73,15 @@ def sum_tokens(client): config = { "if_trainee_model": params.if_trainee_model, "input_file": params.input_file, + "output_data_type": params.output_data_type, + "output_data_format": params.output_data_format, "tokenizer": params.tokenizer, - "quiz_samples": params.quiz_samples, + "search": {"enabled": False}, + "quiz_and_judge_strategy": { + "enabled": params.if_trainee_model, + "quiz_samples": params.quiz_samples, + }, "traverse_strategy": { - "qa_form": params.qa_form, "bidirectional": params.bidirectional, "expand_method": params.expand_method, "max_extra_edges": params.max_extra_edges, @@ -122,6 +116,35 @@ def sum_tokens(client): env["TRAINEE_BASE_URL"], env["TRAINEE_API_KEY"], env["TRAINEE_MODEL"] ) + # Load input data + file = config["input_file"] + if isinstance(file, list): + file = file[0] + + data = [] + + if file.endswith(".jsonl"): + config["input_data_type"] = "raw" + with open(file, "r", encoding="utf-8") as f: + data.extend(json.loads(line) for line in f) + elif file.endswith(".json"): + config["input_data_type"] = "chunked" + with open(file, "r", encoding="utf-8") as f: + data.extend(json.load(f)) + elif file.endswith(".txt"): + # 读取文件后根据chunk_size转成raw格式的数据 + config["input_data_type"] = "raw" + content = "" + with open(file, "r", encoding="utf-8") as f: + lines = f.readlines() + for line in lines: + content += line.strip() + " " + size = int(config.get("chunk_size", 512)) + chunks = [content[i : i + size] for i in range(0, len(content), size)] + data.extend([{"content": chunk} for chunk in chunks]) + else: + raise ValueError(f"Unsupported file type: {file}") + # Initialize GraphGen graph_gen = init_graph_gen(config, env) graph_gen.clear() @@ -129,51 +152,20 @@ def sum_tokens(client): graph_gen.progress_bar = progress try: - # Load input data - file = config["input_file"] - if isinstance(file, list): - file = file[0] - - data = [] - - if file.endswith(".jsonl"): - data_type = "raw" - with open(file, "r", encoding="utf-8") as f: - data.extend(json.loads(line) for line in f) - elif file.endswith(".json"): - data_type = "chunked" - with open(file, "r", encoding="utf-8") as f: - data.extend(json.load(f)) - elif file.endswith(".txt"): - # 读取文件后根据chunk_size转成raw格式的数据 - data_type = "raw" - content = "" - with open(file, "r", encoding="utf-8") as f: - lines = f.readlines() - for line in lines: - content += line.strip() + " " - size = int(config.get("chunk_size", 512)) - chunks = [content[i : i + size] for i in range(0, len(content), size)] - data.extend([{"content": chunk} for chunk in chunks]) - else: - raise ValueError(f"Unsupported file type: {file}") - # Process the data - graph_gen.insert(data, data_type) + graph_gen.insert() if config["if_trainee_model"]: # Generate quiz - graph_gen.quiz(max_samples=config["quiz_samples"]) + graph_gen.quiz() # Judge statements graph_gen.judge() else: graph_gen.traverse_strategy.edge_sampling = "random" - # Skip judge statements - graph_gen.judge(skip=True) # Traverse graph - graph_gen.traverse(traverse_strategy=graph_gen.traverse_strategy) + graph_gen.traverse() # Save output output_data = graph_gen.qa_storage.data @@ -328,12 +320,18 @@ def sum_tokens(client): tokenizer = gr.Textbox( label="Tokenizer", value="cl100k_base", interactive=True ) - qa_form = gr.Radio( + output_data_type = gr.Radio( choices=["atomic", "multi_hop", "aggregated"], - label="QA Form", + label="Output Data Type", value="aggregated", interactive=True, ) + output_data_format = gr.Radio( + choices=["Alpaca", "Sharegpt", "ChatML"], + label="Output Data Format", + value="Alpaca", + interactive=True, + ) quiz_samples = gr.Number( label="Quiz Samples", value=2, @@ -533,33 +531,35 @@ def sum_tokens(client): if_trainee_model=args[0], input_file=args[1], tokenizer=args[2], - qa_form=args[3], - bidirectional=args[4], - expand_method=args[5], - max_extra_edges=args[6], - max_tokens=args[7], - max_depth=args[8], - edge_sampling=args[9], - isolated_node_strategy=args[10], - loss_strategy=args[11], - synthesizer_url=args[12], - synthesizer_model=args[13], - trainee_model=args[14], - api_key=args[15], - chunk_size=args[16], - rpm=args[17], - tpm=args[18], - quiz_samples=args[19], - trainee_url=args[20], - trainee_api_key=args[21], - token_counter=args[22], + output_data_type=args[3], + output_data_format=args[4], + bidirectional=args[5], + expand_method=args[6], + max_extra_edges=args[7], + max_tokens=args[8], + max_depth=args[9], + edge_sampling=args[10], + isolated_node_strategy=args[11], + loss_strategy=args[12], + synthesizer_url=args[13], + synthesizer_model=args[14], + trainee_model=args[15], + api_key=args[16], + chunk_size=args[17], + rpm=args[18], + tpm=args[19], + quiz_samples=args[20], + trainee_url=args[21], + trainee_api_key=args[22], + token_counter=args[23], ) ), inputs=[ if_trainee_model, upload_file, tokenizer, - qa_form, + output_data_type, + output_data_format, bidirectional, expand_method, max_extra_edges, diff --git a/webui/base.py b/webui/base.py index 32f3ed10..f87d7d9b 100644 --- a/webui/base.py +++ b/webui/base.py @@ -1,15 +1,18 @@ from dataclasses import dataclass from typing import Any + @dataclass class GraphGenParams: """ GraphGen parameters """ + if_trainee_model: bool input_file: str tokenizer: str - qa_form: str + output_data_type: str + output_data_format: str bidirectional: bool expand_method: str max_extra_edges: int diff --git a/webui/i18n.py b/webui/i18n.py index e15acd89..ce6bb40e 100644 --- a/webui/i18n.py +++ b/webui/i18n.py @@ -1,3 +1,4 @@ +# pylint: skip-file import functools import inspect import json From dbdb541f79764f095a633169f8711f2e7b76ba88 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 10 Sep 2025 16:54:00 +0800 Subject: [PATCH 2/2] fix: stream log to cmd --- graphgen/configs/multi_hop_config.yaml | 2 +- graphgen/operators/traverse_graph.py | 4 +++- webui/app.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/graphgen/configs/multi_hop_config.yaml b/graphgen/configs/multi_hop_config.yaml index 530edcd1..bb75d0a9 100644 --- a/graphgen/configs/multi_hop_config.yaml +++ b/graphgen/configs/multi_hop_config.yaml @@ -7,7 +7,7 @@ search: # web search configuration enabled: false # whether to enable web search search_types: ["google"] # search engine types, support: google, bing, uniprot, wikipedia quiz_and_judge_strategy: # quiz and test whether the LLM masters the knowledge points - enabled: true + enabled: false quiz_samples: 2 # number of quiz samples to generate re_judge: false # whether to re-judge the existing quiz samples traverse_strategy: # strategy for clustering sub-graphs using comprehension loss diff --git a/graphgen/operators/traverse_graph.py b/graphgen/operators/traverse_graph.py index dfc815f4..16e2d25b 100644 --- a/graphgen/operators/traverse_graph.py +++ b/graphgen/operators/traverse_graph.py @@ -135,7 +135,9 @@ def get_average_loss(batch: tuple, loss_strategy: str) -> float: ) / (len(batch[0]) + len(batch[1])) raise ValueError("Invalid loss strategy") except Exception as e: # pylint: disable=broad-except - logger.error("Error calculating average loss: %s", e) + logger.warning( + "Loss not found in some nodes or edges, setting loss to -1.0: %s", e + ) return -1.0 diff --git a/webui/app.py b/webui/app.py index 917495e0..f2148531 100644 --- a/webui/app.py +++ b/webui/app.py @@ -38,7 +38,7 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: # Set up working directory log_file, working_dir = setup_workspace(os.path.join(root_dir, "cache")) - set_logger(log_file, if_stream=False) + set_logger(log_file, if_stream=True) graph_gen = GraphGen(working_dir=working_dir, config=config) # Set up LLM clients