Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 2 additions & 30 deletions backend/apps/chat/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,42 +206,14 @@ def _err(_e: Exception):
@router.post("/excel/export")
async def export_excel(excel_data: ExcelData, trans: Trans):
def inner():
_fields_list = []
data = []

if not excel_data.data:
raise HTTPException(
status_code=500,
detail=trans("i18n_excel_export.data_is_empty")
)

# 预处理数据并记录每列的格式类型
col_formats = {} # 格式类型:'text'(文本)、'number'(数字)、'default'(默认)
for field_idx, field in enumerate(excel_data.axis):
_fields_list.append(field.name)
col_formats[field_idx] = 'default' # 默认不特殊处理

for _data in excel_data.data:
_row = []
for field_idx, field in enumerate(excel_data.axis):
value = _data.get(field.value)
if value is not None:
# 检查是否为数字且需要特殊处理
if isinstance(value, (int, float)):
# 整数且超过15位 → 转字符串并标记为文本列
if isinstance(value, int) and len(str(abs(value))) > 15:
value = str(value)
col_formats[field_idx] = 'text'
# 小数且超过15位有效数字 → 转字符串并标记为文本列
elif isinstance(value, float):
decimal_str = format(value, '.16f').rstrip('0').rstrip('.')
if len(decimal_str) > 15:
value = str(value)
col_formats[field_idx] = 'text'
# 其他数字列标记为数字格式(避免科学记数法)
elif col_formats[field_idx] != 'text':
col_formats[field_idx] = 'number'
_row.append(value)
data.append(_row)
data, _fields_list, col_formats = LLMService.format_pd_data(excel_data.axis, excel_data.data)

df = pd.DataFrame(data, columns=_fields_list)

Expand Down
100 changes: 75 additions & 25 deletions backend/apps/chat/task/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
get_chat_chart_data, list_generate_sql_logs, list_generate_chart_logs, start_log, end_log, \
get_last_execute_sql_error
from apps.chat.models.chat_model import ChatQuestion, ChatRecord, Chat, RenameChat, ChatLog, OperationEnum, \
ChatFinishStep
ChatFinishStep, AxisObj
from apps.data_training.curd.data_training import get_training_template
from apps.datasource.crud.datasource import get_table_schema
from apps.datasource.crud.permission import get_row_permission_filters, is_normal_user
Expand Down Expand Up @@ -414,7 +414,7 @@ def select_datasource(self, _session: Session):
if settings.TABLE_EMBEDDING_ENABLED and (
not self.current_assistant or (self.current_assistant and self.current_assistant.type != 1)):
_ds_list = get_ds_embedding(_session, self.current_user, _ds_list, self.out_ds_instance,
self.chat_question.question, self.current_assistant)
self.chat_question.question, self.current_assistant)
# yield {'content': '{"id":' + str(ds.get('id')) + '}'}

_ds_list_dict = []
Expand Down Expand Up @@ -1056,23 +1056,18 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
if in_chat:
yield 'data:' + orjson.dumps({'type': 'finish'}).decode() + '\n\n'
else:
data = []
_fields_list = []
_fields_skip = False
for _data in result.get('data'):
_row = []
for field in result.get('fields'):
_row.append(_data.get(field))
if not _fields_skip:
_fields_list.append(field)
data.append(_row)
_fields_skip = True
_column_list = []
for field in result.get('fields'):
_column_list.append(AxisObj(name=field, value=field))

data, _fields_list, col_formats = self.format_pd_data(_column_list, result.get('data'))

if not data or not _fields_list:
yield 'The SQL execution result is empty.\n\n'
else:
df = pd.DataFrame(data, columns=_fields_list)
markdown_table = df.to_markdown(index=False)
df_safe = self.safe_convert_to_string(df)
markdown_table = df_safe.to_markdown(index=False)
yield markdown_table + '\n\n'
else:
yield json_result
Expand Down Expand Up @@ -1117,22 +1112,19 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
if chart.get('axis').get('series'):
_fields[chart.get('axis').get('series').get('value')] = chart.get('axis').get('series').get(
'name')
_fields_list = []
_fields_skip = False
for _data in result.get('data'):
_row = []
for field in result.get('fields'):
_row.append(_data.get(field))
if not _fields_skip:
_fields_list.append(field if not _fields.get(field) else _fields.get(field))
data.append(_row)
_fields_skip = True
_column_list = []
for field in result.get('fields'):
_column_list.append(
AxisObj(name=field if not _fields.get(field) else _fields.get(field), value=field))

data, _fields_list, col_formats = self.format_pd_data(_column_list, result.get('data'))

if not data or not _fields_list:
yield 'The SQL execution result is empty.\n\n'
else:
df = pd.DataFrame(data, columns=_fields_list)
markdown_table = df.to_markdown(index=False)
df_safe = self.safe_convert_to_string(df)
markdown_table = df_safe.to_markdown(index=False)
yield markdown_table + '\n\n'

if in_chat:
Expand Down Expand Up @@ -1179,6 +1171,64 @@ def run_task(self, in_chat: bool = True, stream: bool = True,
self.finish(_session)
session_maker.remove()

@staticmethod
def safe_convert_to_string(df):
"""
安全地将数值列转换为字符串,避免科学记数法
"""
df_copy = df.copy()

for col in df_copy.columns:
# 只处理数值类型的列
if pd.api.types.is_numeric_dtype(df_copy[col]):
try:
df_copy[col] = df_copy[col].astype(str)
except Exception as e:
print(f"列 {col} 转换失败: {e}")
# 如果转换失败,保持原样
continue

return df_copy

@staticmethod
def format_pd_data(column_list: list, data_list: list, col_formats: dict = None):
# 预处理数据并记录每列的格式类型
# 格式类型:'text'(文本)、'number'(数字)、'default'(默认)
_fields_list = []

if col_formats is None:
col_formats = {}
for field_idx, field in enumerate(column_list):
_fields_list.append(field.name)
col_formats[field_idx] = 'default' # 默认不特殊处理

data = []

for _data in data_list:
_row = []
for field_idx, field in enumerate(column_list):
value = _data.get(field.value)
if value is not None:
# 检查是否为数字且需要特殊处理
if isinstance(value, (int, float)):
# 整数且超过15位 → 转字符串并标记为文本列
if isinstance(value, int) and len(str(abs(value))) > 15:
value = str(value)
col_formats[field_idx] = 'text'
# 小数且超过15位有效数字 → 转字符串并标记为文本列
elif isinstance(value, float):
decimal_str = format(value, '.16f').rstrip('0').rstrip('.')
if len(decimal_str) > 15:
value = str(value)
col_formats[field_idx] = 'text'
# 其他数字列标记为数字格式(避免科学记数法)
elif col_formats[field_idx] != 'text':
col_formats[field_idx] = 'number'
_row.append(value)
data.append(_row)

return data, _fields_list, col_formats

def run_recommend_questions_task_async(self):
self.future = executor.submit(self.run_recommend_questions_task_cache)

Expand Down