From a034ad8d3e7ef65279ba55f3d68a70852be12953 Mon Sep 17 00:00:00 2001 From: Yang Xiufeng Date: Tue, 3 Dec 2024 16:28:58 +0800 Subject: [PATCH 1/2] fix session state update --- databend_py/connection.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/databend_py/connection.py b/databend_py/connection.py index 8d0368a..fdc68a2 100644 --- a/databend_py/connection.py +++ b/databend_py/connection.py @@ -223,7 +223,9 @@ def query(self, statement): log.logger.debug(f"http headers {self.make_headers()}") try: resp_dict = self.do_query(url, query_sql) - self.client_session = resp_dict.get("session", self.default_session()) + new_session_state = resp_dict.get("session", self.default_session()) + if new_session_state: + self.client_session = new_session_state if self.additional_headers: self.additional_headers.update( {XDatabendQueryIDHeader: resp_dict.get(QueryID)} @@ -286,7 +288,7 @@ def query_with_session(self, statement): response_list.append(response) start_time = time.time() time_limit = 12 - session = response.get("session", self.default_session()) + session = response.get("session") if session: self.client_session = session while response["next_uri"] is not None: @@ -294,7 +296,7 @@ def query_with_session(self, statement): response = json.loads(resp.content) log.logger.debug(f"Sql in progress, fetch next_uri content: {response}") self.check_error(response) - session = response.get("session", self.default_session()) + session = response.get("session") if session: self.client_session = session response_list.append(response) From 0fd3aa05138837cdc2693edcad4a4e3c49514f89 Mon Sep 17 00:00:00 2001 From: Yang Xiufeng Date: Wed, 4 Dec 2024 07:13:26 +0800 Subject: [PATCH 2/2] feat: support cookie --- databend_py/connection.py | 17 +++++++++++++++++ tests/test_client.py | 8 ++++++++ tests/test_simple.py | 6 +++--- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/databend_py/connection.py b/databend_py/connection.py index fdc68a2..c6e1d7d 100644 --- a/databend_py/connection.py +++ b/databend_py/connection.py @@ -4,7 +4,9 @@ import time import uuid +from http.cookiejar import Cookie from requests.auth import HTTPBasicAuth +from requests.cookies import RequestsCookieJar import environs import requests @@ -75,6 +77,17 @@ def get_error(response): return ServerException(response["error"]["message"], response["error"]["code"]) +class GlobalCookieJar(RequestsCookieJar): + + def __init__(self): + super().__init__() + + def set_cookie(self, cookie: Cookie, *args, **kwargs): + cookie.domain = "" + cookie.path = "/" + super().set_cookie(cookie, *args, **kwargs) + + class Connection(object): # Databend http handler doc: https://databend.rs/doc/reference/api/rest @@ -120,6 +133,10 @@ def __init__( self.context = Context() self.requests_session = requests.Session() self.schema = "http" + cookie_jar = GlobalCookieJar() + cookie_jar.set("cookie_enabled", "true") + self.requests_session.cookies = cookie_jar + self.schema = 'http' if self.secure: self.schema = "https" e = environs.Env() diff --git a/tests/test_client.py b/tests/test_client.py index b7a7695..3d40ad4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -293,6 +293,14 @@ def test_cast_bool(self): _, data = client.execute("select 'False'::boolean union select 'True'::boolean") self.assertEqual(data, [(True,), (False,)]) + def test_temp_table(self): + client = Client.from_url(self.databend_url) + client.execute("create temp table t1(a int)") + client.execute("insert into t1 values (1)") + _, data = client.execute("select * from t1") + self.assertEqual(data, [(1,)]) + client.execute("drop table t1") + if __name__ == "__main__": unittest.main() diff --git a/tests/test_simple.py b/tests/test_simple.py index e63fb2b..9eb682f 100644 --- a/tests/test_simple.py +++ b/tests/test_simple.py @@ -16,15 +16,15 @@ def __setattr__(self, key, value): class TestDict(unittest.TestCase): - databend_url = None # 使用类属性来存储 databend_url + databend_url = None @classmethod def setUpClass(cls): - cls.databend_url = "test_url" # 在类级别设置 databend_url + cls.databend_url = "test_url" def test_init(self): d = Dict(a=1, b="test") - self.assertEqual(self.databend_url, "test_url") # 使用类属性 + self.assertEqual(self.databend_url, "test_url") self.assertEqual(d.a, 1) self.assertEqual(d.b, "test") self.assertTrue(isinstance(d, dict))