From fad24d00bf6428741578c0acac07ac7da563b680 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 24 Sep 2025 11:20:36 +0800 Subject: [PATCH 1/9] fix: update __init__.py in models --- graphgen/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 79111b00..3a4d9b63 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -43,4 +43,5 @@ "TraverseStrategy", # community models "CommunityDetector", + "read_file", ] From a430183a585450e8728bff043cb6d27cf71d26ac Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 24 Sep 2025 11:59:21 +0800 Subject: [PATCH 2/9] refactor: add datatypes --- graphgen/bases/base_splitter.py | 47 ++++++ graphgen/bases/datatypes.py | 18 +++ graphgen/evaluate.py | 134 +++++++++++------- graphgen/graphgen.py | 2 +- graphgen/models/__init__.py | 4 - graphgen/models/evaluate/base_evaluator.py | 16 ++- graphgen/models/evaluate/length_evaluator.py | 10 +- graphgen/models/evaluate/mtld_evaluator.py | 21 +-- graphgen/models/evaluate/reward_evaluator.py | 20 ++- graphgen/models/evaluate/uni_evaluator.py | 68 ++++++--- graphgen/models/text/__init__.py | 0 graphgen/models/text/chunk.py | 7 - graphgen/models/text/text_pair.py | 9 -- graphgen/operators/kg/extract_kg.py | 3 +- .../preprocess/resolute_coreference.py | 3 +- 15 files changed, 238 insertions(+), 124 deletions(-) create mode 100644 graphgen/bases/base_splitter.py create mode 100644 graphgen/bases/datatypes.py delete mode 100644 graphgen/models/text/__init__.py delete mode 100644 graphgen/models/text/chunk.py delete mode 100644 graphgen/models/text/text_pair.py diff --git a/graphgen/bases/base_splitter.py b/graphgen/bases/base_splitter.py new file mode 100644 index 00000000..c298ee5d --- /dev/null +++ b/graphgen/bases/base_splitter.py @@ -0,0 +1,47 @@ +import copy +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional + +from graphgen.bases.datatypes import Chunk + + +@dataclass +class BaseSplitter(ABC): + """ + Abstract base class for splitting text into smaller chunks. + """ + + chunk_size: int = 1024 + chunk_overlap_size: int = 100 + length_function: Callable[[str], int] = len + keep_separator: bool = False + add_start_index: bool = False + + @abstractmethod + def split_text(self, text: str) -> List[Dict[str, Any]]: + """ + Split the input text into smaller chunks. + + :param text: The input text to be split. + :return: A list of dictionaries, each containing a chunk of text and optionally its start index. + """ + + def create_chunks( + self, texts: List[str], metadatas: Optional[List[dict]] = None + ) -> List[Chunk]: + """ + Turn a list of texts into a list of Chunks, with optional metadata. + :param texts: + :param metadatas: + :return: + """ + _metadatas = metadatas or [{}] * len(texts) + chunks = [] + for i, text in enumerate(texts): + chunks.append(Chunk(content=text, metadata=copy.deepcopy(_metadatas[i]))) + return chunks + + def split(self, text: str, metadata: Optional[dict] = None) -> List[Chunk]: + texts = self.split_text(text) + return self.create_chunks(texts, [metadata] * len(texts) if metadata else None) diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py new file mode 100644 index 00000000..fd7bc177 --- /dev/null +++ b/graphgen/bases/datatypes.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + + +@dataclass +class Chunk: + id: str + content: str + metadata: dict + + +@dataclass +class QAPair: + """ + A pair of question and answer. + """ + + question: str + answer: str diff --git a/graphgen/evaluate.py b/graphgen/evaluate.py index da74a308..c6737516 100644 --- a/graphgen/evaluate.py +++ b/graphgen/evaluate.py @@ -1,11 +1,15 @@ """Evaluate the quality of the generated text using various metrics""" -import os -import json import argparse +import json +import os + import pandas as pd from dotenv import load_dotenv -from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, TextPair, UniEvaluator + +from graphgen.bases.datatypes import QAPair + +from .models import LengthEvaluator, MTLDEvaluator, RewardEvaluator, UniEvaluator from .utils import logger, set_logger sys_path = os.path.abspath(os.path.dirname(__file__)) @@ -13,15 +17,15 @@ load_dotenv() + def evaluate_length(corpus, tokenizer_name): - length_evaluator = LengthEvaluator( - tokenizer_name=tokenizer_name - ) + length_evaluator = LengthEvaluator(tokenizer_name=tokenizer_name) logger.info("Length evaluator loaded") scores = length_evaluator.get_average_score(corpus) logger.info("Length scores: %s", scores) return scores + def evaluate_mtld(corpus): mtld_evaluator = MTLDEvaluator() logger.info("MTLD evaluator loaded") @@ -31,30 +35,30 @@ def evaluate_mtld(corpus): logger.info("MTLD min max scores: %s", min_max_scores) return scores, min_max_scores + def evaluate_reward(corpus, reward_model_names): scores = [] for reward_name in reward_model_names: - reward_evaluator = RewardEvaluator( - reward_name=reward_name - ) + reward_evaluator = RewardEvaluator(reward_name=reward_name) logger.info("Loaded reward model: %s", reward_name) average_score = reward_evaluator.get_average_score(corpus) logger.info("%s scores: %s", reward_name, average_score) min_max_scores = reward_evaluator.get_min_max_score(corpus) logger.info("%s min max scores: %s", reward_name, min_max_scores) - scores.append({ - 'reward_name': reward_name.split('/')[-1], - 'score': average_score, - 'min_max_scores': min_max_scores - }) + scores.append( + { + "reward_name": reward_name.split("/")[-1], + "score": average_score, + "min_max_scores": min_max_scores, + } + ) del reward_evaluator clean_gpu_cache() return scores + def evaluate_uni(corpus, uni_model_name): - uni_evaluator = UniEvaluator( - model_name=uni_model_name - ) + uni_evaluator = UniEvaluator(model_name=uni_model_name) logger.info("Uni evaluator loaded with model %s", uni_model_name) uni_scores = uni_evaluator.get_average_score(corpus) for key, value in uni_scores.items(): @@ -64,27 +68,47 @@ def evaluate_uni(corpus, uni_model_name): logger.info("Uni %s min max scores: %s", key, value) del uni_evaluator clean_gpu_cache() - return (uni_scores['naturalness'], uni_scores['coherence'], uni_scores['understandability'], - min_max_scores['naturalness'], min_max_scores['coherence'], min_max_scores['understandability']) + return ( + uni_scores["naturalness"], + uni_scores["coherence"], + uni_scores["understandability"], + min_max_scores["naturalness"], + min_max_scores["coherence"], + min_max_scores["understandability"], + ) def clean_gpu_cache(): import torch + if torch.cuda.is_available(): torch.cuda.empty_cache() -if __name__ == '__main__': +if __name__ == "__main__": import torch.multiprocessing as mp + parser = argparse.ArgumentParser() - parser.add_argument('--folder', type=str, default='cache/data', help='folder to load data') - parser.add_argument('--output', type=str, default='cache/output', help='path to save output') + parser.add_argument( + "--folder", type=str, default="cache/data", help="folder to load data" + ) + parser.add_argument( + "--output", type=str, default="cache/output", help="path to save output" + ) - parser.add_argument('--tokenizer', type=str, default='cl100k_base', help='tokenizer name') - parser.add_argument('--reward', type=str, default='OpenAssistant/reward-model-deberta-v3-large-v2', - help='Comma-separated list of reward models') - parser.add_argument('--uni', type=str, default='MingZhong/unieval-sum', help='uni model name') + parser.add_argument( + "--tokenizer", type=str, default="cl100k_base", help="tokenizer name" + ) + parser.add_argument( + "--reward", + type=str, + default="OpenAssistant/reward-model-deberta-v3-large-v2", + help="Comma-separated list of reward models", + ) + parser.add_argument( + "--uni", type=str, default="MingZhong/unieval-sum", help="uni model name" + ) args = parser.parse_args() @@ -94,49 +118,55 @@ def clean_gpu_cache(): if not os.path.exists(args.output): os.makedirs(args.output) - reward_models = args.reward.split(',') - + reward_models = args.reward.split(",") results = [] logger.info("Data loaded from %s", args.folder) - mp.set_start_method('spawn') + mp.set_start_method("spawn") for file in os.listdir(args.folder): - if file.endswith('.json'): + if file.endswith(".json"): logger.info("Processing %s", file) - with open(os.path.join(args.folder, file), 'r', encoding='utf-8') as f: + with open(os.path.join(args.folder, file), "r", encoding="utf-8") as f: data = json.load(f) - data = [TextPair( - question=data[key]['question'], - answer=data[key]['answer'] - ) for key in data] + data = [ + QAPair(question=data[key]["question"], answer=data[key]["answer"]) + for key in data + ] length_scores = evaluate_length(data, args.tokenizer) mtld_scores, min_max_mtld_scores = evaluate_mtld(data) reward_scores = evaluate_reward(data, reward_models) - uni_naturalness_scores, uni_coherence_scores, uni_understandability_scores, \ - min_max_uni_naturalness_scores, min_max_uni_coherence_scores, min_max_uni_understandability_scores \ - = evaluate_uni(data, args.uni) + ( + uni_naturalness_scores, + uni_coherence_scores, + uni_understandability_scores, + min_max_uni_naturalness_scores, + min_max_uni_coherence_scores, + min_max_uni_understandability_scores, + ) = evaluate_uni(data, args.uni) result = { - 'file': file, - 'number': len(data), - 'length': length_scores, - 'mtld': mtld_scores, - 'mtld_min_max': min_max_mtld_scores, - 'uni_naturalness': uni_naturalness_scores, - 'uni_coherence': uni_coherence_scores, - 'uni_understandability': uni_understandability_scores, - 'uni_naturalness_min_max': min_max_uni_naturalness_scores, - 'uni_coherence_min_max': min_max_uni_coherence_scores, - 'uni_understandability_min_max': min_max_uni_understandability_scores + "file": file, + "number": len(data), + "length": length_scores, + "mtld": mtld_scores, + "mtld_min_max": min_max_mtld_scores, + "uni_naturalness": uni_naturalness_scores, + "uni_coherence": uni_coherence_scores, + "uni_understandability": uni_understandability_scores, + "uni_naturalness_min_max": min_max_uni_naturalness_scores, + "uni_coherence_min_max": min_max_uni_coherence_scores, + "uni_understandability_min_max": min_max_uni_understandability_scores, } for reward_score in reward_scores: - result[reward_score['reward_name']] = reward_score['score'] - result[f"{reward_score['reward_name']}_min_max"] = reward_score['min_max_scores'] + result[reward_score["reward_name"]] = reward_score["score"] + result[f"{reward_score['reward_name']}_min_max"] = reward_score[ + "min_max_scores" + ] results.append(result) results = pd.DataFrame(results) - results.to_csv(os.path.join(args.output, 'evaluation.csv'), index=False) + results.to_csv(os.path.join(args.output, "evaluation.csv"), index=False) diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index fcb62387..54f93b2b 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -8,8 +8,8 @@ from tqdm.asyncio import tqdm as tqdm_async from graphgen.bases.base_storage import StorageNameSpace +from graphgen.bases.datatypes import Chunk from graphgen.models import ( - Chunk, JsonKVStorage, JsonListStorage, NetworkXStorage, diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index 3a4d9b63..d555fcc9 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -14,8 +14,6 @@ from .storage.json_storage import JsonKVStorage, JsonListStorage from .storage.networkx_storage import NetworkXStorage from .strategy.travserse_strategy import TraverseStrategy -from .text.chunk import Chunk -from .text.text_pair import TextPair __all__ = [ # llm models @@ -24,7 +22,6 @@ "Token", "Tokenizer", # storage models - "Chunk", "NetworkXStorage", "JsonKVStorage", "JsonListStorage", @@ -34,7 +31,6 @@ "BingSearch", "UniProtSearch", # evaluate models - "TextPair", "LengthEvaluator", "MTLDEvaluator", "RewardEvaluator", diff --git a/graphgen/models/evaluate/base_evaluator.py b/graphgen/models/evaluate/base_evaluator.py index 6c5ae2d5..1359be6c 100644 --- a/graphgen/models/evaluate/base_evaluator.py +++ b/graphgen/models/evaluate/base_evaluator.py @@ -1,22 +1,24 @@ import asyncio - from dataclasses import dataclass + from tqdm.asyncio import tqdm as tqdm_async + +from graphgen.bases.datatypes import QAPair from graphgen.utils import create_event_loop -from graphgen.models.text.text_pair import TextPair + @dataclass class BaseEvaluator: max_concurrent: int = 100 results: list[float] = None - def evaluate(self, pairs: list[TextPair]) -> list[float]: + def evaluate(self, pairs: list[QAPair]) -> list[float]: """ Evaluate the text and return a score. """ return create_event_loop().run_until_complete(self.async_evaluate(pairs)) - async def async_evaluate(self, pairs: list[TextPair]) -> list[float]: + async def async_evaluate(self, pairs: list[QAPair]) -> list[float]: semaphore = asyncio.Semaphore(self.max_concurrent) async def evaluate_with_semaphore(pair): @@ -31,10 +33,10 @@ async def evaluate_with_semaphore(pair): results.append(await result) return results - async def evaluate_single(self, pair: TextPair) -> float: + async def evaluate_single(self, pair: QAPair) -> float: raise NotImplementedError() - def get_average_score(self, pairs: list[TextPair]) -> float: + def get_average_score(self, pairs: list[QAPair]) -> float: """ Get the average score of a batch of texts. """ @@ -42,7 +44,7 @@ def get_average_score(self, pairs: list[TextPair]) -> float: self.results = results return sum(self.results) / len(pairs) - def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]: + def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]: """ Get the min and max score of a batch of texts. """ diff --git a/graphgen/models/evaluate/length_evaluator.py b/graphgen/models/evaluate/length_evaluator.py index ba53ff6b..bf7cc483 100644 --- a/graphgen/models/evaluate/length_evaluator.py +++ b/graphgen/models/evaluate/length_evaluator.py @@ -1,19 +1,19 @@ from dataclasses import dataclass + +from graphgen.bases.datatypes import QAPair from graphgen.models.evaluate.base_evaluator import BaseEvaluator from graphgen.models.llm.tokenizer import Tokenizer -from graphgen.models.text.text_pair import TextPair from graphgen.utils import create_event_loop @dataclass class LengthEvaluator(BaseEvaluator): tokenizer_name: str = "cl100k_base" + def __post_init__(self): - self.tokenizer = Tokenizer( - model_name=self.tokenizer_name - ) + self.tokenizer = Tokenizer(model_name=self.tokenizer_name) - async def evaluate_single(self, pair: TextPair) -> float: + async def evaluate_single(self, pair: QAPair) -> float: loop = create_event_loop() return await loop.run_in_executor(None, self._calculate_length, pair.answer) diff --git a/graphgen/models/evaluate/mtld_evaluator.py b/graphgen/models/evaluate/mtld_evaluator.py index 4ea68875..fc563d1c 100644 --- a/graphgen/models/evaluate/mtld_evaluator.py +++ b/graphgen/models/evaluate/mtld_evaluator.py @@ -1,22 +1,27 @@ -from dataclasses import dataclass, field +from dataclasses import dataclass, field from typing import Set +from graphgen.bases.datatypes import QAPair from graphgen.models.evaluate.base_evaluator import BaseEvaluator -from graphgen.models.text.text_pair import TextPair -from graphgen.utils import detect_main_language, NLTKHelper, create_event_loop - +from graphgen.utils import NLTKHelper, create_event_loop, detect_main_language nltk_helper = NLTKHelper() + @dataclass class MTLDEvaluator(BaseEvaluator): """ 衡量文本词汇多样性的指标 """ - stopwords_en: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("english"))) - stopwords_zh: Set[str] = field(default_factory=lambda: set(nltk_helper.get_stopwords("chinese"))) - async def evaluate_single(self, pair: TextPair) -> float: + stopwords_en: Set[str] = field( + default_factory=lambda: set(nltk_helper.get_stopwords("english")) + ) + stopwords_zh: Set[str] = field( + default_factory=lambda: set(nltk_helper.get_stopwords("chinese")) + ) + + async def evaluate_single(self, pair: QAPair) -> float: loop = create_event_loop() return await loop.run_in_executor(None, self._calculate_mtld_score, pair.answer) @@ -71,6 +76,6 @@ def _compute_factors(tokens: list, threshold: float) -> float: if ttr <= threshold: factors += 1 else: - factors += (1 - (ttr - threshold) / (1 - threshold)) + factors += 1 - (ttr - threshold) / (1 - threshold) return len(tokens) / factors if factors > 0 else len(tokens) diff --git a/graphgen/models/evaluate/reward_evaluator.py b/graphgen/models/evaluate/reward_evaluator.py index 2e4c021c..4d2c2fb9 100644 --- a/graphgen/models/evaluate/reward_evaluator.py +++ b/graphgen/models/evaluate/reward_evaluator.py @@ -1,6 +1,8 @@ from dataclasses import dataclass + from tqdm import tqdm -from graphgen.models.text.text_pair import TextPair + +from graphgen.bases.datatypes import QAPair @dataclass @@ -9,19 +11,22 @@ class RewardEvaluator: Reward Model Evaluator. OpenAssistant/reward-model-deberta-v3-large-v2: 分数范围为[-inf, inf],越高越好 """ + reward_name: str = "OpenAssistant/reward-model-deberta-v3-large-v2" max_length: int = 2560 results: list[float] = None def __post_init__(self): import torch + self.num_gpus = torch.cuda.device_count() @staticmethod def process_chunk(rank, pairs, reward_name, max_length, return_dict): import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer - device = f'cuda:{rank}' + + device = f"cuda:{rank}" torch.cuda.set_device(rank) rank_model = AutoModelForSequenceClassification.from_pretrained(reward_name) @@ -37,7 +42,7 @@ def process_chunk(rank, pairs, reward_name, max_length, return_dict): pair.answer, return_tensors="pt", max_length=max_length, - truncation=True + truncation=True, ) inputs = {k: v.to(device) for k, v in inputs.items()} score = rank_model(**inputs).logits[0].item() @@ -45,8 +50,9 @@ def process_chunk(rank, pairs, reward_name, max_length, return_dict): return_dict[rank] = results - def evaluate(self, pairs: list[TextPair]) -> list[float]: + def evaluate(self, pairs: list[QAPair]) -> list[float]: import torch.multiprocessing as mp + chunk_size = len(pairs) // self.num_gpus chunks = [] for i in range(self.num_gpus): @@ -64,7 +70,7 @@ def evaluate(self, pairs: list[TextPair]) -> list[float]: for rank, chunk in enumerate(chunks): p = mp.Process( target=self.process_chunk, - args=(rank, chunk, self.reward_name, self.max_length, return_dict) + args=(rank, chunk, self.reward_name, self.max_length, return_dict), ) p.start() processes.append(p) @@ -84,7 +90,7 @@ def evaluate(self, pairs: list[TextPair]) -> list[float]: return results - def get_average_score(self, pairs: list[TextPair]) -> float: + def get_average_score(self, pairs: list[QAPair]) -> float: """ Get the average score of a batch of texts. """ @@ -92,7 +98,7 @@ def get_average_score(self, pairs: list[TextPair]) -> float: self.results = results return sum(self.results) / len(pairs) - def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]: + def get_min_max_score(self, pairs: list[QAPair]) -> tuple[float, float]: """ Get the min and max score of a batch of texts. """ diff --git a/graphgen/models/evaluate/uni_evaluator.py b/graphgen/models/evaluate/uni_evaluator.py index a334f0a9..20fa3517 100644 --- a/graphgen/models/evaluate/uni_evaluator.py +++ b/graphgen/models/evaluate/uni_evaluator.py @@ -1,40 +1,58 @@ # https://github.com/maszhongming/UniEval/tree/main from dataclasses import dataclass, field + from tqdm import tqdm -from graphgen.models.text.text_pair import TextPair + +from graphgen.bases.datatypes import QAPair def _add_questions(dimension: str, question: str, answer: str): if dimension == "naturalness": - cur_input = 'question: Is this a natural response in the dialogue? response: ' + answer + cur_input = ( + "question: Is this a natural response in the dialogue? response: " + + answer + ) elif dimension == "coherence": - cur_input = 'question: Is this a coherent response given the dialogue history? response: ' \ - + answer + ' dialogue history: ' + question + cur_input = ( + "question: Is this a coherent response given the dialogue history? response: " + + answer + + " dialogue history: " + + question + ) elif dimension == "understandability": - cur_input = 'question: Is this an understandable response in the dialogue? response: ' + answer + cur_input = ( + "question: Is this an understandable response in the dialogue? response: " + + answer + ) else: raise NotImplementedError( - 'The input format for this dimension is still undefined. Please customize it first.') + "The input format for this dimension is still undefined. Please customize it first." + ) return cur_input + @dataclass class UniEvaluator: model_name: str = "MingZhong/unieval-sum" - dimensions: list = field(default_factory=lambda: ['naturalness', 'coherence', 'understandability']) + dimensions: list = field( + default_factory=lambda: ["naturalness", "coherence", "understandability"] + ) max_length: int = 2560 results: dict = None def __post_init__(self): import torch + self.num_gpus = torch.cuda.device_count() self.results = {} @staticmethod def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict): import torch - from transformers import AutoTokenizer, AutoModelForSeq2SeqLM - device = f'cuda:{rank}' + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer + + device = f"cuda:{rank}" torch.cuda.set_device(rank) rank_model = AutoModelForSeq2SeqLM.from_pretrained(model_name) @@ -59,26 +77,26 @@ def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict): max_length=max_length, truncation=True, padding=True, - return_tensors='pt' + return_tensors="pt", ) encoded_tgt = tokenizer( tgt, max_length=max_length, truncation=True, padding=True, - return_tensors='pt' + return_tensors="pt", ) - src_tokens = encoded_src['input_ids'].to(device) - src_mask = encoded_src['attention_mask'].to(device) + src_tokens = encoded_src["input_ids"].to(device) + src_mask = encoded_src["attention_mask"].to(device) - tgt_tokens = encoded_tgt['input_ids'].to(device)[:, 0].unsqueeze(-1) + tgt_tokens = encoded_tgt["input_ids"].to(device)[:, 0].unsqueeze(-1) output = rank_model( input_ids=src_tokens, attention_mask=src_mask, labels=tgt_tokens, - use_cache = False + use_cache=False, ) logits = output.logits.view(-1, rank_model.config.vocab_size) @@ -91,8 +109,9 @@ def process_chunk(rank, pairs, model_name, max_length, dimension, return_dict): return_dict[rank] = results - def evaluate(self, pairs: list[TextPair]) -> list[dict]: + def evaluate(self, pairs: list[QAPair]) -> list[dict]: import torch.multiprocessing as mp + final_results = [] for dimension in self.dimensions: chunk_size = len(pairs) // self.num_gpus @@ -112,7 +131,14 @@ def evaluate(self, pairs: list[TextPair]) -> list[dict]: for rank, chunk in enumerate(chunks): p = mp.Process( target=self.process_chunk, - args=(rank, chunk, self.model_name, self.max_length, dimension, return_dict) + args=( + rank, + chunk, + self.model_name, + self.max_length, + dimension, + return_dict, + ), ) p.start() processes.append(p) @@ -130,12 +156,10 @@ def evaluate(self, pairs: list[TextPair]) -> list[dict]: p.terminate() p.join() - final_results.append({ - dimension: results - }) + final_results.append({dimension: results}) return final_results - def get_average_score(self, pairs: list[TextPair]) -> dict: + def get_average_score(self, pairs: list[QAPair]) -> dict: """ Get the average score of a batch of texts. """ @@ -147,7 +171,7 @@ def get_average_score(self, pairs: list[TextPair]) -> dict: self.results[key] = value return final_results - def get_min_max_score(self, pairs: list[TextPair]) -> dict: + def get_min_max_score(self, pairs: list[QAPair]) -> dict: """ Get the min and max score of a batch of texts. """ diff --git a/graphgen/models/text/__init__.py b/graphgen/models/text/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graphgen/models/text/chunk.py b/graphgen/models/text/chunk.py deleted file mode 100644 index 9678949f..00000000 --- a/graphgen/models/text/chunk.py +++ /dev/null @@ -1,7 +0,0 @@ -from dataclasses import dataclass - - -@dataclass -class Chunk: - id : str - content: str diff --git a/graphgen/models/text/text_pair.py b/graphgen/models/text/text_pair.py deleted file mode 100644 index f9a971f1..00000000 --- a/graphgen/models/text/text_pair.py +++ /dev/null @@ -1,9 +0,0 @@ -from dataclasses import dataclass - -@dataclass -class TextPair: - """ - A pair of input data. - """ - question: str - answer: str diff --git a/graphgen/operators/kg/extract_kg.py b/graphgen/operators/kg/extract_kg.py index ec1f959c..ed64f223 100644 --- a/graphgen/operators/kg/extract_kg.py +++ b/graphgen/operators/kg/extract_kg.py @@ -7,7 +7,8 @@ from tqdm.asyncio import tqdm as tqdm_async from graphgen.bases.base_storage import BaseGraphStorage -from graphgen.models import Chunk, OpenAIModel, Tokenizer +from graphgen.bases.datatypes import Chunk +from graphgen.models import OpenAIModel, Tokenizer from graphgen.operators.kg.merge_kg import merge_edges, merge_nodes from graphgen.templates import KG_EXTRACTION_PROMPT from graphgen.utils import ( diff --git a/graphgen/operators/preprocess/resolute_coreference.py b/graphgen/operators/preprocess/resolute_coreference.py index cdf702e2..e3c498da 100644 --- a/graphgen/operators/preprocess/resolute_coreference.py +++ b/graphgen/operators/preprocess/resolute_coreference.py @@ -1,6 +1,7 @@ from typing import List -from graphgen.models import Chunk, OpenAIModel +from graphgen.bases.datatypes import Chunk +from graphgen.models import OpenAIModel from graphgen.templates import COREFERENCE_RESOLUTION_PROMPT from graphgen.utils import detect_main_language From 0bffcfc772822a89d422d15144cc275876c6814b Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 24 Sep 2025 16:14:44 +0800 Subject: [PATCH 3/9] feat: add splitter classes --- graphgen/bases/base_splitter.py | 116 ++++++++++++-- graphgen/models/splitter/__init__.py | 0 .../models/splitter/character_splitter.py | 26 +++ graphgen/models/splitter/markdown_splitter.py | 32 ++++ .../splitter/recursive_character_splitter.py | 150 ++++++++++++++++++ tests/__init__.py | 0 tests/integration_tests/__init__.py | 0 .../splitter/test_character_splitter.py | 30 ++++ .../models/splitter/test_markdown_splitter.py | 40 +++++ .../test_recursive_character_splitter.py | 49 ++++++ 10 files changed, 429 insertions(+), 14 deletions(-) create mode 100644 graphgen/models/splitter/__init__.py create mode 100644 graphgen/models/splitter/character_splitter.py create mode 100644 graphgen/models/splitter/markdown_splitter.py create mode 100644 graphgen/models/splitter/recursive_character_splitter.py create mode 100644 tests/__init__.py create mode 100644 tests/integration_tests/__init__.py create mode 100644 tests/integration_tests/models/splitter/test_character_splitter.py create mode 100644 tests/integration_tests/models/splitter/test_markdown_splitter.py create mode 100644 tests/integration_tests/models/splitter/test_recursive_character_splitter.py diff --git a/graphgen/bases/base_splitter.py b/graphgen/bases/base_splitter.py index c298ee5d..dcec006e 100644 --- a/graphgen/bases/base_splitter.py +++ b/graphgen/bases/base_splitter.py @@ -1,9 +1,11 @@ import copy +import re from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional +from typing import Callable, Iterable, List, Literal, Optional, Union from graphgen.bases.datatypes import Chunk +from graphgen.utils import logger @dataclass @@ -13,35 +15,121 @@ class BaseSplitter(ABC): """ chunk_size: int = 1024 - chunk_overlap_size: int = 100 + chunk_overlap: int = 100 length_function: Callable[[str], int] = len keep_separator: bool = False add_start_index: bool = False + strip_whitespace: bool = True @abstractmethod - def split_text(self, text: str) -> List[Dict[str, Any]]: + def split_text(self, text: str) -> List[str]: """ Split the input text into smaller chunks. :param text: The input text to be split. - :return: A list of dictionaries, each containing a chunk of text and optionally its start index. + :return: A list of text chunks. """ def create_chunks( self, texts: List[str], metadatas: Optional[List[dict]] = None ) -> List[Chunk]: - """ - Turn a list of texts into a list of Chunks, with optional metadata. - :param texts: - :param metadatas: - :return: - """ + """Create chunks from a list of texts.""" _metadatas = metadatas or [{}] * len(texts) chunks = [] for i, text in enumerate(texts): - chunks.append(Chunk(content=text, metadata=copy.deepcopy(_metadatas[i]))) + index = 0 + previous_chunk_len = 0 + for chunk in self.split_text(text): + metadata = copy.deepcopy(_metadatas[i]) + if self.add_start_index: + offset = index + previous_chunk_len - self.chunk_overlap + index = text.find(chunk, max(0, offset)) + metadata["start_index"] = index + previous_chunk_len = len(chunk) + new_chunk = Chunk(content=chunk, metadata=metadata) + chunks.append(new_chunk) + return chunks + + def _join_chunks(self, chunks: List[str], separator: str) -> Optional[str]: + text = separator.join(chunks) + if self.strip_whitespace: + text = text.strip() + if text == "": + return None + return text + + def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]: + # We now want to combine these smaller pieces into medium size chunks to send to the LLM. + separator_len = self.length_function(separator) + + chunks = [] + current_chunk: List[str] = [] + total = 0 + for d in splits: + _len = self.length_function(d) + if ( + total + _len + (separator_len if len(current_chunk) > 0 else 0) + > self.chunk_size + ): + if total > self.chunk_size: + logger.warning( + "Created a chunk of size %s, which is longer than the specified %s", + total, + self.chunk_size, + ) + if len(current_chunk) > 0: + chunk = self._join_chunks(current_chunk, separator) + if chunk is not None: + chunks.append(chunk) + # Keep on popping if: + # - we have a larger chunk than in the chunk overlap + # - or if we still have any chunks and the length is long + while total > self.chunk_overlap or ( + total + _len + (separator_len if len(current_chunk) > 0 else 0) + > self.chunk_size + and total > 0 + ): + total -= self.length_function(current_chunk[0]) + ( + separator_len if len(current_chunk) > 1 else 0 + ) + current_chunk = current_chunk[1:] + current_chunk.append(d) + total += _len + (separator_len if len(current_chunk) > 1 else 0) + chunk = self._join_chunks(current_chunk, separator) + if chunk is not None: + chunks.append(chunk) return chunks - def split(self, text: str, metadata: Optional[dict] = None) -> List[Chunk]: - texts = self.split_text(text) - return self.create_chunks(texts, [metadata] * len(texts) if metadata else None) + @staticmethod + def _split_text_with_regex( + text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]] + ) -> List[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = ( + ( + [ + _splits[i] + _splits[i + 1] + for i in range(0, len(_splits) - 1, 2) + ] + ) + if keep_separator == "end" + else ( + [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)] + ) + ) + if len(_splits) % 2 == 0: + splits += _splits[-1:] + splits = ( + (splits + [_splits[-1]]) + if keep_separator == "end" + else ([_splits[0]] + splits) + ) + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] diff --git a/graphgen/models/splitter/__init__.py b/graphgen/models/splitter/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/models/splitter/character_splitter.py b/graphgen/models/splitter/character_splitter.py new file mode 100644 index 00000000..1c91877e --- /dev/null +++ b/graphgen/models/splitter/character_splitter.py @@ -0,0 +1,26 @@ +import re +from typing import Any, List + +from graphgen.bases.base_splitter import BaseSplitter + + +class CharacterSplitter(BaseSplitter): + """Splitting text that looks at characters.""" + + def __init__( + self, separator: str = "\n\n", is_separator_regex: bool = False, **kwargs: Any + ) -> None: + """Create a new TextSplitter.""" + super().__init__(**kwargs) + self._separator = separator + self._is_separator_regex = is_separator_regex + + def split_text(self, text: str) -> List[str]: + """Split incoming text and return chunks.""" + # First we naively split the large input into a bunch of smaller ones. + separator = ( + self._separator if self._is_separator_regex else re.escape(self._separator) + ) + splits = self._split_text_with_regex(text, separator, self.keep_separator) + _separator = "" if self.keep_separator else self._separator + return self._merge_splits(splits, _separator) diff --git a/graphgen/models/splitter/markdown_splitter.py b/graphgen/models/splitter/markdown_splitter.py new file mode 100644 index 00000000..baa8620b --- /dev/null +++ b/graphgen/models/splitter/markdown_splitter.py @@ -0,0 +1,32 @@ +from typing import Any + +from graphgen.models.splitter.recursive_character_splitter import ( + RecursiveCharacterSplitter, +) + + +class MarkdownTextRefSplitter(RecursiveCharacterSplitter): + """Attempts to split the text along Markdown-formatted headings.""" + + def __init__(self, **kwargs: Any) -> None: + """Initialize a MarkdownTextRefSplitter.""" + separators = [ + # First, try to split along Markdown headings (starting with level 2) + "\n#{1,6} ", + # Note the alternative syntax for headings (below) is not handled here + # Heading level 2 + # --------------- + # End of code block + "```\n", + # Horizontal lines + "\n\\*\\*\\*+\n", + "\n---+\n", + "\n___+\n", + # Note that this splitter doesn't handle horizontal lines defined + # by *three or more* of ***, ---, or ___, but this is not handled + "\n\n", + "\n", + " ", + "", + ] + super().__init__(separators=separators, **kwargs) diff --git a/graphgen/models/splitter/recursive_character_splitter.py b/graphgen/models/splitter/recursive_character_splitter.py new file mode 100644 index 00000000..78f82449 --- /dev/null +++ b/graphgen/models/splitter/recursive_character_splitter.py @@ -0,0 +1,150 @@ +import re +from typing import Any, List, Optional + +from graphgen.bases.base_splitter import BaseSplitter + + +class RecursiveCharacterSplitter(BaseSplitter): + """Splitting text by recursively look at characters. + + Recursively tries to split by different characters to find one that works. + """ + + def __init__( + self, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = False, + **kwargs: Any, + ) -> None: + """Create a new TextSplitter.""" + super().__init__(keep_separator=keep_separator, **kwargs) + self._separators = separators or ["\n\n", "\n", " ", ""] + self._is_separator_regex = is_separator_regex + + def _split_text(self, text: str, separators: List[str]) -> List[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1 :] + break + + _separator = separator if self._is_separator_regex else re.escape(separator) + splits = self._split_text_with_regex(text, _separator, self.keep_separator) + + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self.keep_separator else separator + for s in splits: + if self.length_function(s) < self.chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return final_chunks + + def split_text(self, text: str) -> List[str]: + return self._split_text(text, self._separators) + + +class ChineseRecursiveTextSplitter(RecursiveCharacterSplitter): + def __init__( + self, + separators: Optional[List[str]] = None, + keep_separator: bool = True, + is_separator_regex: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(keep_separator=keep_separator, **kwargs) + self._separators = separators or [ + "\n\n", + "\n", + "。|!|?", + r"\.\s|\!\s|\?\s", + r";|;\s", + r",|,\s", + ] + self._is_separator_regex = is_separator_regex + + def _split_text_with_regex_from_end( + self, text: str, separator: str, keep_separator: bool + ) -> List[str]: + # Now that we have the separator, split the text + if separator: + if keep_separator: + # The parentheses in the pattern keep the delimiters in the result. + _splits = re.split(f"({separator})", text) + splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] + if len(_splits) % 2 == 1: + splits += _splits[-1:] + # splits = [_splits[0]] + splits + else: + splits = re.split(separator, text) + else: + splits = list(text) + return [s for s in splits if s != ""] + + def _split_text(self, text: str, separators: List[str]) -> List[str]: + """Split incoming text and return chunks.""" + final_chunks = [] + # Get appropriate separator to use + separator = separators[-1] + new_separators = [] + for i, _s in enumerate(separators): + _separator = _s if self._is_separator_regex else re.escape(_s) + if _s == "": + separator = _s + break + if re.search(_separator, text): + separator = _s + new_separators = separators[i + 1 :] + break + + _separator = separator if self._is_separator_regex else re.escape(separator) + splits = self._split_text_with_regex_from_end( + text, _separator, self.keep_separator + ) + + # Now go merging things, recursively splitting longer texts. + _good_splits = [] + _separator = "" if self.keep_separator else separator + for s in splits: + if self.length_function(s) < self.chunk_size: + _good_splits.append(s) + else: + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + _good_splits = [] + if not new_separators: + final_chunks.append(s) + else: + other_info = self._split_text(s, new_separators) + final_chunks.extend(other_info) + if _good_splits: + merged_text = self._merge_splits(_good_splits, _separator) + final_chunks.extend(merged_text) + return [ + re.sub(r"\n{2,}", "\n", chunk.strip()) + for chunk in final_chunks + if chunk.strip() != "" + ] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/__init__.py b/tests/integration_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/models/splitter/test_character_splitter.py b/tests/integration_tests/models/splitter/test_character_splitter.py new file mode 100644 index 00000000..5d41547d --- /dev/null +++ b/tests/integration_tests/models/splitter/test_character_splitter.py @@ -0,0 +1,30 @@ +import pytest + +from graphgen.models.splitter.character_splitter import CharacterSplitter + + +@pytest.mark.parametrize( + "text,chunk_size,chunk_overlap,expected", + [ + ( + "This is a test.\n\nThis is only a test.\n\nIn the event of an actual emergency...", + 25, + 5, + [ + "This is a test.", + "This is only a test.", + "In the event of an actual emergency...", + ], + ), + ], +) +def test_character_splitter(text, chunk_size, chunk_overlap, expected): + splitter = CharacterSplitter( + separator="\n\n", + is_separator_regex=False, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + keep_separator=False, + ) + chunks = splitter.split_text(text) + assert chunks == expected diff --git a/tests/integration_tests/models/splitter/test_markdown_splitter.py b/tests/integration_tests/models/splitter/test_markdown_splitter.py new file mode 100644 index 00000000..d2f50ad7 --- /dev/null +++ b/tests/integration_tests/models/splitter/test_markdown_splitter.py @@ -0,0 +1,40 @@ +from graphgen.models.splitter.markdown_splitter import MarkdownTextRefSplitter + + +def test_split_markdown_structures(): + md = ( + "# Header1\n\n" + "Some introduction here.\n\n" + "## Header2\n\n" + "```python\nprint('hello')\n```\n" + "Paragraph under code block.\n\n" + "***\n" + "### Header3\n\n" + "More text after horizontal rule.\n\n" + "#### Header4\n\n" + "Final paragraph." + ) + + splitter = MarkdownTextRefSplitter( + chunk_size=120, + chunk_overlap=0, + keep_separator=True, + is_separator_regex=True, + ) + chunks = splitter.split_text(md) + assert len(chunks) > 1 + + for chk in chunks: + assert len(chk) <= 120 + + assert any("## Header2" in c for c in chunks) + assert any("***" in c for c in chunks) + assert any("```" in c for c in chunks) + + +def test_split_size_less_than_single_char(): + """极端情况:chunk_size 比任何单段都小,应仍能返回原文""" + short = "# A\n\nB" + splitter = MarkdownTextRefSplitter(chunk_size=1, chunk_overlap=0) + chunks = splitter.split_text(short) + assert "".join(chunks) == short diff --git a/tests/integration_tests/models/splitter/test_recursive_character_splitter.py b/tests/integration_tests/models/splitter/test_recursive_character_splitter.py new file mode 100644 index 00000000..7d104f3e --- /dev/null +++ b/tests/integration_tests/models/splitter/test_recursive_character_splitter.py @@ -0,0 +1,49 @@ +from graphgen.models.splitter.recursive_character_splitter import ( + ChineseRecursiveTextSplitter, + RecursiveCharacterSplitter, +) + + +def test_split_english_paragraph(): + text = ( + "Natural language processing (NLP) is a subfield of linguistics, computer science, " + "and artificial intelligence. It focuses on the interaction between computers and " + "humans through natural language. The ultimate objective of NLP is to read, decipher, " + "understand, and make sense of human languages in a manner that is valuable.\n\n" + "Most NLP techniques rely on machine learning." + ) + + splitter = RecursiveCharacterSplitter( + chunk_size=150, + chunk_overlap=0, + keep_separator=True, + is_separator_regex=False, + ) + chunks = splitter.split_text(text) + + assert len(chunks) > 1 + for chk in chunks: + assert len(chk) <= 150 + + +def test_split_chinese_with_punctuation(): + text = ( + "自然语言处理是人工智能的重要分支。它研究能实现人与计算机之间用自然语言" + "进行有效通信的各种理论和方法!融合语言学、计算机科学、数学于一体?" + "近年来,深度学习极大推动了NLP的发展;Transformer、BERT、GPT等模型层出不穷," + ",,,甚至出现了多模态大模型。\n\n" + "未来,NLP 将继续向通用人工智能迈进。" + ) + + splitter = ChineseRecursiveTextSplitter( + chunk_size=60, + chunk_overlap=0, + keep_separator=True, + is_separator_regex=True, + ) + chunks = splitter.split_text(text) + + assert len(chunks) > 1 + for chk in chunks: + assert len(chk) <= 60 + assert "\n\n\n" not in chk From b02307facbca74f78ea1f078a1dbf5add54b8331 Mon Sep 17 00:00:00 2001 From: chenzihong <58508660+ChenZiHong-Gavin@users.noreply.github.com> Date: Wed, 24 Sep 2025 16:17:23 +0800 Subject: [PATCH 4/9] Update graphgen/bases/datatypes.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- graphgen/bases/datatypes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/graphgen/bases/datatypes.py b/graphgen/bases/datatypes.py index fd7bc177..4cdc9d2a 100644 --- a/graphgen/bases/datatypes.py +++ b/graphgen/bases/datatypes.py @@ -1,11 +1,11 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field @dataclass class Chunk: id: str content: str - metadata: dict + metadata: dict = field(default_factory=dict) @dataclass From 86e9082935cb4667279e2278143f5ebe33284d9b Mon Sep 17 00:00:00 2001 From: chenzihong <58508660+ChenZiHong-Gavin@users.noreply.github.com> Date: Wed, 24 Sep 2025 16:17:44 +0800 Subject: [PATCH 5/9] Update tests/integration_tests/models/splitter/test_markdown_splitter.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../integration_tests/models/splitter/test_markdown_splitter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests/models/splitter/test_markdown_splitter.py b/tests/integration_tests/models/splitter/test_markdown_splitter.py index d2f50ad7..8d02e1b2 100644 --- a/tests/integration_tests/models/splitter/test_markdown_splitter.py +++ b/tests/integration_tests/models/splitter/test_markdown_splitter.py @@ -33,7 +33,7 @@ def test_split_markdown_structures(): def test_split_size_less_than_single_char(): - """极端情况:chunk_size 比任何单段都小,应仍能返回原文""" + """Edge case: chunk_size is smaller than any segment; should still return the original text.""" short = "# A\n\nB" splitter = MarkdownTextRefSplitter(chunk_size=1, chunk_overlap=0) chunks = splitter.split_text(short) From 797781d902ba39a18ce40179cd5f6574140ddd8e Mon Sep 17 00:00:00 2001 From: chenzihong <58508660+ChenZiHong-Gavin@users.noreply.github.com> Date: Wed, 24 Sep 2025 16:17:56 +0800 Subject: [PATCH 6/9] Update graphgen/models/splitter/recursive_character_splitter.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- graphgen/models/splitter/recursive_character_splitter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/graphgen/models/splitter/recursive_character_splitter.py b/graphgen/models/splitter/recursive_character_splitter.py index 78f82449..c9d7c543 100644 --- a/graphgen/models/splitter/recursive_character_splitter.py +++ b/graphgen/models/splitter/recursive_character_splitter.py @@ -96,7 +96,6 @@ def _split_text_with_regex_from_end( splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])] if len(_splits) % 2 == 1: splits += _splits[-1:] - # splits = [_splits[0]] + splits else: splits = re.split(separator, text) else: From 6a6cb34ec68465ca0fab6fa2af79165596906664 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Wed, 24 Sep 2025 17:40:06 +0800 Subject: [PATCH 7/9] feat(webui): update webui with splitter config --- graphgen/configs/__init__.py | 0 graphgen/configs/aggregated_config.yaml | 6 +++- graphgen/configs/atomic_config.yaml | 6 +++- graphgen/configs/cot_config.yaml | 6 +++- graphgen/configs/multi_hop_config.yaml | 6 +++- graphgen/graphgen.py | 27 ++++++++++------- graphgen/models/__init__.py | 28 +----------------- graphgen/models/splitter/__init__.py | 31 ++++++++++++++++++++ webui/app.py | 39 +++++++++++++++++-------- webui/base.py | 3 +- 10 files changed, 97 insertions(+), 55 deletions(-) create mode 100644 graphgen/configs/__init__.py diff --git a/graphgen/configs/__init__.py b/graphgen/configs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/graphgen/configs/aggregated_config.yaml b/graphgen/configs/aggregated_config.yaml index a65cf2ac..bb444623 100644 --- a/graphgen/configs/aggregated_config.yaml +++ b/graphgen/configs/aggregated_config.yaml @@ -1,4 +1,8 @@ -input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples +read: + input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples +split: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting output_data_type: aggregated # atomic, aggregated, multi_hop, cot output_data_format: ChatML # Alpaca, Sharegpt, ChatML tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path diff --git a/graphgen/configs/atomic_config.yaml b/graphgen/configs/atomic_config.yaml index 4e8c4e29..009fcc67 100644 --- a/graphgen/configs/atomic_config.yaml +++ b/graphgen/configs/atomic_config.yaml @@ -1,4 +1,8 @@ -input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples +read: + input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples +split: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting output_data_type: atomic # atomic, aggregated, multi_hop, cot output_data_format: Alpaca # Alpaca, Sharegpt, ChatML tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path diff --git a/graphgen/configs/cot_config.yaml b/graphgen/configs/cot_config.yaml index 6aa6bf52..e4b0db38 100644 --- a/graphgen/configs/cot_config.yaml +++ b/graphgen/configs/cot_config.yaml @@ -1,4 +1,8 @@ -input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples +read: + input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples +split: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting output_data_type: cot # atomic, aggregated, multi_hop, cot output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path diff --git a/graphgen/configs/multi_hop_config.yaml b/graphgen/configs/multi_hop_config.yaml index 02e5e787..7a1f52d5 100644 --- a/graphgen/configs/multi_hop_config.yaml +++ b/graphgen/configs/multi_hop_config.yaml @@ -1,4 +1,8 @@ -input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples +read: + input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples +split: + chunk_size: 1024 # chunk size for text splitting + chunk_overlap: 100 # chunk overlap for text splitting output_data_type: multi_hop # atomic, aggregated, multi_hop, cot output_data_format: ChatML # Alpaca, Sharegpt, ChatML tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index 54f93b2b..ff05cd83 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -17,6 +17,7 @@ Tokenizer, TraverseStrategy, read_file, + split_chunks, ) from .operators import ( @@ -32,6 +33,7 @@ from .utils import ( compute_content_hash, create_event_loop, + detect_main_language, format_generation_results, logger, ) @@ -50,11 +52,6 @@ class GraphGen: synthesizer_llm_client: OpenAIModel = None trainee_llm_client: OpenAIModel = None - # text chunking - # TODO: make it configurable - chunk_size: int = 1024 - chunk_overlap_size: int = 100 - # search search_config: dict = field( default_factory=lambda: {"enabled": False, "search_types": ["wikipedia"]} @@ -136,14 +133,22 @@ async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict: async for doc_key, doc in tqdm_async( new_docs.items(), desc="[1/4]Chunking documents", unit="doc" ): + doc_language = detect_main_language(doc["content"]) + text_chunks = split_chunks( + doc["content"], + language=doc_language, + chunk_size=self.config["split"]["chunk_size"], + chunk_overlap=self.config["split"]["chunk_overlap"], + ) + chunks = { - compute_content_hash(dp["content"], prefix="chunk-"): { - **dp, + compute_content_hash(txt, prefix="chunk-"): { + "content": txt, "full_doc_id": doc_key, + "length": len(self.tokenizer_instance.encode_string(txt)), + "language": "en", } - for dp in self.tokenizer_instance.chunk_by_token_size( - doc["content"], self.chunk_overlap_size, self.chunk_size - ) + for txt in text_chunks } inserting_chunks.update(chunks) @@ -171,7 +176,7 @@ async def async_insert(self): insert chunks into the graph """ - input_file = self.config["input_file"] + input_file = self.config["read"]["input_file"] data = read_file(input_file) inserting_chunks = await self.async_split_chunks(data) diff --git a/graphgen/models/__init__.py b/graphgen/models/__init__.py index d555fcc9..9650909a 100644 --- a/graphgen/models/__init__.py +++ b/graphgen/models/__init__.py @@ -11,33 +11,7 @@ from .search.kg.wiki_search import WikiSearch from .search.web.bing_search import BingSearch from .search.web.google_search import GoogleSearch +from .splitter import split_chunks from .storage.json_storage import JsonKVStorage, JsonListStorage from .storage.networkx_storage import NetworkXStorage from .strategy.travserse_strategy import TraverseStrategy - -__all__ = [ - # llm models - "OpenAIModel", - "TopkTokenModel", - "Token", - "Tokenizer", - # storage models - "NetworkXStorage", - "JsonKVStorage", - "JsonListStorage", - # search models - "WikiSearch", - "GoogleSearch", - "BingSearch", - "UniProtSearch", - # evaluate models - "LengthEvaluator", - "MTLDEvaluator", - "RewardEvaluator", - "UniEvaluator", - # strategy models - "TraverseStrategy", - # community models - "CommunityDetector", - "read_file", -] diff --git a/graphgen/models/splitter/__init__.py b/graphgen/models/splitter/__init__.py index e69de29b..0743654a 100644 --- a/graphgen/models/splitter/__init__.py +++ b/graphgen/models/splitter/__init__.py @@ -0,0 +1,31 @@ +from functools import lru_cache +from typing import Union + +from .recursive_character_splitter import ( + ChineseRecursiveTextSplitter, + RecursiveCharacterSplitter, +) + +_MAPPING = { + "en": RecursiveCharacterSplitter, + "zh": ChineseRecursiveTextSplitter, +} + +SplitterT = Union[RecursiveCharacterSplitter, ChineseRecursiveTextSplitter] + + +@lru_cache(maxsize=None) +def _get_splitter(language: str, frozen_kwargs: frozenset) -> SplitterT: + cls = _MAPPING[language] + kwargs = dict(frozen_kwargs) + return cls(**kwargs) + + +def split_chunks(text: str, language: str = "en", **kwargs) -> list: + if language not in _MAPPING: + raise ValueError( + f"Unsupported language: {language}. " + f"Supported languages are: {list(_MAPPING.keys())}" + ) + splitter = _get_splitter(language, frozenset(kwargs.items())) + return splitter.split_text(text) diff --git a/webui/app.py b/webui/app.py index 6822405f..07f239f6 100644 --- a/webui/app.py +++ b/webui/app.py @@ -12,7 +12,7 @@ from graphgen.models import OpenAIModel, Tokenizer from graphgen.models.llm.limitter import RPM, TPM from graphgen.utils import set_logger -from webui.base import GraphGenParams +from webui.base import WebuiParams from webui.cache_utils import cleanup_workspace, setup_workspace from webui.count_tokens import count_tokens from webui.i18n import Translate @@ -66,13 +66,19 @@ def init_graph_gen(config: dict, env: dict) -> GraphGen: # pylint: disable=too-many-statements -def run_graphgen(params, progress=gr.Progress()): +def run_graphgen(params: WebuiParams, progress=gr.Progress()): def sum_tokens(client): return sum(u["total_tokens"] for u in client.token_usage) config = { "if_trainee_model": params.if_trainee_model, - "input_file": params.input_file, + "read": { + "input_file": params.input_file, + }, + "split": { + "chunk_size": params.chunk_size, + "chunk_overlap": params.chunk_overlap, + }, "output_data_type": params.output_data_type, "output_data_format": params.output_data_format, "tokenizer": params.tokenizer, @@ -91,7 +97,6 @@ def sum_tokens(client): "isolated_node_strategy": params.isolated_node_strategy, "loss_strategy": params.loss_strategy, }, - "chunk_size": params.chunk_size, } env = { @@ -284,10 +289,18 @@ def sum_tokens(client): label="Chunk Size", minimum=256, maximum=4096, - value=512, + value=1024, step=256, interactive=True, ) + chunk_overlap = gr.Slider( + label="Chunk Overlap", + minimum=0, + maximum=500, + value=100, + step=100, + interactive=True, + ) tokenizer = gr.Textbox( label="Tokenizer", value="cl100k_base", interactive=True ) @@ -499,7 +512,7 @@ def sum_tokens(client): submit_btn.click( lambda *args: run_graphgen( - GraphGenParams( + WebuiParams( if_trainee_model=args[0], input_file=args[1], tokenizer=args[2], @@ -518,12 +531,13 @@ def sum_tokens(client): trainee_model=args[15], api_key=args[16], chunk_size=args[17], - rpm=args[18], - tpm=args[19], - quiz_samples=args[20], - trainee_url=args[21], - trainee_api_key=args[22], - token_counter=args[23], + chunk_overlap=args[18], + rpm=args[19], + tpm=args[20], + quiz_samples=args[21], + trainee_url=args[22], + trainee_api_key=args[23], + token_counter=args[24], ) ), inputs=[ @@ -545,6 +559,7 @@ def sum_tokens(client): trainee_model, api_key, chunk_size, + chunk_overlap, rpm, tpm, quiz_samples, diff --git a/webui/base.py b/webui/base.py index f87d7d9b..95fd7c22 100644 --- a/webui/base.py +++ b/webui/base.py @@ -3,7 +3,7 @@ @dataclass -class GraphGenParams: +class WebuiParams: """ GraphGen parameters """ @@ -26,6 +26,7 @@ class GraphGenParams: trainee_model: str api_key: str chunk_size: int + chunk_overlap: int rpm: int tpm: int quiz_samples: int From d4392620355a66bbb3afcd07cb21313050a1cbff Mon Sep 17 00:00:00 2001 From: chenzihong <58508660+ChenZiHong-Gavin@users.noreply.github.com> Date: Wed, 24 Sep 2025 17:49:33 +0800 Subject: [PATCH 8/9] Update graphgen/models/splitter/markdown_splitter.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- graphgen/models/splitter/markdown_splitter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/graphgen/models/splitter/markdown_splitter.py b/graphgen/models/splitter/markdown_splitter.py index baa8620b..03def6ae 100644 --- a/graphgen/models/splitter/markdown_splitter.py +++ b/graphgen/models/splitter/markdown_splitter.py @@ -22,8 +22,9 @@ def __init__(self, **kwargs: Any) -> None: "\n\\*\\*\\*+\n", "\n---+\n", "\n___+\n", - # Note that this splitter doesn't handle horizontal lines defined - # by *three or more* of ***, ---, or ___, but this is not handled + # Note: horizontal lines defined by three or more of ***, ---, or ___ + # are handled by the regexes above, but alternative syntaxes (e.g., with spaces) + # are not handled. "\n\n", "\n", " ", From fdaef0e336a545013ca580ba4c962f2aef918d4c Mon Sep 17 00:00:00 2001 From: chenzihong <58508660+ChenZiHong-Gavin@users.noreply.github.com> Date: Wed, 24 Sep 2025 17:49:59 +0800 Subject: [PATCH 9/9] Update graphgen/graphgen.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- graphgen/graphgen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphgen/graphgen.py b/graphgen/graphgen.py index ff05cd83..44d530bc 100644 --- a/graphgen/graphgen.py +++ b/graphgen/graphgen.py @@ -146,7 +146,7 @@ async def async_split_chunks(self, data: List[Union[List, Dict]]) -> dict: "content": txt, "full_doc_id": doc_key, "length": len(self.tokenizer_instance.encode_string(txt)), - "language": "en", + "language": doc_language, } for txt in text_chunks }