Skip to content
This repository was archived by the owner on Jan 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions databend_py/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ def from_url(cls, url):
settings[name] = asbool(value)
elif name in timeouts:
kwargs[name] = float(value)
elif name == 'persist_cookies':
kwargs[name] = asbool(value)
else:
settings[name] = value # settings={'copy_purge':False}
secure = kwargs.get("secure", False)
Expand Down
8 changes: 6 additions & 2 deletions databend_py/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Connection(object):
# 'database': 'default'
# }
def __init__(self, host, port=None, user=defines.DEFAULT_USER, password=defines.DEFAULT_PASSWORD,
database=defines.DEFAULT_DATABASE, secure=False, copy_purge=False, session_settings=None):
database=defines.DEFAULT_DATABASE, secure=False, copy_purge=False, session_settings=None, persist_cookies=False):
self.host = host
self.port = port
self.user = user
Expand All @@ -90,6 +90,8 @@ def __init__(self, host, port=None, user=defines.DEFAULT_USER, password=defines.
if os.getenv("ADDITIONAL_HEADERS") is not None:
print(os.getenv("ADDITIONAL_HEADERS"))
self.additional_headers = e.dict("ADDITIONAL_HEADERS")
self.persist_cookies = persist_cookies
self.cookies = None

def default_session(self):
return {"database": self.database}
Expand Down Expand Up @@ -124,6 +126,8 @@ def do_query(self, url, query_sql):
raise UnexpectedException("failed to parse response: %s" % response.content)
if resp_dict and resp_dict.get('error') and "no endpoint" in resp_dict.get('error'):
raise WarehouseTimeoutException
if self.persist_cookies:
self.cookies = response.cookies
return resp_dict

def query(self, statement):
Expand Down Expand Up @@ -160,7 +164,7 @@ def reset_session(self):

def next_page(self, next_uri):
url = "{}://{}:{}{}".format(self.schema, self.host, self.port, next_uri)
return self.requests_session.get(url=url, headers=self.make_headers())
return self.requests_session.get(url=url, headers=self.make_headers(), cookies=self.cookies)

# return a list of response util empty next_uri
def query_with_session(self, statement):
Expand Down
1 change: 1 addition & 0 deletions docs/connection.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@ client = Client(
| secure | Enable SSL | false | http://root@localhost:8000/db?secure=False |
| copy_purge | If True, the command will purge the files in the stage after they are loaded successfully into the table | false | http://root@localhost:8000/db?copy_purge=False |
| debug | Enable debug log | False | http://root@localhost:8000/db?debug=True |
| persist_cookies | if using cookies set by server to perform following requests. | False | http://root@localhost:8000/db?persist_cookies=True|

17 changes: 17 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def test_simple(self):
c = Client.from_url("databend://root:root@localhost:8000/default?compress=True")
self.assertEqual(c._uploader._compress, True)

self.assertEqual(c.connection.persist_cookies, False)
c = Client.from_url('https://root:root@localhost:8000?persist_cookies=True')
self.assertEqual(c.connection.persist_cookies, True)

def test_session_settings(self):
session_settings = {"db": "database"}
c = Client(host="localhost", port=8000, user="root", password="root", session_settings={"db": "database"})
Expand Down Expand Up @@ -126,6 +130,18 @@ def tearDown(self):
client.execute('DROP TABLE IF EXISTS test')
client.disconnect()

def test_cookies(self):
client = Client.from_url(self.databend_url)
client.execute("select 1")
self.assertIsNone(client.connection.cookies)

if "?" in self.databend_url:
url_with_persist_cookies = f"{self.databend_url}&persist_cookies=true"
else:
url_with_persist_cookies = f"{self.databend_url}?persist_cookies=true"
client = Client.from_url(url_with_persist_cookies)
client.execute("select 1")
# self.assertIsNotNone(client.connection.cookies)

if __name__ == '__main__':
print("start test......")
Expand All @@ -137,5 +153,6 @@ def tearDown(self):
dt.test_iter_query()
dt.test_insert()
dt.test_insert_with_compress()
dt.test_cookies()
dt.tearDown()
print("end test.....")