Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 36 additions & 61 deletions backend/apps/db/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,16 +404,7 @@ def get_tables(ds: CoreDatasource):
"excel") else get_engine_config()
db = DB.get_db(ds.type)
sql, sql_param = get_table_sql(ds, conf, get_version(ds))
if equals_ignore_case(ds.type, "sqlite"):
engine = get_engine(ds)
with engine.raw_connection() as conn:
cursor = conn.cursor()
cursor.execute(sql)
res = cursor.fetchall()
cursor.close()
res_list = [TableSchema(*item) for item in res]
return res_list
elif db.connect_type == ConnectType.sqlalchemy:
if db.connect_type == ConnectType.sqlalchemy:
with get_session(ds) as session:
with session.execute(text(sql), {"param": sql_param}) as result:
res = result.fetchall()
Expand Down Expand Up @@ -460,36 +451,27 @@ def get_tables(ds: CoreDatasource):
res_list = [TableSchema(*item) for item in res]
return res_list
elif equals_ignore_case(ds.type, 'hive'):
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
database=conf.database, **extra_config_dict)
cursor = conn.cursor()
cursor.execute(sql)
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
cursor.close()
conn.close()
return res_list
with hive.connect(host=conf.host, port=conf.port, username=conf.username,
database=conf.database, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute(sql)
res = cursor.fetchall()
res_list = [TableSchema(*item) for item in res]
return res_list


def get_fields(ds: CoreDatasource, table_name: str = None):
conf = DatasourceConf(**json.loads(aes_decrypt(ds.configuration))) if not equals_ignore_case(ds.type,
"excel") else get_engine_config()
db = DB.get_db(ds.type)
sql, p1, p2 = get_field_sql(ds, conf, table_name)
if equals_ignore_case(ds.type, "sqlite"):
engine = get_engine(ds)
with engine.raw_connection() as conn:
cursor = conn.cursor()
cursor.execute(sql)
res = cursor.fetchall()
cursor.close()
res_list = [ColumnSchema(item[1], item[2], '') for item in res]
return res_list
elif db.connect_type == ConnectType.sqlalchemy:
if db.connect_type == ConnectType.sqlalchemy:
with get_session(ds) as session:
with session.execute(text(sql), {"param1": p1, "param2": p2}) as result:
res = result.fetchall()
res_list = [ColumnSchema(*item) for item in res]
if equals_ignore_case(ds.type, "sqlite"):
res_list = [ColumnSchema(item[1], item[2], '') for item in res]
else:
res_list = [ColumnSchema(*item) for item in res]
return res_list
else:
extra_config_dict = get_extra_config(conf)
Expand Down Expand Up @@ -532,15 +514,12 @@ def get_fields(ds: CoreDatasource, table_name: str = None):
res_list = [ColumnSchema(*item) for item in res]
return res_list
elif equals_ignore_case(ds.type, 'hive'):
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
database=conf.database, **extra_config_dict)
cursor = conn.cursor()
cursor.execute(sql)
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
cursor.close()
conn.close()
return res_list
with hive.connect(host=conf.host, port=conf.port, username=conf.username,
database=conf.database, **extra_config_dict) as conn, conn.cursor() as cursor:
cursor.execute(sql)
res = cursor.fetchall()
res_list = [ColumnSchema(*item) for item in res]
return res_list


def convert_value(value, datetime_format='space'):
Expand Down Expand Up @@ -730,28 +709,24 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
except Exception as ex:
raise Exception(str(ex))
elif equals_ignore_case(ds.type, 'hive'):
conn = hive.connect(host=conf.host, port=conf.port, username=conf.username,
database=conf.database, **extra_config_dict)
cursor = conn.cursor()
try:
# Hive uses backticks for identifiers; normalize quoted identifiers as a compatibility fallback.
hive_sql = re.sub(r'"([A-Za-z_][A-Za-z0-9_]*)"', r'`\1`', sql)
cursor.execute(hive_sql)
res = cursor.fetchall()
columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for
field in
cursor.description]
result_list = [
{str(columns[i]): convert_value(value) for i, value in enumerate(tuple_item)} for tuple_item in
res
]
return {"fields": columns, "data": result_list,
"sql": bytes.decode(base64.b64encode(bytes(hive_sql, 'utf-8')))}
except Exception as ex:
raise ParseSQLResultError(str(ex))
finally:
cursor.close()
conn.close()
with hive.connect(host=conf.host, port=conf.port, username=conf.username,
database=conf.database, **extra_config_dict) as conn, conn.cursor() as cursor:
try:
# Hive uses backticks for identifiers; normalize quoted identifiers as a compatibility fallback.
hive_sql = re.sub(r'"([A-Za-z_][A-Za-z0-9_]*)"', r'`\1`', sql)
cursor.execute(hive_sql)
res = cursor.fetchall()
columns = [field[0] for field in cursor.description] if origin_column else [field[0].lower() for
field in
cursor.description]
result_list = [
{str(columns[i]): convert_value(value) for i, value in enumerate(tuple_item)} for tuple_item in
res
]
return {"fields": columns, "data": result_list,
"sql": bytes.decode(base64.b64encode(bytes(hive_sql, 'utf-8')))}
except Exception as ex:
raise ParseSQLResultError(str(ex))


def check_sql_read(sql: str, ds: CoreDatasource | AssistantOutDsSchema):
Expand Down
Loading