diff --git a/graphgen/bases/base_generator.py b/graphgen/bases/base_generator.py index 271f20f..d6148ce 100644 --- a/graphgen/bases/base_generator.py +++ b/graphgen/bases/base_generator.py @@ -1,17 +1,16 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Any from graphgen.bases.base_llm_client import BaseLLMClient -@dataclass class BaseGenerator(ABC): """ Generate QAs based on given prompts. """ - llm_client: BaseLLMClient + def __init__(self, llm_client: BaseLLMClient): + self.llm_client = llm_client @staticmethod @abstractmethod diff --git a/graphgen/bases/base_kg_builder.py b/graphgen/bases/base_kg_builder.py index af15486..e234d8d 100644 --- a/graphgen/bases/base_kg_builder.py +++ b/graphgen/bases/base_kg_builder.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections import defaultdict -from dataclasses import dataclass, field from typing import Dict, List, Tuple from graphgen.bases.base_llm_client import BaseLLMClient @@ -8,14 +7,11 @@ from graphgen.bases.datatypes import Chunk -@dataclass class BaseKGBuilder(ABC): - llm_client: BaseLLMClient - - _nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list)) - _edges: Dict[Tuple[str, str], List[dict]] = field( - default_factory=lambda: defaultdict(list) - ) + def __init__(self, llm_client: BaseLLMClient): + self.llm_client = llm_client + self._nodes: Dict[str, List[dict]] = defaultdict(list) + self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list) @abstractmethod async def extract( diff --git a/graphgen/bases/base_partitioner.py b/graphgen/bases/base_partitioner.py index a3739e5..78baddd 100644 --- a/graphgen/bases/base_partitioner.py +++ b/graphgen/bases/base_partitioner.py @@ -1,12 +1,10 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Any, List from graphgen.bases.base_storage import BaseGraphStorage from graphgen.bases.datatypes import Community -@dataclass class BasePartitioner(ABC): @abstractmethod async def partition( diff --git a/graphgen/bases/base_splitter.py b/graphgen/bases/base_splitter.py index dcec006..b2d1ad3 100644 --- a/graphgen/bases/base_splitter.py +++ b/graphgen/bases/base_splitter.py @@ -1,25 +1,32 @@ import copy import re from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Callable, Iterable, List, Literal, Optional, Union from graphgen.bases.datatypes import Chunk from graphgen.utils import logger -@dataclass class BaseSplitter(ABC): """ Abstract base class for splitting text into smaller chunks. """ - chunk_size: int = 1024 - chunk_overlap: int = 100 - length_function: Callable[[str], int] = len - keep_separator: bool = False - add_start_index: bool = False - strip_whitespace: bool = True + def __init__( + self, + chunk_size: int = 1024, + chunk_overlap: int = 100, + length_function: Callable[[str], int] = len, + keep_separator: bool = False, + add_start_index: bool = False, + strip_whitespace: bool = True, + ): + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.length_function = length_function + self.keep_separator = keep_separator + self.add_start_index = add_start_index + self.strip_whitespace = strip_whitespace @abstractmethod def split_text(self, text: str) -> List[str]: diff --git a/graphgen/bases/base_storage.py b/graphgen/bases/base_storage.py index 6968dca..f82e6f6 100644 --- a/graphgen/bases/base_storage.py +++ b/graphgen/bases/base_storage.py @@ -16,7 +16,6 @@ async def query_done_callback(self): """commit the storage operations after querying""" -@dataclass class BaseListStorage(Generic[T], StorageNameSpace): async def all_items(self) -> list[T]: raise NotImplementedError @@ -34,7 +33,6 @@ async def drop(self): raise NotImplementedError -@dataclass class BaseKVStorage(Generic[T], StorageNameSpace): async def all_keys(self) -> list[str]: raise NotImplementedError @@ -58,7 +56,6 @@ async def drop(self): raise NotImplementedError -@dataclass class BaseGraphStorage(StorageNameSpace): async def has_node(self, node_id: str) -> bool: raise NotImplementedError diff --git a/graphgen/bases/base_tokenizer.py b/graphgen/bases/base_tokenizer.py index 958b142..346d500 100644 --- a/graphgen/bases/base_tokenizer.py +++ b/graphgen/bases/base_tokenizer.py @@ -1,13 +1,12 @@ from __future__ import annotations from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import List -@dataclass class BaseTokenizer(ABC): - model_name: str = "cl100k_base" + def __init__(self, model_name: str = "cl100k_base"): + self.model_name = model_name @abstractmethod def encode(self, text: str) -> List[int]: diff --git a/graphgen/models/evaluator/base_evaluator.py b/graphgen/models/evaluator/base_evaluator.py index 1359be6..e24cfa4 100644 --- a/graphgen/models/evaluator/base_evaluator.py +++ b/graphgen/models/evaluator/base_evaluator.py @@ -1,5 +1,4 @@ import asyncio -from dataclasses import dataclass from tqdm.asyncio import tqdm as tqdm_async @@ -7,10 +6,10 @@ from graphgen.utils import create_event_loop -@dataclass class BaseEvaluator: - max_concurrent: int = 100 - results: list[float] = None + def __init__(self, max_concurrent: int = 100): + self.max_concurrent = max_concurrent + self.results: list[float] = None def evaluate(self, pairs: list[QAPair]) -> list[float]: """ diff --git a/graphgen/models/evaluator/length_evaluator.py b/graphgen/models/evaluator/length_evaluator.py index a7e9989..d5c3321 100644 --- a/graphgen/models/evaluator/length_evaluator.py +++ b/graphgen/models/evaluator/length_evaluator.py @@ -1,16 +1,13 @@ -from dataclasses import dataclass - from graphgen.bases.datatypes import QAPair from graphgen.models.evaluator.base_evaluator import BaseEvaluator from graphgen.models.tokenizer import Tokenizer from graphgen.utils import create_event_loop -@dataclass class LengthEvaluator(BaseEvaluator): - tokenizer_name: str = "cl100k_base" - - def __post_init__(self): + def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100): + super().__init__(max_concurrent) + self.tokenizer_name = tokenizer_name self.tokenizer = Tokenizer(model_name=self.tokenizer_name) async def evaluate_single(self, pair: QAPair) -> float: diff --git a/graphgen/models/evaluator/mtld_evaluator.py b/graphgen/models/evaluator/mtld_evaluator.py index 79924fe..c106d86 100644 --- a/graphgen/models/evaluator/mtld_evaluator.py +++ b/graphgen/models/evaluator/mtld_evaluator.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, field from typing import Set from graphgen.bases.datatypes import QAPair @@ -8,18 +7,15 @@ nltk_helper = NLTKHelper() -@dataclass class MTLDEvaluator(BaseEvaluator): """ 衡量文本词汇多样性的指标 """ - stopwords_en: Set[str] = field( - default_factory=lambda: set(nltk_helper.get_stopwords("english")) - ) - stopwords_zh: Set[str] = field( - default_factory=lambda: set(nltk_helper.get_stopwords("chinese")) - ) + def __init__(self, max_concurrent: int = 100): + super().__init__(max_concurrent) + self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english")) + self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese")) async def evaluate_single(self, pair: QAPair) -> float: loop = create_event_loop() diff --git a/graphgen/models/generator/aggregated_generator.py b/graphgen/models/generator/aggregated_generator.py index bbf483e..4bad8e9 100644 --- a/graphgen/models/generator/aggregated_generator.py +++ b/graphgen/models/generator/aggregated_generator.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any from graphgen.bases import BaseGenerator @@ -6,7 +5,6 @@ from graphgen.utils import compute_content_hash, detect_main_language, logger -@dataclass class AggregatedGenerator(BaseGenerator): """ Aggregated Generator follows a TWO-STEP process: diff --git a/graphgen/models/generator/atomic_generator.py b/graphgen/models/generator/atomic_generator.py index bd152d3..713140d 100644 --- a/graphgen/models/generator/atomic_generator.py +++ b/graphgen/models/generator/atomic_generator.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any from graphgen.bases import BaseGenerator @@ -6,7 +5,6 @@ from graphgen.utils import compute_content_hash, detect_main_language, logger -@dataclass class AtomicGenerator(BaseGenerator): @staticmethod def build_prompt( diff --git a/graphgen/models/generator/cot_generator.py b/graphgen/models/generator/cot_generator.py index bd924b7..a111a6f 100644 --- a/graphgen/models/generator/cot_generator.py +++ b/graphgen/models/generator/cot_generator.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any from graphgen.bases import BaseGenerator @@ -6,7 +5,6 @@ from graphgen.utils import compute_content_hash, detect_main_language, logger -@dataclass class CoTGenerator(BaseGenerator): @staticmethod def build_prompt( diff --git a/graphgen/models/generator/multi_hop_generator.py b/graphgen/models/generator/multi_hop_generator.py index 3fd1824..9098b10 100644 --- a/graphgen/models/generator/multi_hop_generator.py +++ b/graphgen/models/generator/multi_hop_generator.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any from graphgen.bases import BaseGenerator @@ -6,7 +5,6 @@ from graphgen.utils import compute_content_hash, detect_main_language, logger -@dataclass class MultiHopGenerator(BaseGenerator): @staticmethod def build_prompt( diff --git a/graphgen/models/generator/vqa_generator.py b/graphgen/models/generator/vqa_generator.py index b0c29d2..eefbdd1 100644 --- a/graphgen/models/generator/vqa_generator.py +++ b/graphgen/models/generator/vqa_generator.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import Any from graphgen.bases import BaseGenerator @@ -6,7 +5,6 @@ from graphgen.utils import compute_content_hash, detect_main_language, logger -@dataclass class VQAGenerator(BaseGenerator): @staticmethod def build_prompt( diff --git a/graphgen/models/kg_builder/light_rag_kg_builder.py b/graphgen/models/kg_builder/light_rag_kg_builder.py index e734eca..cde42d2 100644 --- a/graphgen/models/kg_builder/light_rag_kg_builder.py +++ b/graphgen/models/kg_builder/light_rag_kg_builder.py @@ -1,6 +1,5 @@ import re from collections import Counter, defaultdict -from dataclasses import dataclass from typing import Dict, List, Tuple from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk @@ -15,10 +14,10 @@ ) -@dataclass class LightRAGKGBuilder(BaseKGBuilder): - llm_client: BaseLLMClient = None - max_loop: int = 3 + def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3): + super().__init__(llm_client) + self.max_loop = max_loop async def extract( self, chunk: Chunk diff --git a/graphgen/models/kg_builder/mm_kg_builder.py b/graphgen/models/kg_builder/mm_kg_builder.py index c554729..f352cb2 100644 --- a/graphgen/models/kg_builder/mm_kg_builder.py +++ b/graphgen/models/kg_builder/mm_kg_builder.py @@ -2,7 +2,7 @@ from collections import defaultdict from typing import Dict, List, Tuple -from graphgen.bases import BaseLLMClient, Chunk +from graphgen.bases import Chunk from graphgen.templates import MMKG_EXTRACTION_PROMPT from graphgen.utils import ( detect_main_language, @@ -16,8 +16,6 @@ class MMKGBuilder(LightRAGKGBuilder): - llm_client: BaseLLMClient = None - async def extract( self, chunk: Chunk ) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]: diff --git a/graphgen/models/llm/topk_token_model.py b/graphgen/models/llm/topk_token_model.py index 94719cf..e93ca01 100644 --- a/graphgen/models/llm/topk_token_model.py +++ b/graphgen/models/llm/topk_token_model.py @@ -1,21 +1,31 @@ -from dataclasses import dataclass +from abc import ABC, abstractmethod from typing import List, Optional from graphgen.bases import Token -@dataclass -class TopkTokenModel: - do_sample: bool = False - temperature: float = 0 - max_tokens: int = 4096 - repetition_penalty: float = 1.05 - num_beams: int = 1 - topk: int = 50 - topp: float = 0.95 - - topk_per_token: int = 5 # number of topk tokens to generate for each token +class TopkTokenModel(ABC): + def __init__( + self, + do_sample: bool = False, + temperature: float = 0, + max_tokens: int = 4096, + repetition_penalty: float = 1.05, + num_beams: int = 1, + topk: int = 50, + topp: float = 0.95, + topk_per_token: int = 5, + ): + self.do_sample = do_sample + self.temperature = temperature + self.max_tokens = max_tokens + self.repetition_penalty = repetition_penalty + self.num_beams = num_beams + self.topk = topk + self.topp = topp + self.topk_per_token = topk_per_token + @abstractmethod async def generate_topk_per_token(self, text: str) -> List[Token]: """ Generate prob, text and candidates for each token of the model's output. @@ -23,6 +33,7 @@ async def generate_topk_per_token(self, text: str) -> List[Token]: """ raise NotImplementedError + @abstractmethod async def generate_inputs_prob( self, text: str, history: Optional[List[str]] = None ) -> List[Token]: @@ -32,6 +43,7 @@ async def generate_inputs_prob( """ raise NotImplementedError + @abstractmethod async def generate_answer( self, text: str, history: Optional[List[str]] = None ) -> str: diff --git a/graphgen/models/partitioner/bfs_partitioner.py b/graphgen/models/partitioner/bfs_partitioner.py index 7cc5148..7b7b421 100644 --- a/graphgen/models/partitioner/bfs_partitioner.py +++ b/graphgen/models/partitioner/bfs_partitioner.py @@ -1,6 +1,5 @@ import random from collections import deque -from dataclasses import dataclass from typing import Any, List from graphgen.bases import BaseGraphStorage, BasePartitioner @@ -10,7 +9,6 @@ EDGE_UNIT: str = "e" -@dataclass class BFSPartitioner(BasePartitioner): """ BFS partitioner that partitions the graph into communities of a fixed size. diff --git a/graphgen/models/partitioner/dfs_partitioner.py b/graphgen/models/partitioner/dfs_partitioner.py index a9a64a9..01df509 100644 --- a/graphgen/models/partitioner/dfs_partitioner.py +++ b/graphgen/models/partitioner/dfs_partitioner.py @@ -1,5 +1,4 @@ import random -from dataclasses import dataclass from typing import Any, List from graphgen.bases import BaseGraphStorage, BasePartitioner @@ -9,7 +8,6 @@ EDGE_UNIT: str = "e" -@dataclass class DFSPartitioner(BasePartitioner): """ DFS partitioner that partitions the graph into communities of a fixed size. diff --git a/graphgen/models/partitioner/ece_partitioner.py b/graphgen/models/partitioner/ece_partitioner.py index 4352e62..e874f56 100644 --- a/graphgen/models/partitioner/ece_partitioner.py +++ b/graphgen/models/partitioner/ece_partitioner.py @@ -1,6 +1,5 @@ import asyncio import random -from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple from tqdm.asyncio import tqdm as tqdm_async @@ -13,7 +12,6 @@ EDGE_UNIT: str = "e" -@dataclass class ECEPartitioner(BFSPartitioner): """ ECE partitioner that partitions the graph into communities based on Expected Calibration Error (ECE). diff --git a/graphgen/models/partitioner/leiden_partitioner.py b/graphgen/models/partitioner/leiden_partitioner.py index ffa38ae..28dfc1d 100644 --- a/graphgen/models/partitioner/leiden_partitioner.py +++ b/graphgen/models/partitioner/leiden_partitioner.py @@ -1,5 +1,4 @@ from collections import defaultdict -from dataclasses import dataclass from typing import Any, Dict, List, Set, Tuple import igraph as ig @@ -9,7 +8,6 @@ from graphgen.bases.datatypes import Community -@dataclass class LeidenPartitioner(BasePartitioner): """ Leiden partitioner that partitions the graph into communities using the Leiden algorithm. diff --git a/graphgen/models/search/db/uniprot_search.py b/graphgen/models/search/db/uniprot_search.py index 96bdd99..daf4224 100644 --- a/graphgen/models/search/db/uniprot_search.py +++ b/graphgen/models/search/db/uniprot_search.py @@ -1,5 +1,3 @@ -from dataclasses import dataclass - import requests from fastapi import HTTPException @@ -8,7 +6,6 @@ UNIPROT_BASE = "https://rest.uniprot.org/uniprotkb/search" -@dataclass class UniProtSearch: """ UniProt Search client to search with UniProt. diff --git a/graphgen/models/search/kg/wiki_search.py b/graphgen/models/search/kg/wiki_search.py index e9513f2..2d8686c 100644 --- a/graphgen/models/search/kg/wiki_search.py +++ b/graphgen/models/search/kg/wiki_search.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import List, Union import wikipedia @@ -7,7 +6,6 @@ from graphgen.utils import detect_main_language, logger -@dataclass class WikiSearch: @staticmethod def set_language(language: str): diff --git a/graphgen/models/search/web/bing_search.py b/graphgen/models/search/web/bing_search.py index a769ba7..d52815d 100644 --- a/graphgen/models/search/web/bing_search.py +++ b/graphgen/models/search/web/bing_search.py @@ -1,5 +1,3 @@ -from dataclasses import dataclass - import requests from fastapi import HTTPException @@ -9,13 +7,13 @@ BING_MKT = "en-US" -@dataclass class BingSearch: """ Bing Search client to search with Bing. """ - subscription_key: str + def __init__(self, subscription_key: str): + self.subscription_key = subscription_key def search(self, query: str, num_results: int = 1): """ diff --git a/graphgen/models/search/web/google_search.py b/graphgen/models/search/web/google_search.py index 1abfcdf..0b04572 100644 --- a/graphgen/models/search/web/google_search.py +++ b/graphgen/models/search/web/google_search.py @@ -1,5 +1,3 @@ -from dataclasses import dataclass - import requests from fastapi import HTTPException @@ -8,7 +6,6 @@ GOOGLE_SEARCH_ENDPOINT = "https://customsearch.googleapis.com/customsearch/v1" -@dataclass class GoogleSearch: def __init__(self, subscription_key: str, cx: str): """ diff --git a/graphgen/models/storage/json_storage.py b/graphgen/models/storage/json_storage.py index e37d033..171eb98 100644 --- a/graphgen/models/storage/json_storage.py +++ b/graphgen/models/storage/json_storage.py @@ -53,6 +53,8 @@ async def drop(self): @dataclass class JsonListStorage(BaseListStorage): + working_dir: str = None + namespace: str = None _data: list = None def __post_init__(self): diff --git a/graphgen/models/tokenizer/__init__.py b/graphgen/models/tokenizer/__init__.py index 1df5039..6712f91 100644 --- a/graphgen/models/tokenizer/__init__.py +++ b/graphgen/models/tokenizer/__init__.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass, field from typing import List from graphgen.bases import BaseTokenizer @@ -30,16 +29,13 @@ def get_tokenizer_impl(tokenizer_name: str = "cl100k_base") -> BaseTokenizer: ) -@dataclass class Tokenizer(BaseTokenizer): """ Encapsulates different tokenization implementations based on the specified model name. """ - model_name: str = "cl100k_base" - _impl: BaseTokenizer = field(init=False, repr=False) - - def __post_init__(self): + def __init__(self, model_name: str = "cl100k_base"): + super().__init__(model_name) if not self.model_name: raise ValueError("TOKENIZER_MODEL must be specified in the ENV variables.") self._impl = get_tokenizer_impl(self.model_name) diff --git a/graphgen/models/tokenizer/hf_tokenizer.py b/graphgen/models/tokenizer/hf_tokenizer.py index e5511a9..c43ddd7 100644 --- a/graphgen/models/tokenizer/hf_tokenizer.py +++ b/graphgen/models/tokenizer/hf_tokenizer.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import List from transformers import AutoTokenizer @@ -6,9 +5,9 @@ from graphgen.bases import BaseTokenizer -@dataclass class HFTokenizer(BaseTokenizer): - def __post_init__(self): + def __init__(self, model_name: str = "cl100k_base"): + super().__init__(model_name) self.enc = AutoTokenizer.from_pretrained(self.model_name) def encode(self, text: str) -> List[int]: diff --git a/graphgen/models/tokenizer/tiktoken_tokenizer.py b/graphgen/models/tokenizer/tiktoken_tokenizer.py index 3c84edd..6145d07 100644 --- a/graphgen/models/tokenizer/tiktoken_tokenizer.py +++ b/graphgen/models/tokenizer/tiktoken_tokenizer.py @@ -1,4 +1,3 @@ -from dataclasses import dataclass from typing import List import tiktoken @@ -6,9 +5,9 @@ from graphgen.bases import BaseTokenizer -@dataclass class TiktokenTokenizer(BaseTokenizer): - def __post_init__(self): + def __init__(self, model_name: str = "cl100k_base"): + super().__init__(model_name) self.enc = tiktoken.get_encoding(self.model_name) def encode(self, text: str) -> List[int]: