-
Notifications
You must be signed in to change notification settings - Fork 4.4k
/
index_processor_base.py
75 lines (62 loc) · 3.14 KB
/
index_processor_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""Abstract interface for document loader implementations."""
from abc import ABC, abstractmethod
from typing import Optional
from flask import current_app
from core.model_manager import ModelInstance
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.models.document import Document
from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from core.rag.splitter.text_splitter import TextSplitter
from models.dataset import Dataset, DatasetProcessRule
class BaseIndexProcessor(ABC):
"""Interface for extract files.
"""
@abstractmethod
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True):
raise NotImplementedError
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True):
raise NotImplementedError
@abstractmethod
def retrieve(self, retrival_method: str, query: str, dataset: Dataset, top_k: int,
score_threshold: float, reranking_model: dict) -> list[Document]:
raise NotImplementedError
def _get_splitter(self, processing_rule: dict,
embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
"""
Get the NodeParser object according to the processing rule.
"""
if processing_rule['mode'] == "custom":
# The user-defined segmentation rule
rules = processing_rule['rules']
segmentation = rules["segmentation"]
max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH'])
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
separator = segmentation["separator"]
if separator:
separator = separator.replace('\\n', '\n')
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=segmentation.get('chunk_overlap', 0),
fixed_separator=separator,
separators=["\n\n", "。", ". ", " ", ""],
embedding_model_instance=embedding_model_instance
)
else:
# Automatic segmentation
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
separators=["\n\n", "。", ". ", " ", ""],
embedding_model_instance=embedding_model_instance
)
return character_splitter