Skip to content
Merged
5 changes: 2 additions & 3 deletions graphgen/bases/base_generator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any

from graphgen.bases.base_llm_client import BaseLLMClient


@dataclass
class BaseGenerator(ABC):
"""
Generate QAs based on given prompts.
"""

llm_client: BaseLLMClient
def __init__(self, llm_client: BaseLLMClient):
self.llm_client = llm_client

@staticmethod
@abstractmethod
Expand Down
12 changes: 4 additions & 8 deletions graphgen/bases/base_kg_builder.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Dict, List, Tuple

from graphgen.bases.base_llm_client import BaseLLMClient
from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Chunk


@dataclass
class BaseKGBuilder(ABC):
llm_client: BaseLLMClient

_nodes: Dict[str, List[dict]] = field(default_factory=lambda: defaultdict(list))
_edges: Dict[Tuple[str, str], List[dict]] = field(
default_factory=lambda: defaultdict(list)
)
def __init__(self, llm_client: BaseLLMClient):
self.llm_client = llm_client
self._nodes: Dict[str, List[dict]] = defaultdict(list)
self._edges: Dict[Tuple[str, str], List[dict]] = defaultdict(list)

@abstractmethod
async def extract(
Expand Down
2 changes: 0 additions & 2 deletions graphgen/bases/base_partitioner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List

from graphgen.bases.base_storage import BaseGraphStorage
from graphgen.bases.datatypes import Community


@dataclass
class BasePartitioner(ABC):
@abstractmethod
async def partition(
Expand Down
23 changes: 15 additions & 8 deletions graphgen/bases/base_splitter.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,32 @@
import copy
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterable, List, Literal, Optional, Union

from graphgen.bases.datatypes import Chunk
from graphgen.utils import logger


@dataclass
class BaseSplitter(ABC):
"""
Abstract base class for splitting text into smaller chunks.
"""

chunk_size: int = 1024
chunk_overlap: int = 100
length_function: Callable[[str], int] = len
keep_separator: bool = False
add_start_index: bool = False
strip_whitespace: bool = True
def __init__(
self,
chunk_size: int = 1024,
chunk_overlap: int = 100,
length_function: Callable[[str], int] = len,
keep_separator: bool = False,
add_start_index: bool = False,
strip_whitespace: bool = True,
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.length_function = length_function
self.keep_separator = keep_separator
self.add_start_index = add_start_index
self.strip_whitespace = strip_whitespace

@abstractmethod
def split_text(self, text: str) -> List[str]:
Expand Down
3 changes: 0 additions & 3 deletions graphgen/bases/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ async def query_done_callback(self):
"""commit the storage operations after querying"""


@dataclass
class BaseListStorage(Generic[T], StorageNameSpace):
async def all_items(self) -> list[T]:
raise NotImplementedError
Expand All @@ -34,7 +33,6 @@ async def drop(self):
raise NotImplementedError


@dataclass
class BaseKVStorage(Generic[T], StorageNameSpace):
async def all_keys(self) -> list[str]:
raise NotImplementedError
Expand All @@ -58,7 +56,6 @@ async def drop(self):
raise NotImplementedError


@dataclass
class BaseGraphStorage(StorageNameSpace):
async def has_node(self, node_id: str) -> bool:
raise NotImplementedError
Expand Down
5 changes: 2 additions & 3 deletions graphgen/bases/base_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List


@dataclass
class BaseTokenizer(ABC):
model_name: str = "cl100k_base"
def __init__(self, model_name: str = "cl100k_base"):
self.model_name = model_name

@abstractmethod
def encode(self, text: str) -> List[int]:
Expand Down
7 changes: 3 additions & 4 deletions graphgen/models/evaluator/base_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
import asyncio
from dataclasses import dataclass

from tqdm.asyncio import tqdm as tqdm_async

from graphgen.bases.datatypes import QAPair
from graphgen.utils import create_event_loop


@dataclass
class BaseEvaluator:
max_concurrent: int = 100
results: list[float] = None
def __init__(self, max_concurrent: int = 100):
self.max_concurrent = max_concurrent
self.results: list[float] = None

def evaluate(self, pairs: list[QAPair]) -> list[float]:
"""
Expand Down
9 changes: 3 additions & 6 deletions graphgen/models/evaluator/length_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from dataclasses import dataclass

from graphgen.bases.datatypes import QAPair
from graphgen.models.evaluator.base_evaluator import BaseEvaluator
from graphgen.models.tokenizer import Tokenizer
from graphgen.utils import create_event_loop


@dataclass
class LengthEvaluator(BaseEvaluator):
tokenizer_name: str = "cl100k_base"

def __post_init__(self):
def __init__(self, tokenizer_name: str = "cl100k_base", max_concurrent: int = 100):
super().__init__(max_concurrent)
self.tokenizer_name = tokenizer_name
self.tokenizer = Tokenizer(model_name=self.tokenizer_name)

async def evaluate_single(self, pair: QAPair) -> float:
Expand Down
12 changes: 4 additions & 8 deletions graphgen/models/evaluator/mtld_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass, field
from typing import Set

from graphgen.bases.datatypes import QAPair
Expand All @@ -8,18 +7,15 @@
nltk_helper = NLTKHelper()


@dataclass
class MTLDEvaluator(BaseEvaluator):
"""
衡量文本词汇多样性的指标
"""

stopwords_en: Set[str] = field(
default_factory=lambda: set(nltk_helper.get_stopwords("english"))
)
stopwords_zh: Set[str] = field(
default_factory=lambda: set(nltk_helper.get_stopwords("chinese"))
)
def __init__(self, max_concurrent: int = 100):
super().__init__(max_concurrent)
self.stopwords_en: Set[str] = set(nltk_helper.get_stopwords("english"))
self.stopwords_zh: Set[str] = set(nltk_helper.get_stopwords("chinese"))

async def evaluate_single(self, pair: QAPair) -> float:
loop = create_event_loop()
Expand Down
2 changes: 0 additions & 2 deletions graphgen/models/generator/aggregated_generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from dataclasses import dataclass
from typing import Any

from graphgen.bases import BaseGenerator
from graphgen.templates import AGGREGATED_GENERATION_PROMPT
from graphgen.utils import compute_content_hash, detect_main_language, logger


@dataclass
class AggregatedGenerator(BaseGenerator):
"""
Aggregated Generator follows a TWO-STEP process:
Expand Down
2 changes: 0 additions & 2 deletions graphgen/models/generator/atomic_generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from dataclasses import dataclass
from typing import Any

from graphgen.bases import BaseGenerator
from graphgen.templates import ATOMIC_GENERATION_PROMPT
from graphgen.utils import compute_content_hash, detect_main_language, logger


@dataclass
class AtomicGenerator(BaseGenerator):
@staticmethod
def build_prompt(
Expand Down
2 changes: 0 additions & 2 deletions graphgen/models/generator/cot_generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from dataclasses import dataclass
from typing import Any

from graphgen.bases import BaseGenerator
from graphgen.templates import COT_GENERATION_PROMPT
from graphgen.utils import compute_content_hash, detect_main_language, logger


@dataclass
class CoTGenerator(BaseGenerator):
@staticmethod
def build_prompt(
Expand Down
2 changes: 0 additions & 2 deletions graphgen/models/generator/multi_hop_generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from dataclasses import dataclass
from typing import Any

from graphgen.bases import BaseGenerator
from graphgen.templates import MULTI_HOP_GENERATION_PROMPT
from graphgen.utils import compute_content_hash, detect_main_language, logger


@dataclass
class MultiHopGenerator(BaseGenerator):
@staticmethod
def build_prompt(
Expand Down
2 changes: 0 additions & 2 deletions graphgen/models/generator/vqa_generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from dataclasses import dataclass
from typing import Any

from graphgen.bases import BaseGenerator
from graphgen.templates import VQA_GENERATION_PROMPT
from graphgen.utils import compute_content_hash, detect_main_language, logger


@dataclass
class VQAGenerator(BaseGenerator):
@staticmethod
def build_prompt(
Expand Down
7 changes: 3 additions & 4 deletions graphgen/models/kg_builder/light_rag_kg_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import re
from collections import Counter, defaultdict
from dataclasses import dataclass
from typing import Dict, List, Tuple

from graphgen.bases import BaseGraphStorage, BaseKGBuilder, BaseLLMClient, Chunk
Expand All @@ -15,10 +14,10 @@
)


@dataclass
class LightRAGKGBuilder(BaseKGBuilder):
llm_client: BaseLLMClient = None
max_loop: int = 3
def __init__(self, llm_client: BaseLLMClient, max_loop: int = 3):
super().__init__(llm_client)
self.max_loop = max_loop

async def extract(
self, chunk: Chunk
Expand Down
4 changes: 1 addition & 3 deletions graphgen/models/kg_builder/mm_kg_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from typing import Dict, List, Tuple

from graphgen.bases import BaseLLMClient, Chunk
from graphgen.bases import Chunk
from graphgen.templates import MMKG_EXTRACTION_PROMPT
from graphgen.utils import (
detect_main_language,
Expand All @@ -16,8 +16,6 @@


class MMKGBuilder(LightRAGKGBuilder):
llm_client: BaseLLMClient = None

async def extract(
self, chunk: Chunk
) -> Tuple[Dict[str, List[dict]], Dict[Tuple[str, str], List[dict]]]:
Expand Down
36 changes: 24 additions & 12 deletions graphgen/models/llm/topk_token_model.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,39 @@
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import List, Optional

from graphgen.bases import Token


@dataclass
class TopkTokenModel:
do_sample: bool = False
temperature: float = 0
max_tokens: int = 4096
repetition_penalty: float = 1.05
num_beams: int = 1
topk: int = 50
topp: float = 0.95

topk_per_token: int = 5 # number of topk tokens to generate for each token
class TopkTokenModel(ABC):
def __init__(
self,
do_sample: bool = False,
temperature: float = 0,
max_tokens: int = 4096,
repetition_penalty: float = 1.05,
num_beams: int = 1,
topk: int = 50,
topp: float = 0.95,
topk_per_token: int = 5,
):
self.do_sample = do_sample
self.temperature = temperature
self.max_tokens = max_tokens
self.repetition_penalty = repetition_penalty
self.num_beams = num_beams
self.topk = topk
self.topp = topp
self.topk_per_token = topk_per_token

@abstractmethod
async def generate_topk_per_token(self, text: str) -> List[Token]:
"""
Generate prob, text and candidates for each token of the model's output.
This function is used to visualize the inference process.
"""
raise NotImplementedError

@abstractmethod
async def generate_inputs_prob(
self, text: str, history: Optional[List[str]] = None
) -> List[Token]:
Expand All @@ -32,6 +43,7 @@ async def generate_inputs_prob(
"""
raise NotImplementedError

@abstractmethod
async def generate_answer(
self, text: str, history: Optional[List[str]] = None
) -> str:
Expand Down
2 changes: 0 additions & 2 deletions graphgen/models/partitioner/bfs_partitioner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import random
from collections import deque
from dataclasses import dataclass
from typing import Any, List

from graphgen.bases import BaseGraphStorage, BasePartitioner
Expand All @@ -10,7 +9,6 @@
EDGE_UNIT: str = "e"


@dataclass
class BFSPartitioner(BasePartitioner):
"""
BFS partitioner that partitions the graph into communities of a fixed size.
Expand Down
Loading