From b2184489eea6550be469c0bf4bf42a920fa58a70 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 18 Oct 2025 07:40:49 +0000 Subject: [PATCH 1/9] Initial plan From 35fbbfe6ea4649600ccc492b84ef8e68253f1e38 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 18 Oct 2025 07:49:43 +0000 Subject: [PATCH 2/9] Remove @dataclass from ABC base classes and fix search classes Co-authored-by: ChenZiHong-Gavin <58508660+ChenZiHong-Gavin@users.noreply.github.com> --- graphgen/bases/base_generator.py | 5 ++- graphgen/bases/base_kg_builder.py | 12 +++---- graphgen/bases/base_partitioner.py | 2 -- graphgen/bases/base_splitter.py | 23 ++++++++----- graphgen/bases/base_storage.py | 10 ++---- graphgen/bases/base_tokenizer.py | 5 ++- graphgen/models/evaluator/base_evaluator.py | 7 ++-- graphgen/models/llm/topk_token_model.py | 36 ++++++++++++++------- graphgen/models/search/db/uniprot_search.py | 3 -- graphgen/models/search/kg/wiki_search.py | 2 -- graphgen/models/search/web/bing_search.py | 6 ++-- graphgen/models/search/web/google_search.py | 3 -- 12 files changed, 55 insertions(+), 59 deletions(-) diff --git a/graphgen/bases/base_generator.py b/graphgen/bases/base_generator.py index 271f20f..d6148ce 100644 --- a/graphgen/bases/base_generator.py +++ b/graphgen/bases/base_generator.py @@ -1,17 +1,16 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Any from graphgen.bases.base_llm_client import BaseLLMClient -@dataclass class BaseGenerator(ABC): """ Generate QAs based on given prompts. """ - llm_client: BaseLLMClient + def __init__(self, llm_client: BaseLLMClient): + self.llm_client = llm_client @staticmethod @abstractmethod diff --git a/graphgen/bases/base_kg_builder.py b/graphgen/bases/base_kg_builder.py index af15486..e234d8d 100644 --- a/graphgen/bases/base_kg_builder.py +++ b/graphgen/bases/base_kg_builder.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections import defaultdict -from dataclasses import dataclass, field from typing import Dict, List, Tuple from graphgen.bases.base_llm_client import BaseLLMClient @@ -8,14 +7,11 @@ from graphgen.bases.datatypes import Chunk -@dataclass class BaseKGBuilder(ABC): - llm_client: BaseLLMClient - - _nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list)) - _edges: Dict[Tuple[str, str], List[dict]] = field( - default_factory=lambda: defaultdict(list) - ) + def __init__(self, llm_client: BaseLLMClient): + self.llm_client = llm_client + self._nodes: Dict[str, List[dict]] = defaultdict(list) + self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list) @abstractmethod async def extract( diff --git a/graphgen/bases/base_partitioner.py b/graphgen/bases/base_partitioner.py index a3739e5..78baddd 100644 --- a/graphgen/bases/base_partitioner.py +++ b/graphgen/bases/base_partitioner.py @@ -1,12 +1,10 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Any, List from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Community -@dataclass class BasePartitioner(ABC): @abstractmethod async def partition( diff --git a/graphgen/bases/base_splitter.py b/graphgen/bases/base_splitter.py index dcec006..b2d1ad3 100644 --- a/graphgen/bases/base_splitter.py +++ b/graphgen/bases/base_splitter.py @@ -1,25 +1,32 @@ import copy import re from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Callable, Iterable, List, Literal, Optional, Union from graphgen.bases.datatypes import Chunk from graphgen.utils import logger -@dataclass class BaseSplitter(ABC): """ Abstract base class for splitting text into smaller chunks. """ - chunk_size: int = 1024 - chunk_overlap: int = 100 - length_function: Callable[[str], int] = len - keep_separator: bool = False - add_start_index: bool = False - strip_whitespace: bool = True + def __init__( + self, + chunk_size: int = 1024, + chunk_overlap: int = 100, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, + strip_whitespace: bool = True, + ): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.length_function = length_function + self.keep_separator = keep_separator + self.add_start_index = add_start_index + self.strip_whitespace = strip_whitespace @abstractmethod def split_text(self, text: str) -> List[str]: diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index 6968dca..2a2e0b1 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -1,13 +1,12 @@ -from dataclasses import dataclass from typing import Generic, TypeVar, Union T = TypeVar("T") -@dataclass class StorageNameSpace: - working_dir: str = None - namespace: str = None + def __init__(self, working_dir: str = None, namespace: str = None): + self.working_dir = working_dir + self.namespace = namespace async def index_done_callback(self): """commit the storage operations after indexing""" @@ -16,7 +15,6 @@ async def query_done_callback(self): """commit the storage operations after querying""" -@dataclass class BaseListStorage(Generic[T], StorageNameSpace): async def all_items(self) -> list[T]: raise NotImplementedError @@ -34,7 +32,6 @@ async def drop(self): raise NotImplementedError -@dataclass class BaseKVStorage(Generic[T], StorageNameSpace): async def all_keys(self) -> list[str]: raise NotImplementedError @@ -58,7 +55,6 @@ async def drop(self): raise NotImplementedError -@dataclass class BaseGraphStorage(StorageNameSpace): async def has_node(self, node_id: str) -> bool: raise NotImplementedError diff --git a/graphgen/bases/base_tokenizer.py b/graphgen/bases/base_tokenizer.py index 958b142..346d500 100644 --- a/graphgen/bases/base_tokenizer.py +++ b/graphgen/bases/base_tokenizer.py @@ -1,13 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import List -@dataclass class BaseTokenizer(ABC): - model_name: str = "cl100k_base" + def __init__(self, model_name: str = "cl100k_base"): + self.model_name = model_name @abstractmethod def encode(self, text: str) -> List[int]: diff --git a/graphgen/models/evaluator/base_evaluator.py b/graphgen/models/evaluator/base_evaluator.py index 1359be6..e24cfa4 100644 --- a/graphgen/models/evaluator/base_evaluator.py +++ b/graphgen/models/evaluator/base_evaluator.py @@ -1,5 +1,4 @@ import asyncio -from dataclasses import dataclass from tqdm.asyncio import tqdm as tqdm_async @@ -7,10 +6,10 @@ from graphgen.utils import create_event_loop -@dataclass class BaseEvaluator: - max_concurrent: int = 100 - results: list[float] = None + def __init__(self, max_concurrent: int = 100): + self.max_concurrent = max_concurrent + self.results: list[float] = None def evaluate(self, pairs: list[QAPair]) -> list[float]: """ diff --git a/graphgen/models/llm/topk_token_model.py b/graphgen/models/llm/topk_token_model.py index 94719cf..e93ca01 100644 --- a/graphgen/models/llm/topk_token_model.py +++ b/graphgen/models/llm/topk_token_model.py @@ -1,21 +1,31 @@ -from dataclasses import dataclass +from abc import ABC, abstractmethod from typing import List, Optional from graphgen.bases import Token -@dataclass -class TopkTokenModel: - do_sample: bool = False - temperature: float = 0 - max_tokens: int = 4096 - repetition_penalty: float = 1.05 - num_beams: int = 1 - topk: int = 50 - topp: float = 0.95 - - topk_per_token: int = 5 # number of topk tokens to generate for each token +class TopkTokenModel(ABC): + def __init__( + self, + do_sample: bool = False, + temperature: float = 0, + max_tokens: int = 4096, + repetition_penalty: float = 1.05, + num_beams: int = 1, + topk: int = 50, + topp: float = 0.95, + topk_per_token: int = 5, + ): + self.do_sample = do_sample + self.temperature = temperature + self.max_tokens = max_tokens + self.repetition_penalty = repetition_penalty + self.num_beams = num_beams + self.topk = topk + self.topp = topp + self.topk_per_token = topk_per_token + @abstractmethod async def generate_topk_per_token(self, text: str) -> List[Token]: """ Generate prob, text and candidates for each token of the model's output. @@ -23,6 +33,7 @@ async def generate_topk_per_token(self, text: str) -> List[Token]: """ raise NotImplementedError + @abstractmethod async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None ) -> List[Token]: @@ -32,6 +43,7 @@ async def generate_inputs_prob( """ raise NotImplementedError + @abstractmethod async def generate_answer( self, text: str, history: Optional[List[str]] = None ) -> str: diff --git a/graphgen/models/search/db/uniprot_search.py b/graphgen/models/search/db/uniprot_search.py index 96bdd99..daf4224 100644 --- a/graphgen/models/search/db/uniprot_search.py +++ b/graphgen/models/search/db/uniprot_search.py @@ -1,5 +1,3 @@ -from dataclasses import dataclass - import requests from fastapi import HTTPException @@ -8,7 +6,6 @@ UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search" -@dataclass class UniProtSearch: """ UniProt Search client to search with UniProt. diff --git a/graphgen/models/search/kg/wiki_search.py b/graphgen/models/search/kg/wiki_search.py index e9513f2..2d8686c 100644 --- a/graphgen/models/search/kg/wiki_search.py +++ b/graphgen/models/search/kg/wiki_search.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import List, Union import wikipedia @@ -7,7 +6,6 @@ from graphgen.utils import detect_main_language, logger -@dataclass class WikiSearch: @staticmethod def set_language(language: str): diff --git a/graphgen/models/search/web/bing_search.py b/graphgen/models/search/web/bing_search.py index a769ba7..d52815d 100644 --- a/graphgen/models/search/web/bing_search.py +++ b/graphgen/models/search/web/bing_search.py @@ -1,5 +1,3 @@ -from dataclasses import dataclass - import requests from fastapi import HTTPException @@ -9,13 +7,13 @@ BING_MKT = "en-US" -@dataclass class BingSearch: """ Bing Search client to search with Bing. """ - subscription_key: str + def __init__(self, subscription_key: str): + self.subscription_key = subscription_key def search(self, query: str, num_results: int = 1): """ diff --git a/graphgen/models/search/web/google_search.py b/graphgen/models/search/web/google_search.py index 1abfcdf..0b04572 100644 --- a/graphgen/models/search/web/google_search.py +++ b/graphgen/models/search/web/google_search.py @@ -1,5 +1,3 @@ -from dataclasses import dataclass - import requests from fastapi import HTTPException @@ -8,7 +6,6 @@ GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1" -@dataclass class GoogleSearch: def __init__(self, subscription_key: str, cx: str): """ From 5202d73def72d35d0b474782c23346a84127a186 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 18 Oct 2025 07:52:45 +0000 Subject: [PATCH 3/9] Fix storage implementations to declare parent fields Co-authored-by: ChenZiHong-Gavin <58508660+ChenZiHong-Gavin@users.noreply.github.com> --- graphgen/models/storage/json_storage.py | 4 ++++ graphgen/models/storage/networkx_storage.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/json_storage.py index e37d033..d305ed2 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/json_storage.py @@ -7,6 +7,8 @@ @dataclass class JsonKVStorage(BaseKVStorage): + working_dir: str = None + namespace: str = None _data: dict[str, str] = None def __post_init__(self): @@ -53,6 +55,8 @@ async def drop(self): @dataclass class JsonListStorage(BaseListStorage): + working_dir: str = None + namespace: str = None _data: list = None def __post_init__(self): diff --git a/graphgen/models/storage/networkx_storage.py b/graphgen/models/storage/networkx_storage.py index 539ab84..824c72c 100644 --- a/graphgen/models/storage/networkx_storage.py +++ b/graphgen/models/storage/networkx_storage.py @@ -11,6 +11,9 @@ @dataclass class NetworkXStorage(BaseGraphStorage): + working_dir: str = None + namespace: str = None + @staticmethod def load_nx_graph(file_name) -> Optional[nx.Graph]: if os.path.exists(file_name): From a5ed2f46c05e15664053d0177f275a6f2e553d72 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 20 Oct 2025 08:57:47 +0000 Subject: [PATCH 4/9] Add missing field declarations to dataclass child classes Co-authored-by: ChenZiHong-Gavin <58508660+ChenZiHong-Gavin@users.noreply.github.com> --- graphgen/models/generator/aggregated_generator.py | 4 +++- graphgen/models/generator/atomic_generator.py | 4 +++- graphgen/models/generator/cot_generator.py | 4 +++- graphgen/models/generator/multi_hop_generator.py | 4 +++- graphgen/models/tokenizer/hf_tokenizer.py | 2 ++ graphgen/models/tokenizer/tiktoken_tokenizer.py | 2 ++ 6 files changed, 16 insertions(+), 4 deletions(-) diff --git a/graphgen/models/generator/aggregated_generator.py b/graphgen/models/generator/aggregated_generator.py index 37c54c7..8bb1236 100644 --- a/graphgen/models/generator/aggregated_generator.py +++ b/graphgen/models/generator/aggregated_generator.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Any -from graphgen.bases import BaseGenerator +from graphgen.bases import BaseGenerator, BaseLLMClient from graphgen.templates import AGGREGATED_GENERATION_PROMPT from graphgen.utils import compute_content_hash, detect_main_language, logger @@ -15,6 +15,8 @@ class AggregatedGenerator(BaseGenerator): 2. question generation: Generate relevant questions based on the rephrased text. """ + llm_client: BaseLLMClient = None + @staticmethod def build_prompt( batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] diff --git a/graphgen/models/generator/atomic_generator.py b/graphgen/models/generator/atomic_generator.py index cb566fd..5017a07 100644 --- a/graphgen/models/generator/atomic_generator.py +++ b/graphgen/models/generator/atomic_generator.py @@ -1,13 +1,15 @@ from dataclasses import dataclass from typing import Any -from graphgen.bases import BaseGenerator +from graphgen.bases import BaseGenerator, BaseLLMClient from graphgen.templates import ATOMIC_GENERATION_PROMPT from graphgen.utils import compute_content_hash, detect_main_language, logger @dataclass class AtomicGenerator(BaseGenerator): + llm_client: BaseLLMClient = None + @staticmethod def build_prompt( batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] diff --git a/graphgen/models/generator/cot_generator.py b/graphgen/models/generator/cot_generator.py index 2fc4fe8..a4aed4c 100644 --- a/graphgen/models/generator/cot_generator.py +++ b/graphgen/models/generator/cot_generator.py @@ -1,13 +1,15 @@ from dataclasses import dataclass from typing import Any -from graphgen.bases import BaseGenerator +from graphgen.bases import BaseGenerator, BaseLLMClient from graphgen.templates import COT_GENERATION_PROMPT from graphgen.utils import compute_content_hash, detect_main_language, logger @dataclass class CoTGenerator(BaseGenerator): + llm_client: BaseLLMClient = None + @staticmethod def build_prompt( batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] diff --git a/graphgen/models/generator/multi_hop_generator.py b/graphgen/models/generator/multi_hop_generator.py index 257fc1d..b938909 100644 --- a/graphgen/models/generator/multi_hop_generator.py +++ b/graphgen/models/generator/multi_hop_generator.py @@ -1,13 +1,15 @@ from dataclasses import dataclass from typing import Any -from graphgen.bases import BaseGenerator +from graphgen.bases import BaseGenerator, BaseLLMClient from graphgen.templates import MULTI_HOP_GENERATION_PROMPT from graphgen.utils import compute_content_hash, detect_main_language, logger @dataclass class MultiHopGenerator(BaseGenerator): + llm_client: BaseLLMClient = None + @staticmethod def build_prompt( batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] diff --git a/graphgen/models/tokenizer/hf_tokenizer.py b/graphgen/models/tokenizer/hf_tokenizer.py index e5511a9..be3af5e 100644 --- a/graphgen/models/tokenizer/hf_tokenizer.py +++ b/graphgen/models/tokenizer/hf_tokenizer.py @@ -8,6 +8,8 @@ @dataclass class HFTokenizer(BaseTokenizer): + model_name: str = "cl100k_base" + def __post_init__(self): self.enc = AutoTokenizer.from_pretrained(self.model_name) diff --git a/graphgen/models/tokenizer/tiktoken_tokenizer.py b/graphgen/models/tokenizer/tiktoken_tokenizer.py index 3c84edd..3fe49bc 100644 --- a/graphgen/models/tokenizer/tiktoken_tokenizer.py +++ b/graphgen/models/tokenizer/tiktoken_tokenizer.py @@ -8,6 +8,8 @@ @dataclass class TiktokenTokenizer(BaseTokenizer): + model_name: str = "cl100k_base" + def __post_init__(self): self.enc = tiktoken.get_encoding(self.model_name) From c0df1439554d1354be22294c3a12df445f16ea77 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 23 Oct 2025 11:21:51 +0000 Subject: [PATCH 5/9] Fix dataclass decorator issues in new VQA and MM modules Co-authored-by: ChenZiHong-Gavin <58508660+ChenZiHong-Gavin@users.noreply.github.com> --- graphgen/models/generator/vqa_generator.py | 4 +++- graphgen/models/kg_builder/mm_kg_builder.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/graphgen/models/generator/vqa_generator.py b/graphgen/models/generator/vqa_generator.py index b0c29d2..fb76276 100644 --- a/graphgen/models/generator/vqa_generator.py +++ b/graphgen/models/generator/vqa_generator.py @@ -1,13 +1,15 @@ from dataclasses import dataclass from typing import Any -from graphgen.bases import BaseGenerator +from graphgen.bases import BaseGenerator, BaseLLMClient from graphgen.templates import VQA_GENERATION_PROMPT from graphgen.utils import compute_content_hash, detect_main_language, logger @dataclass class VQAGenerator(BaseGenerator): + llm_client: BaseLLMClient = None + @staticmethod def build_prompt( batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] diff --git a/graphgen/models/kg_builder/mm_kg_builder.py b/graphgen/models/kg_builder/mm_kg_builder.py index c554729..f633bfc 100644 --- a/graphgen/models/kg_builder/mm_kg_builder.py +++ b/graphgen/models/kg_builder/mm_kg_builder.py @@ -1,5 +1,6 @@ import re from collections import defaultdict +from dataclasses import dataclass from typing import Dict, List, Tuple from graphgen.bases import BaseLLMClient, Chunk @@ -15,8 +16,10 @@ from .light_rag_kg_builder import LightRAGKGBuilder +@dataclass class MMKGBuilder(LightRAGKGBuilder): llm_client: BaseLLMClient = None + max_loop: int = 3 async def extract( self, chunk: Chunk From e5954e98aa6c9cb6eea1ea7abf034b84d983b096 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 23 Oct 2025 19:35:29 +0800 Subject: [PATCH 6/9] fix: fix dataclass decorator for generators --- graphgen/models/generator/aggregated_generator.py | 6 +----- graphgen/models/generator/atomic_generator.py | 6 +----- graphgen/models/generator/cot_generator.py | 6 +----- graphgen/models/generator/multi_hop_generator.py | 6 +----- graphgen/models/generator/vqa_generator.py | 6 +----- 5 files changed, 5 insertions(+), 25 deletions(-) diff --git a/graphgen/models/generator/aggregated_generator.py b/graphgen/models/generator/aggregated_generator.py index 2fd585a..4bad8e9 100644 --- a/graphgen/models/generator/aggregated_generator.py +++ b/graphgen/models/generator/aggregated_generator.py @@ -1,12 +1,10 @@ -from dataclasses import dataclass from typing import Any -from graphgen.bases import BaseGenerator, BaseLLMClient +from graphgen.bases import BaseGenerator from graphgen.templates import AGGREGATED_GENERATION_PROMPT from graphgen.utils import compute_content_hash, detect_main_language, logger -@dataclass class AggregatedGenerator(BaseGenerator): """ Aggregated Generator follows a TWO-STEP process: @@ -15,8 +13,6 @@ class AggregatedGenerator(BaseGenerator): 2. question generation: Generate relevant questions based on the rephrased text. """ - llm_client: BaseLLMClient = None - @staticmethod def build_prompt( batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] diff --git a/graphgen/models/generator/atomic_generator.py b/graphgen/models/generator/atomic_generator.py index e086926..713140d 100644 --- a/graphgen/models/generator/atomic_generator.py +++ b/graphgen/models/generator/atomic_generator.py @@ -1,15 +1,11 @@ -from dataclasses import dataclass from typing import Any -from graphgen.bases import BaseGenerator, BaseLLMClient +from graphgen.bases import BaseGenerator from graphgen.templates import ATOMIC_GENERATION_PROMPT from graphgen.utils import compute_content_hash, detect_main_language, logger -@dataclass class AtomicGenerator(BaseGenerator): - llm_client: BaseLLMClient = None - @staticmethod def build_prompt( batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] diff --git a/graphgen/models/generator/cot_generator.py b/graphgen/models/generator/cot_generator.py index 8b94010..a111a6f 100644 --- a/graphgen/models/generator/cot_generator.py +++ b/graphgen/models/generator/cot_generator.py @@ -1,15 +1,11 @@ -from dataclasses import dataclass from typing import Any -from graphgen.bases import BaseGenerator, BaseLLMClient +from graphgen.bases import BaseGenerator from graphgen.templates import COT_GENERATION_PROMPT from graphgen.utils import compute_content_hash, detect_main_language, logger -@dataclass class CoTGenerator(BaseGenerator): - llm_client: BaseLLMClient = None - @staticmethod def build_prompt( batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] diff --git a/graphgen/models/generator/multi_hop_generator.py b/graphgen/models/generator/multi_hop_generator.py index 04e16f0..9098b10 100644 --- a/graphgen/models/generator/multi_hop_generator.py +++ b/graphgen/models/generator/multi_hop_generator.py @@ -1,15 +1,11 @@ -from dataclasses import dataclass from typing import Any -from graphgen.bases import BaseGenerator, BaseLLMClient +from graphgen.bases import BaseGenerator from graphgen.templates import MULTI_HOP_GENERATION_PROMPT from graphgen.utils import compute_content_hash, detect_main_language, logger -@dataclass class MultiHopGenerator(BaseGenerator): - llm_client: BaseLLMClient = None - @staticmethod def build_prompt( batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] diff --git a/graphgen/models/generator/vqa_generator.py b/graphgen/models/generator/vqa_generator.py index fb76276..eefbdd1 100644 --- a/graphgen/models/generator/vqa_generator.py +++ b/graphgen/models/generator/vqa_generator.py @@ -1,15 +1,11 @@ -from dataclasses import dataclass from typing import Any -from graphgen.bases import BaseGenerator, BaseLLMClient +from graphgen.bases import BaseGenerator from graphgen.templates import VQA_GENERATION_PROMPT from graphgen.utils import compute_content_hash, detect_main_language, logger -@dataclass class VQAGenerator(BaseGenerator): - llm_client: BaseLLMClient = None - @staticmethod def build_prompt( batch: tuple[list[tuple[str, dict]], list[tuple[Any, Any, dict]]] From 0ed7f4925805b017dcbf4a3870dd257b44b9f9ba Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Thu, 23 Oct 2025 11:46:23 +0000 Subject: [PATCH 7/9] Remove @dataclass from all subclasses following generator pattern Co-authored-by: ChenZiHong-Gavin <58508660+ChenZiHong-Gavin@users.noreply.github.com> --- graphgen/models/evaluator/length_evaluator.py | 9 +++------ graphgen/models/evaluator/mtld_evaluator.py | 12 ++++-------- graphgen/models/kg_builder/light_rag_kg_builder.py | 7 +++---- graphgen/models/kg_builder/mm_kg_builder.py | 5 ----- graphgen/models/partitioner/bfs_partitioner.py | 2 -- graphgen/models/partitioner/dfs_partitioner.py | 2 -- graphgen/models/partitioner/ece_partitioner.py | 2 -- graphgen/models/partitioner/leiden_partitioner.py | 2 -- graphgen/models/tokenizer/__init__.py | 8 ++------ graphgen/models/tokenizer/hf_tokenizer.py | 7 ++----- graphgen/models/tokenizer/tiktoken_tokenizer.py | 7 ++----- 11 files changed, 16 insertions(+), 47 deletions(-) diff --git a/graphgen/models/evaluator/length_evaluator.py b/graphgen/models/evaluator/length_evaluator.py index a7e9989..d5c3321 100644 --- a/graphgen/models/evaluator/length_evaluator.py +++ b/graphgen/models/evaluator/length_evaluator.py @@ -1,16 +1,13 @@ -from dataclasses import dataclass - from graphgen.bases.datatypes import QAPair from graphgen.models.evaluator.base_evaluator import BaseEvaluator from graphgen.models.tokenizer import Tokenizer from graphgen.utils import create_event_loop -@dataclass class LengthEvaluator(BaseEvaluator): - tokenizer_name: str = "cl100k_base" - - def __post_init__(self): + def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100): + super().__init__(max_concurrent) + self.tokenizer_name = tokenizer_name self.tokenizer = Tokenizer(model_name=self.tokenizer_name) async def evaluate_single(self, pair: QAPair) -> float: diff --git a/graphgen/models/evaluator/mtld_evaluator.py b/graphgen/models/evaluator/mtld_evaluator.py index 79924fe..c106d86 100644 --- a/graphgen/models/evaluator/mtld_evaluator.py +++ b/graphgen/models/evaluator/mtld_evaluator.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, field from typing import Set from graphgen.bases.datatypes import QAPair @@ -8,18 +7,15 @@ nltk_helper = NLTKHelper() -@dataclass class MTLDEvaluator(BaseEvaluator): """ 衡量文本词汇多样性的指标 """ - stopwords_en: Set[str] = field( - default_factory=lambda: set(nltk_helper.get_stopwords("english")) - ) - stopwords_zh: Set[str] = field( - default_factory=lambda: set(nltk_helper.get_stopwords("chinese")) - ) + def __init__(self, max_concurrent: int = 100): + super().__init__(max_concurrent) + self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english")) + self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese")) async def evaluate_single(self, pair: QAPair) -> float: loop = create_event_loop() diff --git a/graphgen/models/kg_builder/light_rag_kg_builder.py b/graphgen/models/kg_builder/light_rag_kg_builder.py index e734eca..cde42d2 100644 --- a/graphgen/models/kg_builder/light_rag_kg_builder.py +++ b/graphgen/models/kg_builder/light_rag_kg_builder.py @@ -1,6 +1,5 @@ import re from collections import Counter, defaultdict -from dataclasses import dataclass from typing import Dict, List, Tuple from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk @@ -15,10 +14,10 @@ ) -@dataclass class LightRAGKGBuilder(BaseKGBuilder): - llm_client: BaseLLMClient = None - max_loop: int = 3 + def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3): + super().__init__(llm_client) + self.max_loop = max_loop async def extract( self, chunk: Chunk diff --git a/graphgen/models/kg_builder/mm_kg_builder.py b/graphgen/models/kg_builder/mm_kg_builder.py index f633bfc..0eeb394 100644 --- a/graphgen/models/kg_builder/mm_kg_builder.py +++ b/graphgen/models/kg_builder/mm_kg_builder.py @@ -1,6 +1,5 @@ import re from collections import defaultdict -from dataclasses import dataclass from typing import Dict, List, Tuple from graphgen.bases import BaseLLMClient, Chunk @@ -16,11 +15,7 @@ from .light_rag_kg_builder import LightRAGKGBuilder -@dataclass class MMKGBuilder(LightRAGKGBuilder): - llm_client: BaseLLMClient = None - max_loop: int = 3 - async def extract( self, chunk: Chunk ) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]: diff --git a/graphgen/models/partitioner/bfs_partitioner.py b/graphgen/models/partitioner/bfs_partitioner.py index 7cc5148..7b7b421 100644 --- a/graphgen/models/partitioner/bfs_partitioner.py +++ b/graphgen/models/partitioner/bfs_partitioner.py @@ -1,6 +1,5 @@ import random from collections import deque -from dataclasses import dataclass from typing import Any, List from graphgen.bases import BaseGraphStorage, BasePartitioner @@ -10,7 +9,6 @@ EDGE_UNIT: str = "e" -@dataclass class BFSPartitioner(BasePartitioner): """ BFS partitioner that partitions the graph into communities of a fixed size. diff --git a/graphgen/models/partitioner/dfs_partitioner.py b/graphgen/models/partitioner/dfs_partitioner.py index a9a64a9..01df509 100644 --- a/graphgen/models/partitioner/dfs_partitioner.py +++ b/graphgen/models/partitioner/dfs_partitioner.py @@ -1,5 +1,4 @@ import random -from dataclasses import dataclass from typing import Any, List from graphgen.bases import BaseGraphStorage, BasePartitioner @@ -9,7 +8,6 @@ EDGE_UNIT: str = "e" -@dataclass class DFSPartitioner(BasePartitioner): """ DFS partitioner that partitions the graph into communities of a fixed size. diff --git a/graphgen/models/partitioner/ece_partitioner.py b/graphgen/models/partitioner/ece_partitioner.py index 4352e62..e874f56 100644 --- a/graphgen/models/partitioner/ece_partitioner.py +++ b/graphgen/models/partitioner/ece_partitioner.py @@ -1,6 +1,5 @@ import asyncio import random -from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple from tqdm.asyncio import tqdm as tqdm_async @@ -13,7 +12,6 @@ EDGE_UNIT: str = "e" -@dataclass class ECEPartitioner(BFSPartitioner): """ ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE). diff --git a/graphgen/models/partitioner/leiden_partitioner.py b/graphgen/models/partitioner/leiden_partitioner.py index ffa38ae..28dfc1d 100644 --- a/graphgen/models/partitioner/leiden_partitioner.py +++ b/graphgen/models/partitioner/leiden_partitioner.py @@ -1,5 +1,4 @@ from collections import defaultdict -from dataclasses import dataclass from typing import Any, Dict, List, Set, Tuple import igraph as ig @@ -9,7 +8,6 @@ from graphgen.bases.datatypes import Community -@dataclass class LeidenPartitioner(BasePartitioner): """ Leiden partitioner that partitions the graph into communities using the Leiden algorithm. diff --git a/graphgen/models/tokenizer/__init__.py b/graphgen/models/tokenizer/__init__.py index 1df5039..6712f91 100644 --- a/graphgen/models/tokenizer/__init__.py +++ b/graphgen/models/tokenizer/__init__.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, field from typing import List from graphgen.bases import BaseTokenizer @@ -30,16 +29,13 @@ def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer: ) -@dataclass class Tokenizer(BaseTokenizer): """ Encapsulates different tokenization implementations based on the specified model name. """ - model_name: str = "cl100k_base" - _impl: BaseTokenizer = field(init=False, repr=False) - - def __post_init__(self): + def __init__(self, model_name: str = "cl100k_base"): + super().__init__(model_name) if not self.model_name: raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.") self._impl = get_tokenizer_impl(self.model_name) diff --git a/graphgen/models/tokenizer/hf_tokenizer.py b/graphgen/models/tokenizer/hf_tokenizer.py index be3af5e..c43ddd7 100644 --- a/graphgen/models/tokenizer/hf_tokenizer.py +++ b/graphgen/models/tokenizer/hf_tokenizer.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import List from transformers import AutoTokenizer @@ -6,11 +5,9 @@ from graphgen.bases import BaseTokenizer -@dataclass class HFTokenizer(BaseTokenizer): - model_name: str = "cl100k_base" - - def __post_init__(self): + def __init__(self, model_name: str = "cl100k_base"): + super().__init__(model_name) self.enc = AutoTokenizer.from_pretrained(self.model_name) def encode(self, text: str) -> List[int]: diff --git a/graphgen/models/tokenizer/tiktoken_tokenizer.py b/graphgen/models/tokenizer/tiktoken_tokenizer.py index 3fe49bc..6145d07 100644 --- a/graphgen/models/tokenizer/tiktoken_tokenizer.py +++ b/graphgen/models/tokenizer/tiktoken_tokenizer.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import List import tiktoken @@ -6,11 +5,9 @@ from graphgen.bases import BaseTokenizer -@dataclass class TiktokenTokenizer(BaseTokenizer): - model_name: str = "cl100k_base" - - def __post_init__(self): + def __init__(self, model_name: str = "cl100k_base"): + super().__init__(model_name) self.enc = tiktoken.get_encoding(self.model_name) def encode(self, text: str) -> List[int]: From e5dde68dbf9fbeab2bb7037ec04f28bff1b991d2 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 23 Oct 2025 20:25:16 +0800 Subject: [PATCH 8/9] fix: delete duplicate init --- graphgen/models/storage/json_storage.py | 2 -- graphgen/models/storage/networkx_storage.py | 3 --- 2 files changed, 5 deletions(-) diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/json_storage.py index d305ed2..171eb98 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/json_storage.py @@ -7,8 +7,6 @@ @dataclass class JsonKVStorage(BaseKVStorage): - working_dir: str = None - namespace: str = None _data: dict[str, str] = None def __post_init__(self): diff --git a/graphgen/models/storage/networkx_storage.py b/graphgen/models/storage/networkx_storage.py index 824c72c..539ab84 100644 --- a/graphgen/models/storage/networkx_storage.py +++ b/graphgen/models/storage/networkx_storage.py @@ -11,9 +11,6 @@ @dataclass class NetworkXStorage(BaseGraphStorage): - working_dir: str = None - namespace: str = None - @staticmethod def load_nx_graph(file_name) -> Optional[nx.Graph]: if os.path.exists(file_name): From b2522268d9d8f84d0fbb1b7b985eeac1e6aeca39 Mon Sep 17 00:00:00 2001 From: chenzihong-gavin Date: Thu, 23 Oct 2025 20:32:27 +0800 Subject: [PATCH 9/9] fix: fix lint problem --- graphgen/bases/base_storage.py | 7 ++++--- graphgen/models/kg_builder/mm_kg_builder.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index 2a2e0b1..f82e6f6 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -1,12 +1,13 @@ +from dataclasses import dataclass from typing import Generic, TypeVar, Union T = TypeVar("T") +@dataclass class StorageNameSpace: - def __init__(self, working_dir: str = None, namespace: str = None): - self.working_dir = working_dir - self.namespace = namespace + working_dir: str = None + namespace: str = None async def index_done_callback(self): """commit the storage operations after indexing""" diff --git a/graphgen/models/kg_builder/mm_kg_builder.py b/graphgen/models/kg_builder/mm_kg_builder.py index 0eeb394..f352cb2 100644 --- a/graphgen/models/kg_builder/mm_kg_builder.py +++ b/graphgen/models/kg_builder/mm_kg_builder.py @@ -2,7 +2,7 @@ from collections import defaultdict from typing import Dict, List, Tuple -from graphgen.bases import BaseLLMClient, Chunk +from graphgen.bases import Chunk from graphgen.templates import MMKG_EXTRACTION_PROMPT from graphgen.utils import ( detect_main_language,