Skip to content

Commit

Permalink
fix(impala): more aggressively clean up cursors internally
Browse files Browse the repository at this point in the history
BREAKING CHANGE: Cursors are no longer returned from DDL operations to prevent resource leakage. Use `raw_sql` if you need specialized operations that return a cursor. Additionally, table-based DDL operations now return the table they're operating on.
  • Loading branch information
cpcloud committed May 6, 2023
1 parent 888718b commit bf5687e
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 85 deletions.
6 changes: 1 addition & 5 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,7 @@ def raw_sql(self, query: str):
query
DDL or DML statement
"""
cursor = self.con.execute(query)
if cursor:
return cursor
cursor.release()
return None
return self.con.execute(query)

@contextlib.contextmanager
def _safe_raw_sql(self, *args, **kwargs):
Expand Down
118 changes: 59 additions & 59 deletions ibis/backends/impala/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import re
import weakref
from functools import cached_property
from pathlib import Path
from posixpath import join as pjoin
from typing import TYPE_CHECKING, Any, Literal
Expand Down Expand Up @@ -299,43 +300,51 @@ def do_connect(

self._ensure_temp_db_exists()

@property
@cached_property
def version(self):
cursor = self.raw_sql('select version()')
result = cursor.fetchone()[0]
cursor.release()
with self._safe_raw_sql('select version()') as cursor:
(result,) = cursor.fetchone()
return result

def list_databases(self, like=None):
cur = self.raw_sql('SHOW DATABASES')
databases = self._get_list(cur)
cur.release()
with self._safe_raw_sql('SHOW DATABASES') as cur:
databases = self._get_list(cur)
return self._filter_with_like(databases, like)

def list_tables(self, like=None, database=None):
statement = 'SHOW TABLES'
if database is not None:
statement += f' IN {database}'
if like:
m = fully_qualified_re.match(like)
if m:
database, quoted, unquoted = m.groups()
if match := fully_qualified_re.match(like):
database, quoted, unquoted = match.groups()
like = quoted or unquoted
return self.list_tables(like=like, database=database)
statement += f" LIKE '{like}'"

return self._filter_with_like(
[row[0] for row in self.raw_sql(statement).fetchall()]
)
with self._safe_raw_sql(statement) as cursor:
tables = [row[0] for row in cursor.fetchall()]
return self._filter_with_like(tables)

def fetch_from_cursor(self, cursor, schema):
batches = cursor.fetchall(columnar=True)
names = [x[0] for x in cursor.description]
names = [name for name, *_ in cursor.description]
df = _column_batches_to_dataframe(names, batches)
if schema:
return schema.apply_to(df)
return df

@contextlib.contextmanager
def _safe_raw_sql(self, *args, **kwargs):
with contextlib.closing(self.raw_sql(*args, **kwargs)) as cursor:
yield cursor
with contextlib.suppress(AttributeError):
cursor.release()

def _safe_exec_sql(self, *args, **kwargs):
with self._safe_raw_sql(*args, **kwargs):
pass

@property
def hdfs(self):
if self._hdfs is None:
Expand Down Expand Up @@ -410,7 +419,7 @@ def create_database(self, name, path=None, force=False):
# which is easier for manual cleanup, if necessary
self.hdfs.mkdir(path)
statement = CreateDatabase(name, path=path, can_exist=force)
return self.raw_sql(statement)
self._safe_exec_sql(statement)

def drop_database(self, name, force=False):
"""Drop an Impala database.
Expand All @@ -433,7 +442,7 @@ def drop_database(self, name, force=False):
udas = []
if force:
for table in tables:
util.log('Dropping {}'.format(f'{name}.{table}'))
util.log(f'Dropping {name}.{table}')
self.drop_table_or_view(table, database=name)
for func in udfs:
util.log(f'Dropping function {func.name}({func.inputs})')
Expand All @@ -457,13 +466,9 @@ def drop_database(self, name, force=False):
"being dropped, or set force=True"
)
statement = DropDatabase(name, must_exist=not force)
return self.raw_sql(statement)
self._safe_exec_sql(statement)

def get_schema(
self,
table_name: str,
database: str | None = None,
) -> sch.Schema:
def get_schema(self, table_name: str, database: str | None = None) -> sch.Schema:
"""Return a Schema object for the indicated table and database.
Parameters
Expand All @@ -479,7 +484,7 @@ def get_schema(
Ibis schema
"""
qualified_name = self._fully_qualified_name(table_name, database)
query = f'DESCRIBE {qualified_name}'
query = f"DESCRIBE {qualified_name}"

# only pull out the first two columns which are names and types
pairs = [row[:2] for row in self.con.fetchall(query)]
Expand All @@ -493,9 +498,9 @@ def get_schema(
def client_options(self):
return self.con.options

def get_options(self):
def get_options(self) -> dict[str, str]:
"""Return current query options for the Impala session."""
return dict(row[:2] for row in self.con.fetchall("SET"))
return {key: value for key, value, *_ in self.con.fetchall("SET")}

def set_options(self, options):
self.con.set_options(options)
Expand Down Expand Up @@ -526,12 +531,12 @@ def create_view(
ast = self.compiler.to_ast(obj)
select = ast.queries[0]
statement = CreateView(name, select, database=database, can_exist=overwrite)
self.raw_sql(statement)
self._safe_exec_sql(statement)
return self.table(name, database=database)

def drop_view(self, name, database=None, force=False):
statement = DropView(name, database=database, must_exist=not force)
return self.raw_sql(statement)
stmt = DropView(name, database=database, must_exist=not force)
self._safe_exec_sql(stmt)

@contextlib.contextmanager
def _setup_insert(self, obj):
Expand Down Expand Up @@ -609,7 +614,7 @@ def create_table(

if overwrite:
self.drop_table(name, force=True)
self.raw_sql(
self._safe_exec_sql(
CTAS(
name,
select,
Expand All @@ -623,7 +628,7 @@ def create_table(
else: # schema is not None
if overwrite:
self.drop_table(name, force=True)
self.raw_sql(
self._safe_exec_sql(
CreateTableWithSchema(
name,
schema,
Expand Down Expand Up @@ -672,7 +677,7 @@ def avro_file(
stmt = ddl.CreateTableAvro(
name, hdfs_dir, avro_schema, database=database, external=external
)
self.raw_sql(stmt)
self._safe_exec_sql(stmt)
return self._wrap_new_table(name, database, persist)

def delimited_file(
Expand Down Expand Up @@ -738,7 +743,7 @@ def delimited_file(
lineterminator=lineterminator,
escapechar=escapechar,
)
self.raw_sql(stmt)
self._safe_exec_sql(stmt)
return self._wrap_new_table(name, database, persist)

def parquet_file(
Expand Down Expand Up @@ -821,7 +826,7 @@ def parquet_file(
external=external,
can_exist=False,
)
self.raw_sql(stmt)
self._safe_exec_sql(stmt)
return self._wrap_new_table(name, database, persist)

def _get_concrete_table_path(self, name, database, persist=False):
Expand Down Expand Up @@ -870,7 +875,7 @@ def _wrap_new_table(self, name, database, persist):
qualified_name, cardinality
)
)
self.raw_sql(set_card)
self._safe_exec_sql(set_card)

return t

Expand Down Expand Up @@ -931,7 +936,7 @@ def drop_table(
>>> con.drop_table(table, database=db, force=True) # doctest: +SKIP
"""
statement = DropTable(name, database=database, must_exist=not force)
self.raw_sql(statement)
self._safe_exec_sql(statement)

def truncate_table(self, name: str, database: str | None = None) -> None:
"""Delete all rows from an existing table.
Expand All @@ -944,7 +949,7 @@ def truncate_table(self, name: str, database: str | None = None) -> None:
Database name
"""
statement = TruncateTable(name, database=database)
self.raw_sql(statement)
self._safe_exec_sql(statement)

def drop_table_or_view(self, name, *, database=None, force=False):
"""Drop view or table."""
Expand Down Expand Up @@ -976,14 +981,13 @@ def cache_table(self, table_name, *, database=None, pool='default'):
>>> con.cache_table('my_table', database=db, pool=pool) # doctest: +SKIP
"""
statement = ddl.CacheTable(table_name, database=database, pool=pool)
self.raw_sql(statement)
self._safe_exec_sql(statement)

def _get_schema_using_query(self, query):
cur = self.raw_sql(f"SELECT * FROM ({query}) t0 LIMIT 0")
# resets the state of the cursor and closes operation
cur.fetchall()
ibis_fields = self._adapt_types(cur.description)
cur.release()
with self._safe_raw_sql(f"SELECT * FROM ({query}) t0 LIMIT 0") as cur:
# resets the state of the cursor and closes operation
cur.fetchall()
ibis_fields = self._adapt_types(cur.description)

return sch.Schema(ibis_fields)

Expand All @@ -1009,7 +1013,7 @@ def create_function(self, func, name=None, database=None):
stmt = ddl.CreateUDA(func, name=name, database=database)
else:
raise TypeError(func)
self.raw_sql(stmt)
self._safe_exec_sql(stmt)

def drop_udf(
self,
Expand Down Expand Up @@ -1082,7 +1086,7 @@ def _drop_single_function(self, name, input_types, database=None, aggregate=Fals
aggregate=aggregate,
database=database,
)
self.raw_sql(stmt)
self._safe_exec_sql(stmt)

def _drop_all_functions(self, database):
udfs = self.list_udfs(database=database)
Expand All @@ -1094,7 +1098,7 @@ def _drop_all_functions(self, database):
aggregate=False,
database=database,
)
self.raw_sql(stmt)
self._safe_exec_sql(stmt)
udafs = self.list_udas(database=database)
for udaf in udafs:
stmt = ddl.DropFunction(
Expand All @@ -1104,28 +1108,23 @@ def _drop_all_functions(self, database):
aggregate=True,
database=database,
)
self.raw_sql(stmt)
self._safe_exec_sql(stmt)

def list_udfs(self, database=None, like=None):
"""Lists all UDFs associated with given database."""
if not database:
database = self.current_database
statement = ddl.ListFunction(database, like=like, aggregate=False)
cur = self.raw_sql(statement)
result = self._get_udfs(cur, udf.ImpalaUDF)
cur.release()
return result
with self._safe_raw_sql(statement) as cur:
return self._get_udfs(cur, udf.ImpalaUDF)

def list_udas(self, database=None, like=None):
"""Lists all UDAFs associated with a given database."""
if not database:
database = self.current_database
statement = ddl.ListFunction(database, like=like, aggregate=True)
cur = self.raw_sql(statement)
result = self._get_udfs(cur, udf.ImpalaUDA)
cur.release()

return result
with self._safe_raw_sql(statement) as cur:
return self._get_udfs(cur, udf.ImpalaUDA)

def _get_udfs(self, cur, klass):
def _to_type(x):
Expand Down Expand Up @@ -1189,7 +1188,7 @@ def compute_stats(
cmd = f'COMPUTE {maybe_inc}STATS'

stmt = self._table_command(cmd, name, database=database)
self.raw_sql(stmt)
self._safe_exec_sql(stmt)

def invalidate_metadata(
self,
Expand All @@ -1210,7 +1209,7 @@ def invalidate_metadata(
stmt = 'INVALIDATE METADATA'
if name is not None:
stmt = self._table_command(stmt, name, database=database)
self.raw_sql(stmt)
self._safe_exec_sql(stmt)

def refresh(self, name: str, database: str | None = None) -> None:
"""Reload HDFS block location metadata for a table.
Expand All @@ -1229,7 +1228,7 @@ def refresh(self, name: str, database: str | None = None) -> None:
"""
# TODO(wesm): can this statement be cancelled?
stmt = self._table_command('REFRESH', name, database=database)
self.raw_sql(stmt)
self._safe_exec_sql(stmt)

def describe_formatted(
self,
Expand Down Expand Up @@ -1292,7 +1291,8 @@ def column_stats(self, name, database=None):
return self._exec_statement(stmt)

def _exec_statement(self, stmt):
return self.fetch_from_cursor(self.raw_sql(stmt), schema=None)
with self._safe_raw_sql(stmt) as cur:
return self.fetch_from_cursor(cur, schema=None)

def _table_command(self, cmd, name, database=None):
qualified_name = self._fully_qualified_name(name, database)
Expand Down
Loading

0 comments on commit bf5687e

Please sign in to comment.