Skip to content
Merged
135 changes: 135 additions & 0 deletions graphgen/bases/base_splitter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import copy
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterable, List, Literal, Optional, Union

from graphgen.bases.datatypes import Chunk
from graphgen.utils import logger


@dataclass
class BaseSplitter(ABC):
"""
Abstract base class for splitting text into smaller chunks.
"""

chunk_size: int = 1024
chunk_overlap: int = 100
length_function: Callable[[str], int] = len
keep_separator: bool = False
add_start_index: bool = False
strip_whitespace: bool = True

@abstractmethod
def split_text(self, text: str) -> List[str]:
"""
Split the input text into smaller chunks.

:param text: The input text to be split.
:return: A list of text chunks.
"""

def create_chunks(
self, texts: List[str], metadatas: Optional[List[dict]] = None
) -> List[Chunk]:
"""Create chunks from a list of texts."""
_metadatas = metadatas or [{}] * len(texts)
chunks = []
for i, text in enumerate(texts):
index = 0
previous_chunk_len = 0
for chunk in self.split_text(text):
metadata = copy.deepcopy(_metadatas[i])
if self.add_start_index:
offset = index + previous_chunk_len - self.chunk_overlap
index = text.find(chunk, max(0, offset))
metadata["start_index"] = index
previous_chunk_len = len(chunk)
new_chunk = Chunk(content=chunk, metadata=metadata)
Copy link

Copilot AI Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating Chunk instances without providing the required id field will cause runtime errors. The Chunk dataclass requires all three fields (id, content, metadata) but only content and metadata are being provided.

Copilot uses AI. Check for mistakes.
chunks.append(new_chunk)
return chunks

def _join_chunks(self, chunks: List[str], separator: str) -> Optional[str]:
text = separator.join(chunks)
if self.strip_whitespace:
text = text.strip()
if text == "":
return None
return text

def _merge_splits(self, splits: Iterable[str], separator: str) -> List[str]:
# We now want to combine these smaller pieces into medium size chunks to send to the LLM.
separator_len = self.length_function(separator)

chunks = []
current_chunk: List[str] = []
total = 0
for d in splits:
_len = self.length_function(d)
if (
total + _len + (separator_len if len(current_chunk) > 0 else 0)
> self.chunk_size
):
if total > self.chunk_size:
logger.warning(
"Created a chunk of size %s, which is longer than the specified %s",
total,
self.chunk_size,
)
if len(current_chunk) > 0:
chunk = self._join_chunks(current_chunk, separator)
if chunk is not None:
chunks.append(chunk)
# Keep on popping if:
# - we have a larger chunk than in the chunk overlap
# - or if we still have any chunks and the length is long
while total > self.chunk_overlap or (
total + _len + (separator_len if len(current_chunk) > 0 else 0)
> self.chunk_size
and total > 0
):
total -= self.length_function(current_chunk[0]) + (
separator_len if len(current_chunk) > 1 else 0
)
current_chunk = current_chunk[1:]
current_chunk.append(d)
total += _len + (separator_len if len(current_chunk) > 1 else 0)
chunk = self._join_chunks(current_chunk, separator)
if chunk is not None:
chunks.append(chunk)
return chunks

@staticmethod
def _split_text_with_regex(
text: str, separator: str, keep_separator: Union[bool, Literal["start", "end"]]
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = (
(
[
_splits[i] + _splits[i + 1]
for i in range(0, len(_splits) - 1, 2)
]
)
if keep_separator == "end"
else (
[_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
)
)
if len(_splits) % 2 == 0:
splits += _splits[-1:]
splits = (
(splits + [_splits[-1]])
if keep_separator == "end"
else ([_splits[0]] + splits)
)
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
18 changes: 18 additions & 0 deletions graphgen/bases/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass, field


@dataclass
class Chunk:
id: str
content: str
metadata: dict = field(default_factory=dict)


@dataclass
class QAPair:
"""
A pair of question and answer.
"""

question: str
answer: str
File renamed without changes.
6 changes: 5 additions & 1 deletion graphgen/configs/aggregated_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
read:
input_file: resources/input_examples/jsonl_demo.jsonl # input file path, support json, jsonl, txt. See resources/input_examples for examples
split:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting
output_data_type: aggregated # atomic, aggregated, multi_hop, cot
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
Expand Down
6 changes: 5 additions & 1 deletion graphgen/configs/atomic_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples
read:
input_file: resources/input_examples/json_demo.json # input file path, support json, jsonl, txt, csv. See resources/input_examples for examples
split:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting
output_data_type: atomic # atomic, aggregated, multi_hop, cot
output_data_format: Alpaca # Alpaca, Sharegpt, ChatML
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
Expand Down
6 changes: 5 additions & 1 deletion graphgen/configs/cot_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples
read:
input_file: resources/input_examples/txt_demo.txt # input file path, support json, jsonl, txt. See resources/input_examples for examples
split:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting
output_data_type: cot # atomic, aggregated, multi_hop, cot
output_data_format: Sharegpt # Alpaca, Sharegpt, ChatML
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
Expand Down
6 changes: 5 additions & 1 deletion graphgen/configs/multi_hop_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples
read:
input_file: resources/input_examples/csv_demo.csv # input file path, support json, jsonl, txt. See resources/input_examples for examples
split:
chunk_size: 1024 # chunk size for text splitting
chunk_overlap: 100 # chunk overlap for text splitting
output_data_type: multi_hop # atomic, aggregated, multi_hop, cot
output_data_format: ChatML # Alpaca, Sharegpt, ChatML
tokenizer: cl100k_base # tokenizer for counting tokens, support tiktoken tokenizer names and local tokenizer path
Expand Down
Loading