Skip to content

Commit

Permalink
feat: Improve db integration (#1114)
Browse files Browse the repository at this point in the history
1. Moved `get_database_handler` to `context_manager`
2. `NativeStorageEngine` is now consistent with `SQLStorageEngine`.
  • Loading branch information
gaurav274 committed Sep 14, 2023
1 parent a40c72e commit 2662ca1
Show file tree
Hide file tree
Showing 14 changed files with 102 additions and 60 deletions.
28 changes: 14 additions & 14 deletions evadb/binder/binder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,12 @@ def check_data_source_and_table_are_valid(
db_catalog_entry = catalog.get_database_catalog_entry(database_name)

if db_catalog_entry is not None:
handler = get_database_handler(
with get_database_handler(
db_catalog_entry.engine, **db_catalog_entry.params
)
handler.connect()
) as handler:
# Get table definition.
resp = handler.get_tables()

# Get table definition.
resp = handler.get_tables()
if resp.error is not None:
error = "There is no table in data source {}. Create the table using native query.".format(
database_name,
Expand All @@ -90,7 +89,7 @@ def check_data_source_and_table_are_valid(


def create_table_catalog_entry_for_data_source(
table_name: str, column_info: pd.DataFrame
table_name: str, database_name: str, column_info: pd.DataFrame
):
column_name_list = list(column_info["name"])
column_type_list = [
Expand All @@ -107,6 +106,7 @@ def create_table_catalog_entry_for_data_source(
file_url=None,
table_type=TableType.NATIVE_DATA,
columns=column_list,
database_name=database_name,
)
return table_catalog_entry

Expand Down Expand Up @@ -134,14 +134,14 @@ def bind_native_table_info(catalog: CatalogManager, table_info: TableInfo):
)

db_catalog_entry = catalog.get_database_catalog_entry(table_info.database_name)
handler = get_database_handler(db_catalog_entry.engine, **db_catalog_entry.params)
handler.connect()

# Assemble columns.
column_df = handler.get_columns(table_info.table_name).data
table_info.table_obj = create_table_catalog_entry_for_data_source(
table_info.table_name, column_df
)
with get_database_handler(
db_catalog_entry.engine, **db_catalog_entry.params
) as handler:
# Assemble columns.
column_df = handler.get_columns(table_info.table_name).data
table_info.table_obj = create_table_catalog_entry_for_data_source(
table_info.table_name, table_info.database_name, column_df
)


def bind_evadb_table_info(catalog: CatalogManager, table_info: TableInfo):
Expand Down
16 changes: 7 additions & 9 deletions evadb/binder/statement_binder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,14 @@ def add_table_alias(self, alias: str, database_name: str, table_name: str):
)

db_catalog_entry = self._catalog().get_database_catalog_entry(database_name)
handler = get_database_handler(
with get_database_handler(
db_catalog_entry.engine, **db_catalog_entry.params
)
handler.connect()

# Assemble columns.
column_df = handler.get_columns(table_name).data
table_obj = create_table_catalog_entry_for_data_source(
table_name, column_df
)
) as handler:
# Assemble columns.
column_df = handler.get_columns(table_name).data
table_obj = create_table_catalog_entry_for_data_source(
table_name, database_name, column_df
)
else:
table_obj = self._catalog().get_table_catalog_entry(table_name)

Expand Down
1 change: 1 addition & 0 deletions evadb/catalog/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class TableCatalogEntry:
identifier_column: str = "id"
columns: List[ColumnCatalogEntry] = field(compare=False, default_factory=list)
row_id: int = None
database_name: str = "EvaDB"


@dataclass(unsafe_hash=True)
Expand Down
6 changes: 2 additions & 4 deletions evadb/executor/create_database_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,8 @@ def exec(self, *args, **kwargs):
raise ExecutorError(f"{self.node.database_name} already exists.")

# Check the validity of database entry.
handler = get_database_handler(self.node.engine, **self.node.param_dict)
resp = handler.connect()
if not resp.status:
raise ExecutorError(f"Cannot establish connection due to {resp.error}")
with get_database_handler(self.node.engine, **self.node.param_dict):
pass

logger.debug(f"Creating database {self.node}")
self.catalog().insert_database_catalog_entry(
Expand Down
22 changes: 21 additions & 1 deletion evadb/executor/executor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@
from pathlib import Path
from typing import TYPE_CHECKING, Generator, List

from evadb.catalog.catalog_utils import xform_column_definitions_to_catalog_entries
from evadb.catalog.models.utils import TableCatalogEntry
from evadb.parser.create_statement import ColumnDefinition

if TYPE_CHECKING:
from evadb.catalog.catalog_manager import CatalogManager

from evadb.catalog.catalog_type import VectorStoreType
from evadb.catalog.catalog_type import TableType, VectorStoreType
from evadb.expression.abstract_expression import AbstractExpression
from evadb.expression.function_expression import FunctionExpression
from evadb.models.storage.batch import Batch
Expand Down Expand Up @@ -169,3 +173,19 @@ def handle_vector_store_params(
return {"index_db": str(Path(index_path).parent)}
else:
raise ValueError("Unsupported vector store type: {}".format(vector_store_type))


def create_table_catalog_entry_for_native_table(
table_info: TableInfo, column_list: List[ColumnDefinition]
):
column_catalog_entires = xform_column_definitions_to_catalog_entries(column_list)

# Assemble table.
table_catalog_entry = TableCatalogEntry(
name=table_info.table_name,
file_url=None,
table_type=TableType.NATIVE_DATA,
columns=column_catalog_entires,
database_name=table_info.database_name,
)
return table_catalog_entry
4 changes: 1 addition & 3 deletions evadb/executor/storage_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ def exec(self, *args, **kwargs) -> Iterator[Batch]:
elif self.node.table.table_type == TableType.STRUCTURED_DATA:
return storage_engine.read(self.node.table, self.node.batch_mem_size)
elif self.node.table.table_type == TableType.NATIVE_DATA:
return storage_engine.read(
self.node.table_ref.table.database_name, self.node.table
)
return storage_engine.read(self.node.table)
elif self.node.table.table_type == TableType.PDF_DATA:
return storage_engine.read(self.node.table)
else:
Expand Down
14 changes: 5 additions & 9 deletions evadb/executor/use_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,12 @@ def exec(self, *args, **kwargs) -> Iterator[Batch]:
f"{self._database_name} data source does not exist. Use CREATE DATABASE to add a new data source."
)

handler = get_database_handler(
db_catalog_entry.engine,
**db_catalog_entry.params,
)

handler.connect()
resp = handler.execute_native_query(self._query_string)
handler.disconnect()
with get_database_handler(
db_catalog_entry.engine, **db_catalog_entry.params
) as handler:
resp = handler.execute_native_query(self._query_string)

if resp.error is None:
if resp and resp.error is None:
return Batch(resp.data)
else:
raise ExecutorError(resp.error)
35 changes: 16 additions & 19 deletions evadb/storage/native_storage_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,30 @@ class NativeStorageEngine(AbstractStorageEngine):
def __init__(self, db: EvaDBDatabase):
super().__init__(db)

def create(self, table: TableCatalogEntry):
pass

def write(self, table: TableCatalogEntry, rows: Batch):
pass

def read(self, database_name: str, table: TableCatalogEntry) -> Iterator[Batch]:
def read(self, table: TableCatalogEntry) -> Iterator[Batch]:
try:
db_catalog_entry = self.db.catalog().get_database_catalog_entry(
database_name
table.database_name
)
handler = get_database_handler(
with get_database_handler(
db_catalog_entry.engine, **db_catalog_entry.params
)
handler.connect()

data_df = handler.execute_native_query(f"SELECT * FROM {table.name}").data

# Handling case-sensitive databases like SQLite can be tricky. Currently,
# EvaDB converts all columns to lowercase, which may result in issues with
# these databases. As we move forward, we are actively working on improving
# this aspect within Binder.
# For more information, please refer to https://github.com/georgia-tech-db/evadb/issues/1079.
data_df.columns = data_df.columns.str.lower()
yield Batch(pd.DataFrame(data_df))
) as handler:
data_df = handler.execute_native_query(
f"SELECT * FROM {table.name}"
).data

# Handling case-sensitive databases like SQLite can be tricky.
# Currently, EvaDB converts all columns to lowercase, which may result
# in issues with these databases. As we move forward, we are actively
# working on improving this aspect within Binder. For more information,
# please refer to https://github.com/georgia-tech-db/evadb/issues/1079.
data_df.columns = data_df.columns.str.lower()
yield Batch(pd.DataFrame(data_df))

except Exception as e:
err_msg = f"Failed to read the table {table.name} in data source {database_name} with exception {str(e)}"
err_msg = f"Failed to read the table {table.name} in data source {table.database_name} with exception {str(e)}"
logger.exception(err_msg)
raise Exception(err_msg)
15 changes: 14 additions & 1 deletion evadb/third_party/databases/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
# limitations under the License.
import importlib
import os
from contextlib import contextmanager


def get_database_handler(engine: str, **kwargs):
def _get_database_handler(engine: str, **kwargs):
"""
Return the database handler. User should modify this function for
their new integrated handlers.
Expand All @@ -43,6 +44,18 @@ def get_database_handler(engine: str, **kwargs):
raise NotImplementedError(f"Engine {engine} is not supported")


@contextmanager
def get_database_handler(engine: str, **kwargs):
handler = _get_database_handler(engine, **kwargs)
try:
handler.connect()
yield handler
except Exception as e:
raise Exception(f"Error connecting to the database: {str(e)}")
finally:
handler.disconnect()


def dynamic_import(handler_dir):
import_path = f"evadb.third_party.databases.{handler_dir}.{handler_dir}_handler"
return importlib.import_module(import_path)
3 changes: 3 additions & 0 deletions evadb/third_party/databases/mariadb/mariadb_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def disconnect(self):
if self.connection:
self.connection.close()

def get_sqlalchmey_uri(self) -> str:
return f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"

def check_connection(self) -> DBHandlerStatus:
"""
Method for checking the status of database connection.
Expand Down
3 changes: 3 additions & 0 deletions evadb/third_party/databases/mysql/mysql_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def disconnect(self):
if self.connection:
self.connection.close()

def get_sqlalchmey_uri(self) -> str:
return f"mysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"

def check_connection(self) -> DBHandlerStatus:
if self.connection:
return DBHandlerStatus(status=True)
Expand Down
3 changes: 3 additions & 0 deletions evadb/third_party/databases/postgres/postgres_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def disconnect(self):
if self.connection:
self.connection.close()

def get_sqlalchmey_uri(self) -> str:
return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"

def check_connection(self) -> DBHandlerStatus:
"""
Check connection to the handler.
Expand Down
3 changes: 3 additions & 0 deletions evadb/third_party/databases/sqlite/sqlite_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def disconnect(self):
if self.connection:
self.connection.close()

def get_sqlalchmey_uri(self) -> str:
return f"sqlite:///{self.database}"

def check_connection(self) -> DBHandlerStatus:
"""
Check connection to the handler.
Expand Down
9 changes: 9 additions & 0 deletions evadb/third_party/databases/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def disconnect(self):
"""
raise NotImplementedError()

def get_sqlalchmey_uri(self) -> str:
"""
Return the valid sqlalchemy uri to connect to the database.
Raises:
NotImplementedError: This method should be implemented in derived classes.
"""
raise NotImplementedError()

def check_connection(self) -> DBHandlerStatus:
"""
Checks the status of the database connection.
Expand Down

0 comments on commit 2662ca1

Please sign in to comment.