diff --git a/databend_py/client.py b/databend_py/client.py index 32e6df7..07bc211 100644 --- a/databend_py/client.py +++ b/databend_py/client.py @@ -250,7 +250,6 @@ def sync_csv_file_into_table(self, filename, data, table): start = time.time() stage_path = self.stage_csv_file(filename, data) copy_options = self.generate_copy_options() - print(copy_options) _, _ = self.execute( f"COPY INTO {table} FROM {stage_path} FILE_FORMAT = (type = CSV)\ PURGE = {copy_options['PURGE']} FORCE = {copy_options['FORCE']}\ diff --git a/databend_py/connection.py b/databend_py/connection.py index 25c7b91..7b5e190 100644 --- a/databend_py/connection.py +++ b/databend_py/connection.py @@ -10,6 +10,8 @@ from . import log from . import defines from .context import Context +from databend_py.errors import WarehouseTimeoutException +from databend_py.retry import retry headers = {'Content-Type': 'application/json', 'Accept': 'application/json', 'X-DATABEND-ROUTE': 'warehouse'} @@ -88,6 +90,9 @@ def __init__(self, host, port=None, user=defines.DEFAULT_USER, password=defines. print(os.getenv("ADDITIONAL_HEADERS")) self.additional_headers = e.dict("ADDITIONAL_HEADERS") + def default_session(self): + return {"database": self.database} + def make_headers(self): if "Authorization" not in self.additional_headers: return { @@ -105,30 +110,38 @@ def get_description(self): def disconnect(self): self.client_session = dict() + @retry(times=5, exceptions=WarehouseTimeoutException) + def do_query(self, url, query_sql): + response = requests.post(url, + data=json.dumps(query_sql), + headers=self.make_headers(), + auth=HTTPBasicAuth(self.user, self.password), + verify=True) + resp_dict = json.loads(response.content) + if resp_dict and resp_dict.get('error') and "no endpoint" in resp_dict.get('error'): + raise WarehouseTimeoutException + + return resp_dict + def query(self, statement): url = self.format_url() log.logger.debug(f"http sql: {statement}") query_sql = {'sql': statement, "string_fields": True} if self.client_session is not None and len(self.client_session) != 0: if "database" not in self.client_session: - self.client_session = {"database": self.database} + self.client_session = self.default_session() query_sql['session'] = self.client_session else: - self.client_session = {"database": self.database} + self.client_session = self.default_session() query_sql['session'] = self.client_session log.logger.debug(f"http headers {self.make_headers()}") - response = requests.post(url, - data=json.dumps(query_sql), - headers=self.make_headers(), - auth=HTTPBasicAuth(self.user, self.password), - verify=True) try: - resp_dict = json.loads(response.content) - self.client_session = resp_dict["session"] + resp_dict = self.do_query(url, query_sql) + self.client_session = resp_dict.get("session", self.default_session()) return resp_dict except Exception as err: log.logger.error( - f"http error on {url}, SQL: {statement} content: {response.content} error msg:{str(err)}" + f"http error on {url}, SQL: {statement} error msg:{str(err)}" ) raise @@ -148,14 +161,13 @@ def next_page(self, next_uri): # return a list of response util empty next_uri def query_with_session(self, statement): - current_session = self.client_session response_list = list() response = self.query(statement) log.logger.debug(f"response content: {response}") response_list.append(response) start_time = time.time() time_limit = 12 - session = response['session'] + session = response.get("session", self.default_session()) if session: self.client_session = session while response['next_uri'] is not None: @@ -163,7 +175,7 @@ def query_with_session(self, statement): response = json.loads(resp.content) log.logger.debug(f"Sql in progress, fetch next_uri content: {response}") self.check_error(response) - session = response['session'] + session = response.get("session", self.default_session()) if session: self.client_session = session response_list.append(response) diff --git a/databend_py/errors.py b/databend_py/errors.py index 91c878f..ac97a4b 100644 --- a/databend_py/errors.py +++ b/databend_py/errors.py @@ -18,3 +18,13 @@ def __init__(self, message, code=None): def __str__(self): return 'Code: {}\n{}'.format(self.code, self.message) + + +class WarehouseTimeoutException(Error): + def __init__(self, message, code=None): + self.message = message + self.code = code + super(WarehouseTimeoutException, self).__init__(message) + + def __str__(self): + return 'Provision warehouse timeout: \n{}'.format(self.message) diff --git a/databend_py/retry.py b/databend_py/retry.py new file mode 100644 index 0000000..8d263fa --- /dev/null +++ b/databend_py/retry.py @@ -0,0 +1,41 @@ +from databend_py.errors import WarehouseTimeoutException + + +def retry(times, exceptions): + """ + Retry Decorator + Retries the wrapped function/method `times` times if the exceptions listed + in ``exceptions`` are thrown + :type times: Int + :param exceptions: Lists of exceptions that trigger a retry attempt + :type exceptions: Tuple of Exceptions + """ + + def decorator(func): + def newfn(*args, **kwargs): + attempt = 0 + while attempt < times: + try: + return func(*args, **kwargs) + except exceptions: + print( + 'Exception thrown when attempting to run %s, attempt ' + '%d of %d' % (func, attempt, times) + ) + attempt += 1 + return func(*args, **kwargs) + + return newfn + + return decorator + + +@retry(times=3, exceptions=WarehouseTimeoutException) +def foo1(): + print('Some code here ....') + print('Oh no, we have exception') + raise WarehouseTimeoutException('Some error') + + +if __name__ == '__main__': + foo1()