diff --git a/backend/apps/data_training/api/data_training.py b/backend/apps/data_training/api/data_training.py
index 540eaf84..58f4a1f9 100644
--- a/backend/apps/data_training/api/data_training.py
+++ b/backend/apps/data_training/api/data_training.py
@@ -1,9 +1,15 @@
+import asyncio
+import io
from typing import Optional
+import pandas as pd
from fastapi import APIRouter, Query
+from fastapi.responses import StreamingResponse
+from apps.chat.models.chat_model import AxisObj
+from apps.chat.task.llm import LLMService
from apps.data_training.curd.data_training import page_data_training, create_training, update_training, delete_training, \
- enable_training
+ enable_training, get_all_data_training
from apps.data_training.models.data_training_model import DataTrainingInfo
from common.core.deps import SessionDep, CurrentUser, Trans
@@ -43,3 +49,44 @@ async def delete(session: SessionDep, id_list: list[int]):
@router.get("/{id}/enable/{enabled}")
async def enable(session: SessionDep, id: int, enabled: bool, trans: Trans):
enable_training(session, id, enabled, trans)
+
+
+@router.get("/export")
+async def export_excel(session: SessionDep, trans: Trans, current_user: CurrentUser,
+ word: Optional[str] = Query(None, description="搜索术语(可选)")):
+ def inner():
+ _list = get_all_data_training(session, word, oid=current_user.oid)
+
+ data_list = []
+ for obj in _list:
+ _data = {
+ "question": obj.question,
+ "description": obj.description,
+ "datasource_name": obj.datasource_name,
+ "advanced_application_name": obj.advanced_application_name,
+ }
+ data_list.append(_data)
+
+ fields = []
+ fields.append(AxisObj(name=trans('i18n_data_training.data_training'), value='question'))
+ fields.append(AxisObj(name=trans('i18n_data_training.problem_description'), value='description'))
+ fields.append(AxisObj(name=trans('i18n_data_training.effective_data_sources'), value='datasource_name'))
+ if current_user.oid == 1:
+ fields.append(
+ AxisObj(name=trans('i18n_data_training.advanced_application'), value='advanced_application_name'))
+
+ md_data, _fields_list = LLMService.convert_object_array_for_pandas(fields, data_list)
+
+ df = pd.DataFrame(md_data, columns=_fields_list)
+
+ buffer = io.BytesIO()
+
+ with pd.ExcelWriter(buffer, engine='xlsxwriter',
+ engine_kwargs={'options': {'strings_to_numbers': False}}) as writer:
+ df.to_excel(writer, sheet_name='Sheet1', index=False)
+
+ buffer.seek(0)
+ return io.BytesIO(buffer.getvalue())
+
+ result = await asyncio.to_thread(inner)
+ return StreamingResponse(result, media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
diff --git a/backend/apps/data_training/curd/data_training.py b/backend/apps/data_training/curd/data_training.py
index 011b7fdb..13291e56 100644
--- a/backend/apps/data_training/curd/data_training.py
+++ b/backend/apps/data_training/curd/data_training.py
@@ -18,42 +18,61 @@
from common.utils.embedding_threads import run_save_data_training_embeddings
-def page_data_training(session: SessionDep, current_page: int = 1, page_size: int = 10, name: Optional[str] = None,
- oid: Optional[int] = 1):
- _list: List[DataTrainingInfoResult] = []
-
- current_page = max(1, current_page)
- page_size = max(10, page_size)
-
- total_count = 0
- total_pages = 0
-
+def get_data_training_base_query(oid: int, name: Optional[str] = None):
+ """
+ 获取数据训练查询的基础查询结构
+ """
if name and name.strip() != "":
keyword_pattern = f"%{name.strip()}%"
parent_ids_subquery = (
select(DataTraining.id)
- .where(and_(DataTraining.question.ilike(keyword_pattern), DataTraining.oid == oid)) # LIKE查询条件
+ .where(and_(DataTraining.question.ilike(keyword_pattern), DataTraining.oid == oid))
)
else:
parent_ids_subquery = (
select(DataTraining.id).where(and_(DataTraining.oid == oid))
)
+ return parent_ids_subquery
+
+
+def build_data_training_query(session: SessionDep, oid: int, name: Optional[str] = None,
+ paginate: bool = True, current_page: int = 1, page_size: int = 10):
+ """
+ 构建数据训练查询的通用方法
+ """
+ parent_ids_subquery = get_data_training_base_query(oid, name)
+
+ # 计算总数
count_stmt = select(func.count()).select_from(parent_ids_subquery.subquery())
total_count = session.execute(count_stmt).scalar()
- total_pages = (total_count + page_size - 1) // page_size
- if current_page > total_pages:
+ if paginate:
+ # 分页处理
+ page_size = max(10, page_size)
+ total_pages = (total_count + page_size - 1) // page_size
+ current_page = max(1, min(current_page, total_pages)) if total_pages > 0 else 1
+
+ paginated_parent_ids = (
+ parent_ids_subquery
+ .order_by(DataTraining.create_time.desc())
+ .offset((current_page - 1) * page_size)
+ .limit(page_size)
+ .subquery()
+ )
+ else:
+ # 不分页,获取所有数据
+ total_pages = 1
current_page = 1
+ page_size = total_count if total_count > 0 else 1
- paginated_parent_ids = (
- parent_ids_subquery
- .order_by(DataTraining.create_time.desc())
- .offset((current_page - 1) * page_size)
- .limit(page_size)
- .subquery()
- )
+ paginated_parent_ids = (
+ parent_ids_subquery
+ .order_by(DataTraining.create_time.desc())
+ .subquery()
+ )
+ # 构建主查询
stmt = (
select(
DataTraining.id,
@@ -74,6 +93,14 @@ def page_data_training(session: SessionDep, current_page: int = 1, page_size: in
.order_by(DataTraining.create_time.desc())
)
+ return stmt, total_count, total_pages, current_page, page_size
+
+
+def execute_data_training_query(session: SessionDep, stmt) -> List[DataTrainingInfoResult]:
+ """
+ 执行查询并返回数据训练信息列表
+ """
+ _list = []
result = session.execute(stmt)
for row in result:
@@ -90,9 +117,34 @@ def page_data_training(session: SessionDep, current_page: int = 1, page_size: in
advanced_application_name=row.advanced_application_name,
))
+ return _list
+
+
+def page_data_training(session: SessionDep, current_page: int = 1, page_size: int = 10,
+ name: Optional[str] = None, oid: Optional[int] = 1):
+ """
+ 分页查询数据训练(原方法保持不变)
+ """
+ stmt, total_count, total_pages, current_page, page_size = build_data_training_query(
+ session, oid, name, True, current_page, page_size
+ )
+ _list = execute_data_training_query(session, stmt)
+
return current_page, page_size, total_count, total_pages, _list
+def get_all_data_training(session: SessionDep, name: Optional[str] = None, oid: Optional[int] = 1):
+ """
+ 获取所有数据训练(不分页)
+ """
+ stmt, total_count, total_pages, current_page, page_size = build_data_training_query(
+ session, oid, name, False
+ )
+ _list = execute_data_training_query(session, stmt)
+
+ return _list
+
+
def create_training(session: SessionDep, info: DataTrainingInfo, oid: int, trans: Trans):
create_time = datetime.datetime.now()
if info.datasource is None and info.advanced_application is None:
diff --git a/backend/locales/en.json b/backend/locales/en.json
index 131a5ec7..44cc60ed 100644
--- a/backend/locales/en.json
+++ b/backend/locales/en.json
@@ -53,7 +53,12 @@
"datasource_cannot_be_none": "Datasource cannot be empty",
"datasource_assistant_cannot_be_none": "Datasource or advanced application cannot both be empty",
"data_training_not_exists": "This example does not exist",
- "exists_in_db": "This question already exists"
+ "exists_in_db": "This question already exists",
+ "data_training": "SQL Example Library",
+ "problem_description": "Problem Description",
+ "sample_sql": "Sample SQL",
+ "effective_data_sources": "Effective Data Sources",
+ "advanced_application": "Advanced Application"
},
"i18n_custom_prompt": {
"exists_in_db": "Template name already exists",
diff --git a/backend/locales/ko-KR.json b/backend/locales/ko-KR.json
index e7a06ba3..ce87967f 100644
--- a/backend/locales/ko-KR.json
+++ b/backend/locales/ko-KR.json
@@ -53,7 +53,12 @@
"datasource_cannot_be_none": "데이터 소스는 비울 수 없습니다",
"datasource_assistant_cannot_be_none": "데이터 소스와 고급 애플리케이션을 모두 비울 수 없습니다",
"data_training_not_exists": "이 예시가 존재하지 않습니다",
- "exists_in_db": "이 질문이 이미 존재합니다"
+ "exists_in_db": "이 질문이 이미 존재합니다",
+ "data_training": "SQL 예시 라이브러리",
+ "problem_description": "문제 설명",
+ "sample_sql": "예시 SQL",
+ "effective_data_sources": "유효 데이터 소스",
+ "advanced_application": "고급 애플리케이션"
},
"i18n_custom_prompt": {
"exists_in_db": "템플릿 이름이 이미 존재합니다",
diff --git a/backend/locales/zh-CN.json b/backend/locales/zh-CN.json
index c2dbdcaf..1b057e0a 100644
--- a/backend/locales/zh-CN.json
+++ b/backend/locales/zh-CN.json
@@ -53,7 +53,12 @@
"datasource_cannot_be_none": "数据源不能为空",
"datasource_assistant_cannot_be_none": "数据源或高级应用不能都为空",
"data_training_not_exists": "该示例不存在",
- "exists_in_db": "该问题已存在"
+ "exists_in_db": "该问题已存在",
+ "data_training": "SQL 示例库",
+ "problem_description": "问题描述",
+ "sample_sql": "示例 SQL",
+ "effective_data_sources": "生效数据源",
+ "advanced_application": "高级应用"
},
"i18n_custom_prompt": {
"exists_in_db": "模版名称已存在",
diff --git a/frontend/src/api/training.ts b/frontend/src/api/training.ts
index a4342145..ebc68e69 100644
--- a/frontend/src/api/training.ts
+++ b/frontend/src/api/training.ts
@@ -9,4 +9,10 @@ export const trainingApi = {
deleteEmbedded: (params: any) => request.delete('/system/data-training', { data: params }),
getOne: (id: any) => request.get(`/system/data-training/${id}`),
enable: (id: any, enabled: any) => request.get(`/system/data-training/${id}/enable/${enabled}`),
+ export2Excel: (params: any) =>
+ request.get(`/system/data-training/export`, {
+ params,
+ responseType: 'blob',
+ requestOptions: { customError: true },
+ }),
}
diff --git a/frontend/src/i18n/en.json b/frontend/src/i18n/en.json
index e12906dd..f0c0edcd 100644
--- a/frontend/src/i18n/en.json
+++ b/frontend/src/i18n/en.json
@@ -40,6 +40,7 @@
"training_data_items": "Do you want to delete the {msg} selected SQL Sample items?",
"sql_statement": "SQL Statement",
"edit_training_data": "Edit SQL Sample",
+ "all_236_terms": "Export all {msg} sample SQL records?",
"sales_this_year": "Do you want to delete the SQL Sample: {msg}?"
},
"professional": {
diff --git a/frontend/src/i18n/ko-KR.json b/frontend/src/i18n/ko-KR.json
index 342b47ba..09f19bbd 100644
--- a/frontend/src/i18n/ko-KR.json
+++ b/frontend/src/i18n/ko-KR.json
@@ -40,6 +40,7 @@
"training_data_items": "선택된 {msg}개의 예제 SQL을 삭제하시겠습니까?",
"sql_statement": "SQL 문",
"edit_training_data": "예제 SQL 편집",
+ "all_236_terms": "모든 {msg}개의 예시 SQL 기록을 내보내시겠습니까?",
"sales_this_year": "예제 SQL을 삭제하시겠습니까: {msg}?"
},
"professional": {
diff --git a/frontend/src/i18n/zh-CN.json b/frontend/src/i18n/zh-CN.json
index 4ffb3f58..e95b659b 100644
--- a/frontend/src/i18n/zh-CN.json
+++ b/frontend/src/i18n/zh-CN.json
@@ -40,6 +40,7 @@
"training_data_items": "是否删除选中的 {msg} 条示例 SQL?",
"sql_statement": "SQL 语句",
"edit_training_data": "编辑示例 SQL",
+ "all_236_terms": "是否导出全部 {msg} 条示例 SQL?",
"sales_this_year": "是否删除示例 SQL:{msg}?"
},
"professional": {
diff --git a/frontend/src/views/system/training/index.vue b/frontend/src/views/system/training/index.vue
index 381ee218..772e489a 100644
--- a/frontend/src/views/system/training/index.vue
+++ b/frontend/src/views/system/training/index.vue
@@ -85,46 +85,60 @@ const cancelDelete = () => {
checkAll.value = false
isIndeterminate.value = false
}
-const exportBatchUser = () => {
- ElMessageBox.confirm(
- t('professional.selected_2_terms_de', { msg: multipleSelectionAll.value.length }),
- {
- confirmButtonType: 'primary',
- confirmButtonText: t('professional.export'),
- cancelButtonText: t('common.cancel'),
- customClass: 'confirm-no_icon',
- autofocus: false,
- }
- ).then(() => {
- trainingApi.deleteEmbedded(multipleSelectionAll.value.map((ele) => ele.id)).then(() => {
- ElMessage({
- type: 'success',
- message: t('dashboard.delete_success'),
- })
- multipleSelectionAll.value = []
- search()
- })
- })
-}
-const exportAllUser = () => {
- ElMessageBox.confirm(t('professional.all_236_terms', { msg: pageInfo.total }), {
+const exportExcel = () => {
+ ElMessageBox.confirm(t('training.all_236_terms', { msg: pageInfo.total }), {
confirmButtonType: 'primary',
confirmButtonText: t('professional.export'),
cancelButtonText: t('common.cancel'),
customClass: 'confirm-no_icon',
autofocus: false,
}).then(() => {
- trainingApi.deleteEmbedded(multipleSelectionAll.value.map((ele) => ele.id)).then(() => {
- ElMessage({
- type: 'success',
- message: t('dashboard.delete_success'),
+ searchLoading.value = true
+ trainingApi
+ .export2Excel(keywords.value ? { question: keywords.value } : {})
+ .then((res) => {
+ const blob = new Blob([res], {
+ type: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
+ })
+ const link = document.createElement('a')
+ link.href = URL.createObjectURL(blob)
+ link.download = `${t('training.data_training')}.xlsx`
+ document.body.appendChild(link)
+ link.click()
+ document.body.removeChild(link)
+ })
+ .catch(async (error) => {
+ if (error.response) {
+ try {
+ let text = await error.response.data.text()
+ try {
+ text = JSON.parse(text)
+ } finally {
+ ElMessage({
+ message: text,
+ type: 'error',
+ showClose: true,
+ })
+ }
+ } catch (e) {
+ console.error('Error processing error response:', e)
+ }
+ } else {
+ console.error('Other error:', error)
+ ElMessage({
+ message: error,
+ type: 'error',
+ showClose: true,
+ })
+ }
+ })
+ .finally(() => {
+ searchLoading.value = false
})
- multipleSelectionAll.value = []
- search()
- })
})
}
+
const deleteBatchUser = () => {
ElMessageBox.confirm(
t('training.training_data_items', { msg: multipleSelectionAll.value.length }),
@@ -363,20 +377,18 @@ const onRowFormClose = () => {
-
-
-
-
-
- {{ $t('professional.export_all') }}
-
-
-
-
-
- {{ $t('user.batch_import') }}
-
-
+
+
+
+
+ {{ $t('professional.export_all') }}
+
+
+
+
+
+ {{ $t('user.batch_import') }}
+
@@ -501,9 +513,6 @@ const onRowFormClose = () => {
>
{{ $t('datasource.select_all') }}
-
@@ -646,7 +655,7 @@ const onRowFormClose = () => {
position: relative;
:deep(.ed-table__empty-text) {
- padding-top: 160px;
+ padding-top: 160px;
}
.datasource-yet {