diff --git a/chdb/state/sqlitelike.py b/chdb/state/sqlitelike.py index 7694cb42ece..78cc6711336 100644 --- a/chdb/state/sqlitelike.py +++ b/chdb/state/sqlitelike.py @@ -564,6 +564,28 @@ def send_query(self, query: str, format: str = "CSV") -> StreamingResult: c_stream_result = self._conn.send_query(query, format) return StreamingResult(c_stream_result, self._conn, result_func, supports_record_batch) + def __enter__(self): + """Enter the context manager and return the connection. + + Returns: + Connection: The connection object itself + """ + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Exit the context manager and close the connection. + + Args: + exc_type: Exception type if an exception was raised + exc_val: Exception value if an exception was raised + exc_tb: Exception traceback if an exception was raised + + Returns: + False to propagate any exception that occurred + """ + self.close() + return False + def close(self) -> None: """Close the connection and cleanup resources. diff --git a/tests/test_query_py.py b/tests/test_query_py.py index ea6a2074665..12e8298b617 100644 --- a/tests/test_query_py.py +++ b/tests/test_query_py.py @@ -110,6 +110,14 @@ def test_query_df(self): ret = chdb.query("SELECT b, sum(a) FROM Python(df) GROUP BY b ORDER BY b") self.assertEqual(str(ret), EXPECTED) + def test_auto_cleanup(self): + with chdb.connect("data.db") as conn: + result = conn.query("SELECT 1") + self.assertEqual(str(result), "1\n") + with chdb.connect("data.db") as conn: + result = conn.query("SELECT 2") + self.assertEqual(str(result), "2\n") + def test_query_df_with_index(self): df = pd.DataFrame( {