Skip to content
This repository was archived by the owner on Jan 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion databend_py/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def sync_csv_file_into_table(self, filename, data, table):
start = time.time()
stage_path = self.stage_csv_file(filename, data)
copy_options = self.generate_copy_options()
print(copy_options)
_, _ = self.execute(
f"COPY INTO {table} FROM {stage_path} FILE_FORMAT = (type = CSV)\
PURGE = {copy_options['PURGE']} FORCE = {copy_options['FORCE']}\
Expand Down
38 changes: 25 additions & 13 deletions databend_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from . import log
from . import defines
from .context import Context
from databend_py.errors import WarehouseTimeoutException
from databend_py.retry import retry

headers = {'Content-Type': 'application/json', 'Accept': 'application/json', 'X-DATABEND-ROUTE': 'warehouse'}

Expand Down Expand Up @@ -88,6 +90,9 @@ def __init__(self, host, port=None, user=defines.DEFAULT_USER, password=defines.
print(os.getenv("ADDITIONAL_HEADERS"))
self.additional_headers = e.dict("ADDITIONAL_HEADERS")

def default_session(self):
return {"database": self.database}

def make_headers(self):
if "Authorization" not in self.additional_headers:
return {
Expand All @@ -105,30 +110,38 @@ def get_description(self):
def disconnect(self):
self.client_session = dict()

@retry(times=5, exceptions=WarehouseTimeoutException)
def do_query(self, url, query_sql):
response = requests.post(url,
data=json.dumps(query_sql),
headers=self.make_headers(),
auth=HTTPBasicAuth(self.user, self.password),
verify=True)
resp_dict = json.loads(response.content)
if resp_dict and resp_dict.get('error') and "no endpoint" in resp_dict.get('error'):
raise WarehouseTimeoutException

return resp_dict

def query(self, statement):
url = self.format_url()
log.logger.debug(f"http sql: {statement}")
query_sql = {'sql': statement, "string_fields": True}
if self.client_session is not None and len(self.client_session) != 0:
if "database" not in self.client_session:
self.client_session = {"database": self.database}
self.client_session = self.default_session()
query_sql['session'] = self.client_session
else:
self.client_session = {"database": self.database}
self.client_session = self.default_session()
query_sql['session'] = self.client_session
log.logger.debug(f"http headers {self.make_headers()}")
response = requests.post(url,
data=json.dumps(query_sql),
headers=self.make_headers(),
auth=HTTPBasicAuth(self.user, self.password),
verify=True)
try:
resp_dict = json.loads(response.content)
self.client_session = resp_dict["session"]
resp_dict = self.do_query(url, query_sql)
self.client_session = resp_dict.get("session", self.default_session())
return resp_dict
except Exception as err:
log.logger.error(
f"http error on {url}, SQL: {statement} content: {response.content} error msg:{str(err)}"
f"http error on {url}, SQL: {statement} error msg:{str(err)}"
)
raise

Expand All @@ -148,22 +161,21 @@ def next_page(self, next_uri):

# return a list of response util empty next_uri
def query_with_session(self, statement):
current_session = self.client_session
response_list = list()
response = self.query(statement)
log.logger.debug(f"response content: {response}")
response_list.append(response)
start_time = time.time()
time_limit = 12
session = response['session']
session = response.get("session", self.default_session())
if session:
self.client_session = session
while response['next_uri'] is not None:
resp = self.next_page(response['next_uri'])
response = json.loads(resp.content)
log.logger.debug(f"Sql in progress, fetch next_uri content: {response}")
self.check_error(response)
session = response['session']
session = response.get("session", self.default_session())
if session:
self.client_session = session
response_list.append(response)
Expand Down
10 changes: 10 additions & 0 deletions databend_py/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,13 @@ def __init__(self, message, code=None):

def __str__(self):
return 'Code: {}\n{}'.format(self.code, self.message)


class WarehouseTimeoutException(Error):
def __init__(self, message, code=None):
self.message = message
self.code = code
super(WarehouseTimeoutException, self).__init__(message)

def __str__(self):
return 'Provision warehouse timeout: \n{}'.format(self.message)
41 changes: 41 additions & 0 deletions databend_py/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from databend_py.errors import WarehouseTimeoutException


def retry(times, exceptions):
"""
Retry Decorator
Retries the wrapped function/method `times` times if the exceptions listed
in ``exceptions`` are thrown
:type times: Int
:param exceptions: Lists of exceptions that trigger a retry attempt
:type exceptions: Tuple of Exceptions
"""

def decorator(func):
def newfn(*args, **kwargs):
attempt = 0
while attempt < times:
try:
return func(*args, **kwargs)
except exceptions:
print(
'Exception thrown when attempting to run %s, attempt '
'%d of %d' % (func, attempt, times)
)
attempt += 1
return func(*args, **kwargs)

return newfn

return decorator


@retry(times=3, exceptions=WarehouseTimeoutException)
def foo1():
print('Some code here ....')
print('Oh no, we have exception')
raise WarehouseTimeoutException('Some error')


if __name__ == '__main__':
foo1()