Skip to content

Commit

Permalink
feat: 分段管理支持批量迁移,删除分段 1Panel-dev#113,1Panel-dev#103
Browse files Browse the repository at this point in the history
  • Loading branch information
shaohuzhang1 committed May 8, 2024
1 parent 8204d5f commit d4e742f
Show file tree
Hide file tree
Showing 13 changed files with 645 additions and 55 deletions.
25 changes: 21 additions & 4 deletions apps/common/event/listener_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,15 @@ def __init__(self, problem_id: str, problem_content: str):


class UpdateEmbeddingDatasetIdArgs:
def __init__(self, source_id_list: List[str], target_dataset_id: str):
self.source_id_list = source_id_list
def __init__(self, paragraph_id_list: List[str], target_dataset_id: str):
self.paragraph_id_list = paragraph_id_list
self.target_dataset_id = target_dataset_id


class UpdateEmbeddingDocumentIdArgs:
def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str):
self.paragraph_id_list = paragraph_id_list
self.target_document_id = target_document_id
self.target_dataset_id = target_dataset_id


Expand Down Expand Up @@ -213,13 +220,23 @@ def update_problem(args: UpdateProblemArgs):

@staticmethod
def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs):
VectorStore.get_embedding_vector().update_by_source_ids(args.source_id_list,
{'dataset_id': args.target_dataset_id})
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'dataset_id': args.target_dataset_id})

@staticmethod
def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs):
VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list,
{'document_id': args.target_document_id,
'dataset_id': args.target_dataset_id})

@staticmethod
def delete_embedding_by_source_ids(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM)

@staticmethod
def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_ids)

@staticmethod
def delete_embedding_by_dataset_id_list(source_ids: List[str]):
VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids)
Expand Down
6 changes: 4 additions & 2 deletions apps/dataset/serializers/document_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,12 @@ def migrate(self, with_valid=True):
meta={})
else:
document_list.update(dataset_id=target_dataset_id)
paragraph_list.update(dataset_id=target_dataset_id)
# 修改向量信息
ListenerManagement.update_embedding_dataset_id(UpdateEmbeddingDatasetIdArgs(
[problem_paragraph_mapping.id for problem_paragraph_mapping in problem_paragraph_mapping_list],
[paragraph.id for paragraph in paragraph_list],
target_dataset_id))
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id)

@staticmethod
def get_target_dataset_problem(target_dataset_id: str,
Expand Down
165 changes: 163 additions & 2 deletions apps/dataset/serializers/paragraph_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
from rest_framework import serializers

from common.db.search import page_search
from common.event.listener_manage import ListenerManagement
from common.event.listener_manage import ListenerManagement, UpdateEmbeddingDocumentIdArgs, UpdateEmbeddingDatasetIdArgs
from common.exception.app_exception import AppApiException
from common.mixins.api_mixin import ApiMixin
from common.util.common import post
from common.util.field_message import ErrMessage
from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping
from dataset.serializers.common_serializers import update_document_char_length
from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer
from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers
from embedding.models import SourceType

Expand Down Expand Up @@ -272,6 +272,167 @@ def get_request_params_api():
description='问题id')
]

class Batch(serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))

@transaction.atomic
def batch_delete(self, instance: Dict, with_valid=True):
if with_valid:
BatchSerializer(data=instance).is_valid(model=Paragraph, raise_exception=True)
self.is_valid(raise_exception=True)
paragraph_id_list = instance.get("id_list")
QuerySet(Paragraph).filter(id__in=paragraph_id_list).delete()
QuerySet(ProblemParagraphMapping).filter(paragraph_id__in=paragraph_id_list).delete()
# 删除向量库
ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_id_list)
return True

class Migrate(ApiMixin, serializers.Serializer):
dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id"))
document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id"))
target_dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("目标知识库id"))
target_document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("目标文档id"))
paragraph_id_list = serializers.ListField(required=True, error_messages=ErrMessage.char("段落列表"),
child=serializers.UUIDField(required=True,
error_messages=ErrMessage.uuid("段落id")))

def is_valid(self, *, raise_exception=False):
super().is_valid(raise_exception=True)
document_list = QuerySet(Document).filter(
id__in=[self.data.get('document_id'), self.data.get('target_document_id')])
document_id = self.data.get('document_id')
target_document_id = self.data.get('target_document_id')
if document_id == target_document_id:
raise AppApiException(5000, "需要迁移的文档和目标文档一致")
if len([document for document in document_list if str(document.id) == self.data.get('document_id')]) < 1:
raise AppApiException(5000, f"文档id不存在【{self.data.get('document_id')}】")
if len([document for document in document_list if
str(document.id) == self.data.get('target_document_id')]) < 1:
raise AppApiException(5000, f"目标文档id不存在【{self.data.get('target_document_id')}】")

@transaction.atomic
def migrate(self, with_valid=True):
if with_valid:
self.is_valid(raise_exception=True)
dataset_id = self.data.get('dataset_id')
target_dataset_id = self.data.get('target_dataset_id')
document_id = self.data.get('document_id')
target_document_id = self.data.get('target_document_id')
paragraph_id_list = self.data.get('paragraph_id_list')
paragraph_list = QuerySet(Paragraph).filter(dataset_id=dataset_id, document_id=document_id,
id__in=paragraph_id_list)
problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list)
# 同数据集迁移
if target_dataset_id == dataset_id:
if len(problem_paragraph_mapping_list):
problem_paragraph_mapping_list = [
self.update_problem_paragraph_mapping(target_document_id,
problem_paragraph_mapping) for problem_paragraph_mapping
in
problem_paragraph_mapping_list]
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['document_id'])
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
[paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id))
# 修改段落信息
paragraph_list.update(document_id=target_document_id)
# 不同数据集迁移
else:
problem_list = QuerySet(Problem).filter(
id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in
problem_paragraph_mapping_list])
# 目标数据集问题
target_problem_list = list(
QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list],
dataset_id=target_dataset_id))

target_handle_problem_list = [
self.get_target_dataset_problem(target_dataset_id, target_document_id, problem_paragraph_mapping,
problem_list, target_problem_list) for
problem_paragraph_mapping
in
problem_paragraph_mapping_list]

create_problem_list = [problem for problem, is_create in target_handle_problem_list if
is_create is not None and is_create]
# 插入问题
QuerySet(Problem).bulk_create(create_problem_list)
# 修改mapping
QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list,
['problem_id', 'dataset_id', 'document_id'])
# 修改向量段落信息
ListenerManagement.update_embedding_document_id(UpdateEmbeddingDocumentIdArgs(
[paragraph.id for paragraph in paragraph_list],
target_document_id, target_dataset_id))
# 修改段落信息
paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id)

@staticmethod
def update_problem_paragraph_mapping(target_document_id: str, problem_paragraph_mapping):
problem_paragraph_mapping.document_id = target_document_id
return problem_paragraph_mapping

@staticmethod
def get_target_dataset_problem(target_dataset_id: str,
target_document_id: str,
problem_paragraph_mapping,
source_problem_list,
target_problem_list):
source_problem_list = [source_problem for source_problem in source_problem_list if
source_problem.id == problem_paragraph_mapping.problem_id]
problem_paragraph_mapping.dataset_id = target_dataset_id
problem_paragraph_mapping.document_id = target_document_id
if len(source_problem_list) > 0:
problem_content = source_problem_list[-1].content
problem_list = [problem for problem in target_problem_list if problem.content == problem_content]
if len(problem_list) > 0:
problem = problem_list[-1]
problem_paragraph_mapping.problem_id = problem.id
return problem, False
else:
problem = Problem(id=uuid.uuid1(), dataset_id=target_dataset_id, content=problem_content)
target_problem_list.append(problem)
problem_paragraph_mapping.problem_id = problem.id
return problem, True
return None

@staticmethod
def get_request_params_api():
return [openapi.Parameter(name='dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='文档id'),
openapi.Parameter(name='target_dataset_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='目标知识库id'),
openapi.Parameter(name='target_document_id',
in_=openapi.IN_PATH,
type=openapi.TYPE_STRING,
required=True,
description='目标知识库id')
]

@staticmethod
def get_request_body_api():
return openapi.Schema(
type=openapi.TYPE_ARRAY,
items=openapi.Schema(type=openapi.TYPE_STRING),
title='段落id列表',
description="段落id列表"
)

class Operate(ApiMixin, serializers.Serializer):
# 段落id
paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char(
Expand Down
4 changes: 4 additions & 0 deletions apps/dataset/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
path('dataset/<str:dataset_id>/document/migrate/<str:target_dataset_id>', views.Document.Migrate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/refresh', views.Document.Refresh.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph', views.Paragraph.as_view()),
path(
'dataset/<str:dataset_id>/document/<str:document_id>/paragraph/migrate/dataset/<str:target_dataset_id>/document/<str:target_document_id>',
views.Paragraph.BatchMigrate.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/_batch', views.Paragraph.Batch.as_view()),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<int:current_page>/<int:page_size>',
views.Paragraph.Page.as_view(), name='paragraph_page'),
path('dataset/<str:dataset_id>/document/<str:document_id>/paragraph/<paragraph_id>',
Expand Down
47 changes: 46 additions & 1 deletion apps/dataset/views/paragraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from rest_framework.views import Request

from common.auth import TokenAuth, has_permissions
from common.constants.permission_constants import Permission, Group, Operate
from common.constants.permission_constants import Permission, Group, Operate, CompareConstants
from common.response import result
from common.util.common import query_params_to_single_dict
from dataset.serializers.common_serializers import BatchSerializer
from dataset.serializers.paragraph_serializers import ParagraphSerializers


Expand Down Expand Up @@ -168,6 +169,50 @@ def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_
o.is_valid(raise_exception=True)
return result.success(o.delete())

class Batch(APIView):
authentication_classes = [TokenAuth]

@action(methods=['DELETE'], detail=False)
@swagger_auto_schema(operation_summary="批量删除段落",
operation_id="批量删除段落",
request_body=
BatchSerializer.get_request_body_api(),
manual_parameters=ParagraphSerializers.Create.get_request_params_api(),
responses=result.get_default_response(),
tags=["知识库/文档/段落"])
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')))
def delete(self, request: Request, dataset_id: str, document_id: str):
return result.success(ParagraphSerializers.Batch(
data={"dataset_id": dataset_id, 'document_id': document_id}).batch_delete(request.data))

class BatchMigrate(APIView):
authentication_classes = [TokenAuth]

@action(methods=['PUT'], detail=False)
@swagger_auto_schema(operation_summary="批量迁移段落",
operation_id="批量迁移段落",
manual_parameters=ParagraphSerializers.Migrate.get_request_params_api(),
request_body=ParagraphSerializers.Migrate.get_request_body_api(),
responses=result.get_default_response(),
tags=["知识库/文档/段落"]
)
@has_permissions(
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('dataset_id')),
lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE,
dynamic_tag=k.get('target_dataset_id')),
compare=CompareConstants.AND
)
def put(self, request: Request, dataset_id: str, target_dataset_id: str, document_id: str, target_document_id):
return result.success(
ParagraphSerializers.Migrate(
data={'dataset_id': dataset_id, 'target_dataset_id': target_dataset_id,
'document_id': document_id,
'target_document_id': target_document_id,
'paragraph_id_list': request.data}).migrate())

class Page(APIView):
authentication_classes = [TokenAuth]

Expand Down
10 changes: 9 additions & 1 deletion apps/embedding/vector/base_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def search(self, query_text, dataset_id_list: list[str], exclude_document_id_lis
return result[0]

@abstractmethod
def query(self, query_text:str,query_embedding: List[float], dataset_id_list: list[str],
def query(self, query_text: str, query_embedding: List[float], dataset_id_list: list[str],
exclude_document_id_list: list[str],
exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float,
search_mode: SearchMode):
Expand All @@ -130,6 +130,10 @@ def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list:
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
pass

@abstractmethod
def update_by_paragraph_ids(self, paragraph_ids: str, instance: Dict):
pass

@abstractmethod
def update_by_source_id(self, source_id: str, instance: Dict):
pass
Expand Down Expand Up @@ -173,3 +177,7 @@ def delete_by_source_ids(self, source_ids: List[str], source_type: str):
@abstractmethod
def delete_by_paragraph_id(self, paragraph_id: str):
pass

@abstractmethod
def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
pass
6 changes: 6 additions & 0 deletions apps/embedding/vector/pg_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def update_by_source_id(self, source_id: str, instance: Dict):
def update_by_paragraph_id(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance)

def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict):
QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance)

def delete_by_dataset_id(self, dataset_id: str):
QuerySet(Embedding).filter(dataset_id=dataset_id).delete()

Expand All @@ -139,6 +142,9 @@ def delete_by_source_id(self, source_id: str, source_type: str):
def delete_by_paragraph_id(self, paragraph_id: str):
QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete()

def delete_by_paragraph_ids(self, paragraph_ids: List[str]):
QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete()


class ISearch(ABC):
@abstractmethod
Expand Down

0 comments on commit d4e742f

Please sign in to comment.