Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add keyword table s3 storage support #3065

Merged
merged 5 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,4 @@ SSRF_PROXY_HTTP_URL=
SSRF_PROXY_HTTPS_URL=

BATCH_UPLOAD_LIMIT=10
KEYWORD_DATA_SOURCE_TYPE=database
2 changes: 2 additions & 0 deletions api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
'KEYWORD_STORE': 'jieba',
'BATCH_UPLOAD_LIMIT': 20,
'TOOL_ICON_CACHE_MAX_AGE': 3600,
'KEYWORD_DATA_SOURCE_TYPE': 'database',
}


Expand Down Expand Up @@ -303,6 +304,7 @@ def __init__(self):
self.API_COMPRESSION_ENABLED = get_bool_env('API_COMPRESSION_ENABLED')
self.TOOL_ICON_CACHE_MAX_AGE = get_env('TOOL_ICON_CACHE_MAX_AGE')

self.KEYWORD_DATA_SOURCE_TYPE = get_env('KEYWORD_DATA_SOURCE_TYPE')

class CloudEditionConfig(Config):

Expand Down
30 changes: 24 additions & 6 deletions api/core/rag/datasource/keyword/jieba/jieba.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from collections import defaultdict
from typing import Any, Optional

from flask import current_app
from pydantic import BaseModel

from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler
from core.rag.datasource.keyword.keyword_base import BaseKeyword
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment


Expand Down Expand Up @@ -108,6 +110,9 @@ def delete(self) -> None:
if dataset_keyword_table:
db.session.delete(dataset_keyword_table)
db.session.commit()
if dataset_keyword_table.data_source_type != 'database':
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
storage.delete(file_key)

def _save_dataset_keyword_table(self, keyword_table):
keyword_table_dict = {
Expand All @@ -118,28 +123,41 @@ def _save_dataset_keyword_table(self, keyword_table):
"table": keyword_table
}
}
self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()

keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
if keyword_data_source_type == 'database':
self.dataset.dataset_keyword_table.keyword_table = json.dumps(keyword_table_dict, cls=SetEncoder)
db.session.commit()
else:
file_key = 'keyword_files/' + self.dataset.tenant_id + '/' + self.dataset.id + '.txt'
if storage.exists(file_key):
storage.delete(file_key)
storage.save(file_key, json.dumps(keyword_table_dict, cls=SetEncoder).encode('utf-8'))

def _get_dataset_keyword_table(self) -> Optional[dict]:
lock_name = 'keyword_indexing_lock_{}'.format(self.dataset.id)
with redis_client.lock(lock_name, timeout=20):
dataset_keyword_table = self.dataset.dataset_keyword_table
if dataset_keyword_table:
if dataset_keyword_table.keyword_table_dict:
return dataset_keyword_table.keyword_table_dict['__data__']['table']
keyword_table_dict = dataset_keyword_table.keyword_table_dict
if keyword_table_dict:
return keyword_table_dict['__data__']['table']
else:
keyword_data_source_type = current_app.config['KEYWORD_DATA_SOURCE_TYPE']
dataset_keyword_table = DatasetKeywordTable(
dataset_id=self.dataset.id,
keyword_table=json.dumps({
keyword_table='',
data_source_type=keyword_data_source_type,
)
if keyword_data_source_type == 'database':
dataset_keyword_table.keyword_table = json.dumps({
'__type__': 'keyword_table',
'__data__': {
"index_id": self.dataset.id,
"summary": None,
"table": {}
}
}, cls=SetEncoder)
)
db.session.add(dataset_keyword_table)
db.session.commit()

Expand Down
14 changes: 14 additions & 0 deletions api/extensions/ext_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ def exists(self, filename):

return os.path.exists(filename)

def delete(self, filename):
if self.storage_type == 's3':
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
elif self.storage_type == 'azure-blob':
blob_container = self.client.get_container_client(container=self.bucket_name)
blob_container.delete_blob(filename)
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
if os.path.exists(filename):
os.remove(filename)


storage = Storage()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""add-keyworg-table-storage-type

Revision ID: 17b5ab037c40
Revises: a8f9b3c45e4a
Create Date: 2024-04-01 09:48:54.232201

"""
import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision = '17b5ab037c40'
down_revision = 'a8f9b3c45e4a'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###

with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op:
batch_op.add_column(sa.Column('data_source_type', sa.String(length=255), server_default=sa.text("'database'::character varying"), nullable=False))

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###

with op.batch_alter_table('dataset_keyword_tables', schema=None) as batch_op:
batch_op.drop_column('data_source_type')

# ### end Alembic commands ###
23 changes: 21 additions & 2 deletions api/models/dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
import logging
import pickle
from json import JSONDecodeError

from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB, UUID

from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Account
from models.model import App, UploadFile

Expand Down Expand Up @@ -441,6 +443,7 @@ class DatasetKeywordTable(db.Model):
id = db.Column(UUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
dataset_id = db.Column(UUID, nullable=False, unique=True)
keyword_table = db.Column(db.Text, nullable=False)
data_source_type = db.Column(db.String(255), nullable=False, server_default=db.text("'database'::character varying"))

@property
def keyword_table_dict(self):
Expand All @@ -454,8 +457,24 @@ def object_hook(self, dct):
if isinstance(node_idxs, list):
dct[keyword] = set(node_idxs)
return dct

return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
# get dataset
dataset = Dataset.query.filter_by(
id=self.dataset_id
).first()
if not dataset:
return None
if self.data_source_type == 'database':
return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
else:
file_key = 'keyword_files/' + dataset.tenant_id + '/' + self.dataset_id + '.txt'
try:
keyword_table_text = storage.load_once(file_key)
if keyword_table_text:
return json.loads(keyword_table_text.decode('utf-8'), cls=SetDecoder)
return None
except Exception as e:
logging.exception(str(e))
return None


class Embedding(db.Model):
Expand Down