Skip to content
This repository was archived by the owner on Jan 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ jobs:
env:
TEST_DATABEND_DSN: "http://databend:databend@localhost:8000/default"
run: |
make lint
make ci
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ test:
ci:
python tests/test_client.py

lint:
pyflakes .

install:
pip install -r requirements.txt
pip install -e .
pip install -e .
2 changes: 1 addition & 1 deletion databend_py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .connection import Connection
from .datetypes import DatabendDataType

VERSION = (0, 4, 1)
VERSION = (0, 4, 2)
__version__ = '.'.join(str(x) for x in VERSION)

__all__ = ['Client', 'Connection', 'DatabendDataType']
78 changes: 46 additions & 32 deletions databend_py/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .util.helper import asbool, Helper
from .util.escape import escape_params
from .result import QueryResult
import json, operator, csv, uuid, requests, time, os
import json, csv, uuid, requests, time


class Client(object):
Expand All @@ -19,6 +19,7 @@ def __init__(self, *args, **kwargs):
self.connection = Connection(*args, **kwargs)
self.query_result_cls = QueryResult
self.helper = Helper
self._debug = asbool(self.settings.get('debug', False))

def __enter__(self):
return self
Expand All @@ -29,39 +30,39 @@ def disconnect(self):
def disconnect_connection(self):
self.connection.disconnect()

def data_generator(self, raw_data):
def _data_generator(self, raw_data):
while raw_data['next_uri'] is not None:
try:
raw_data = self.receive_data(raw_data['next_uri'])
raw_data = self._receive_data(raw_data['next_uri'])
yield raw_data
except (Exception, KeyboardInterrupt):
self.disconnect()
raise

def receive_data(self, next_uri: str):
def _receive_data(self, next_uri: str):
resp = self.connection.next_page(next_uri)
raw_data = json.loads(resp.content)
helper = self.helper()
helper.response = raw_data
helper.check_error()
return raw_data

def receive_result(self, query, query_id=None, with_column_types=False):
def _receive_result(self, query, query_id=None, with_column_types=False):
raw_data = self.connection.query(query)
helper = self.helper()
helper.response = raw_data
helper.check_error()
gen = self.data_generator(raw_data)
gen = self._data_generator(raw_data)
result = self.query_result_cls(
gen, raw_data, with_column_types=with_column_types)
return result.get_result()

def iter_receive_result(self, query, query_id=None, with_column_types=False):
def _iter_receive_result(self, query, query_id=None, with_column_types=False):
raw_data = self.connection.query(query)
helper = self.helper()
helper.response = raw_data
helper.check_error()
gen = self.data_generator(raw_data)
gen = self._data_generator(raw_data)
result = self.query_result_cls(
gen, raw_data, with_column_types=with_column_types)
_, rows = result.get_result()
Expand Down Expand Up @@ -104,16 +105,16 @@ def execute(self, query, params=None, with_column_types=False,
if is_insert:
# remove the `\n` '\s' `\t` in the SQL
query = " ".join([s.strip() for s in query.splitlines()]).strip()
rv = self.process_insert_query(query, params)
rv = self._process_insert_query(query, params)
return [], rv

column_types, rv = self.process_ordinary_query(
column_types, rv = self._process_ordinary_query(
query, params=params, with_column_types=with_column_types,
query_id=query_id)
return column_types, rv

# params = [(1,),(2,)] or params = [(1,2),(2,3)]
def process_insert_query(self, query, params):
def _process_insert_query(self, query, params):
insert_rows = 0
if "values" in query:
query = query.split("values")[0] + 'values'
Expand All @@ -128,32 +129,32 @@ def process_insert_query(self, query, params):
batch_size = query.count(',') + 1
if params is not None:
tuple_ls = [tuple(params[i:i + batch_size]) for i in range(0, len(params), batch_size)]
csv_data, filename = self.generate_csv_data(tuple_ls)
self.sync_csv_file_into_table(filename, csv_data, table_name, "CSV")
csv_data, filename = self._generate_csv_data(tuple_ls)
self._sync_csv_file_into_table(filename, csv_data, table_name, "CSV")
insert_rows = len(tuple_ls)

return insert_rows

def process_ordinary_query(self, query, params=None, with_column_types=False,
def _process_ordinary_query(self, query, params=None, with_column_types=False,
query_id=None):
if params is not None:
query = self.substitute_params(
query = self._substitute_params(
query, params, self.connection.context
)
return self.receive_result(query, query_id=query_id, with_column_types=with_column_types, )
return self._receive_result(query, query_id=query_id, with_column_types=with_column_types, )

def execute_iter(self, query, params=None, with_column_types=False,
query_id=None, settings=None):
if params is not None:
query = self.substitute_params(
query = self._substitute_params(
query, params, self.connection.context
)
return self.iter_receive_result(query, query_id=query_id, with_column_types=with_column_types)
return self._iter_receive_result(query, query_id=query_id, with_column_types=with_column_types)

def iter_process_ordinary_query(self, query, with_column_types=False, query_id=None):
return self.iter_receive_result(query, query_id=query_id, with_column_types=with_column_types)
def _iter_process_ordinary_query(self, query, with_column_types=False, query_id=None):
return self._iter_receive_result(query, query_id=query_id, with_column_types=with_column_types)

def substitute_params(self, query, params, context):
def _substitute_params(self, query, params, context):
if not isinstance(params, dict):
raise ValueError('Parameters are expected in dict form')

Expand Down Expand Up @@ -197,6 +198,8 @@ def from_url(cls, url):
elif name == 'copy_purge':
kwargs[name] = asbool(value)
settings[name] = asbool(value)
elif name == 'debug':
settings[name] = asbool(value)
elif name in timeouts:
kwargs[name] = float(value)
else:
Expand Down Expand Up @@ -224,37 +227,48 @@ def from_url(cls, url):

return cls(host, **kwargs)

def generate_csv_data(self, bindings):
def _generate_csv_data(self, bindings):
file_name = f'{uuid.uuid4()}.csv'
buffer = io.StringIO()
csvwriter = csv.writer(buffer, delimiter=',', quoting=csv.QUOTE_MINIMAL)
csvwriter.writerows(bindings)
buffer.seek(0) # Move the buffer's position to the beginning
return buffer.getvalue(), file_name

def get_file_data(self, filename):
def _get_file_data(self, filename):
with open(filename, "r") as f:
return io.StringIO(f.read())

def stage_csv_file(self, filename, data):
stage_path = "@~/%s" % filename

start_presign_time = time.time()
_, row = self.execute('presign upload %s' % stage_path)
if self._debug:
print("upload: presign file:%s duration:%ss" % (filename, time.time() - start_presign_time))

presigned_url = row[0][2]
headers = json.loads(row[0][1])
resp = requests.put(presigned_url, headers=headers, data=data)
resp.raise_for_status()

start_upload_time = time.time()
try:
resp = requests.put(presigned_url, headers=headers, data=data)
resp.raise_for_status()
finally:
if self._debug:
print("upload: put file:%s duration:%ss" % (filename, time.time() - start_upload_time))
return stage_path

def sync_csv_file_into_table(self, filename, data, table, file_type):
def _sync_csv_file_into_table(self, filename, data, table, file_type):
start = time.time()
stage_path = self.stage_csv_file(filename, data)
copy_options = self.generate_copy_options()
copy_options = self._generate_copy_options()
_, _ = self.execute(
f"COPY INTO {table} FROM {stage_path} FILE_FORMAT = (type = {file_type} RECORD_DELIMITER = '\r\n')\
PURGE = {copy_options['PURGE']} FORCE = {copy_options['FORCE']}\
SIZE_LIMIT={copy_options['SIZE_LIMIT']} ON_ERROR = {copy_options['ON_ERROR']}")
print("sync %s duration:%ss" % (filename, int(time.time() - start)))
# os.remove(filename)
if self._debug:
print("upload: copy %s duration:%ss" % (filename, int(time.time() - start)))

def upload(self, file_name, table_name, file_type=None):
"""
Expand All @@ -268,10 +282,10 @@ def upload(self, file_name, table_name, file_type=None):
file_type = file_name.split(".")[1].upper()
else:
file_type = "CSV"
file_data = self.get_file_data(file_name)
self.sync_csv_file_into_table(file_name, file_data, table_name, file_type)
file_data = self._get_file_data(file_name)
self._sync_csv_file_into_table(file_name, file_data, table_name, file_type)

def generate_copy_options(self):
def _generate_copy_options(self):
# copy options docs: https://databend.rs/doc/sql-commands/dml/dml-copy-into-table#copyoptions
copy_options = {}
if "copy_purge" in self.settings:
Expand Down
1 change: 0 additions & 1 deletion databend_py/result.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ast
from .datetypes import DatabendDataType
import re

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ mysql_connector_repackaged==0.3.1
pytz==2022.5
requests==2.28.1
setuptools==62.3.2
black==23.3.0
pyflakes==3.0.1