Skip to content

Commit

Permalink
Connection class hierarchy (#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
James Robinson committed Jun 23, 2023
1 parent 9dfcc51 commit ee15ce4
Show file tree
Hide file tree
Showing 11 changed files with 758 additions and 677 deletions.
394 changes: 0 additions & 394 deletions noteable/datasource_postprocessing.py

This file was deleted.

74 changes: 24 additions & 50 deletions noteable/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,23 @@
import sys
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import pkg_resources
import structlog
from sqlalchemy.engine import URL

# Import all our known concrete Connection implementations.
import noteable.sql.sqlalchemy # noqa
# ipython-sql thinks mighty highly of isself with this package name.
from noteable.sql.connection import Connection, ConnectionRegistry, get_connection_registry
from noteable.sql.run import add_commit_blacklist_dialect
from noteable.sql.connection import (
Connection,
ConnectionRegistry,
get_connection_class,
get_connection_registry,
)

DEFAULT_SECRETS_DIR = Path('/vault/secrets')

from noteable.datasource_postprocessing import post_processor_by_drivername

logger = structlog.get_logger(__name__)

Expand Down Expand Up @@ -120,12 +124,12 @@ def bootstrap_datasource(
pre_process_dict(dsn_dict)
pre_process_dict(connect_args)

# Do any per-drivername post-processing of and dsn_dict and create_engine_kwargs
# before we make use of any of their contents. Post-processors may end up rejecting this
# configuration, so catch and handle just like a failure when calling Connection.set().
if drivername in post_processor_by_drivername:
post_processor: Callable[[str, dict, dict], None] = post_processor_by_drivername[drivername]
post_processor(datasource_id, dsn_dict, create_engine_kwargs)
# Late lookup the Connection subclass implementation registered for this drivername.
# Will raise KeyError if none are registered.
connection_class = get_connection_class(drivername)

if hasattr(connection_class, 'preprocess_configuration'):
connection_class.preprocess_configuration(datasource_id, dsn_dict, create_engine_kwargs)

# Ensure the required driver packages are installed already, or, if allowed,
# install them on the fly.
Expand All @@ -135,41 +139,12 @@ def bootstrap_datasource(
metadata['allow_datasource_dialect_autoinstall'],
)

# Prepare connection URL string.
url_obj = URL.create(**dsn_dict)
connection_url = str(url_obj)

# XXX TODO, make a mixin for future SQLAlchemy DisableAutoCommit subclasses incorporating
# this particular need. A good look for the end game here may be that most all of this
# 'bootstrapping datasource' code will be within either the Connection base class stuff, unifying
# this module with Connection module, or perhaps a slightly parallel class hierarchy for
# the bootstrapping class corresponding to the Connection subtype registered for the
# drivername field?

# Do we need to tell sql-magic to not try to emit a COMMIT after each statement
# according to the needs of this driver?
if not metadata['sqlmagic_autocommit']:
# A sqlalchemy drivername may be comprised of 'dialect+drivername', such as
# 'databricks+connector'.
# If so, then we must only pass along the LHS of the '+'.
dialect = metadata['drivername'].split('+')[0]
add_commit_blacklist_dialect(dialect)

# Register the connection + return it.
sql_cell_handle = f'@{datasource_id}'

# XXX Todo: polymorphy / mapping connection subclass to construct based on driver name
# will happen here once we have a class hierarchy. Until then, only exactly one class
# to construct!

connection = Connection(
sql_cell_handle=sql_cell_handle,
human_name=metadata['name'],
connection_url=connection_url,
**create_engine_kwargs,
)
# Individual Connection classes don't need to be bothered with these.
del metadata['required_python_modules']
del metadata['allow_datasource_dialect_autoinstall']

return connection
# Construct + return Connection subclass instance.
return connection_class(f'@{datasource_id}', metadata, dsn_dict, create_engine_kwargs)


##
Expand Down Expand Up @@ -251,15 +226,14 @@ def pre_process_dict(the_dict: Dict[str, Any]) -> None:

LOCAL_DB_CONN_HANDLE = "@noteable"
LOCAL_DB_CONN_NAME = "Local Database"
DUCKDB_LOCATION = "duckdb:///:memory:"


def local_duckdb_bootstrapper() -> Connection:
"""Return the noteable.sql.connection.Connection to use for local memory DuckDB."""
return Connection(
sql_cell_handle=LOCAL_DB_CONN_HANDLE,
human_name=LOCAL_DB_CONN_NAME,
connection_url=DUCKDB_LOCATION,
return noteable.sql.sqlalchemy.DuckDBConnection(
LOCAL_DB_CONN_HANDLE,
{'name': LOCAL_DB_CONN_NAME},
{'drivername': 'duckdb', 'database': ':memory:'},
)


Expand Down
145 changes: 93 additions & 52 deletions noteable/sql/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

from typing import Callable, Dict, Optional, TypeVar
from typing import Any, Callable, Dict, List, Optional, Protocol, Type, TypeVar, runtime_checkable

import pandas as pd
import sqlalchemy
import sqlalchemy.engine.base
import structlog
Expand All @@ -13,6 +14,13 @@
'get_sqla_connection',
'get_sqla_engine',
'ConnectionBootstrapper',
'UnknownConnectionError',
'SQLAlchemyUnsupportedError',
'ResultSet',
'Connection',
'ConnectionRegistry',
'connection_class',
'get_connection_class',
)

logger = structlog.get_logger(__name__)
Expand All @@ -30,37 +38,86 @@ class SQLAlchemyUnsupportedError(Exception):
pass


class Connection:
class ResultSet(Protocol):
"""
Results of a query against any kind of data connection / connection type.
"""

keys: Optional[List[str]]
"""Column names from the result, if any"""

rows: Optional[list]
"""List of rows from the result, if any. Each row should be len(keys) long."""
# In case of an INSERT, UPDATE, or DELETE statement.

rowcount: Optional[int]
"""How many rows were affected by an INSERT/UPDATE/DELETE sort of statement?"""

has_results_to_report: bool
"""Most queries will have results to report, but CREATE TABLE and other DDLs may not."""

@property
def is_scalar_value(self) -> bool:
"""Is result expressable as a single scalar value w/o losing any information?"""
return self.has_results_to_report and (
(self.rowcount is not None) or (len(self.rows) == 1 and len(self.rows[0]) == 1)
)

@property
def scalar_value(self):
"""Return either the only row / column value, or the affected num of rows
from an INSERT/DELETE/UPDATE statement as bare scalar"""

# Should only be called if self.is_scalar_value
if self.rowcount is not None:
return self.rowcount
else:
return self.rows[0][0]

@property
def can_become_dataframe(self) -> bool:
return self.has_results_to_report and self.rows is not None

def to_dataframe(self) -> pd.DataFrame:
"Returns a Pandas DataFrame instance built from the result set, if possible."

# Should only be called if self.can_become_dataframe is True

# Worst case will be a zero row but defined columns dataframe.
return pd.DataFrame(self.rows, columns=self.keys)


@runtime_checkable
class Connection(Protocol):
"""Protocol defining all Noteable Data Connection implementations"""

sql_cell_handle: str
"""Machine-accessible name/id, aka @35647345345345 ..."""

human_name: str
"""Human assigned datasource name"""

def __init__(
self, sql_cell_handle: str, human_name: str, connection_url: str, **create_engine_kwargs
):
"""
Construct a new 'connection', which in reality is a sqla Engine
plus some convienent metadata.
is_sqlalchemy_based: bool
"""Is this conection implemented on top of SQLAlchemy?"""

# Lifecycle methods

Common args to go into the create_engine call (and therefore need to be
passed in within `create_engine_kwargs`) include:
def execute(self, statement: str, bind_dict: Dict[str, Any]) -> ResultSet:
"""Execute this statement, possibly interpolating the values in bind_dict"""
... # pragma: no cover

* create_engine_kwargs: SQLA will pass these down to its call to create the DBAPI-level
connection class when new low-level connections are
established.
def close(self) -> None:
"""Close any resources currently allocated to this connection"""
... # pragma: no cover

No SQLA-level connection is immediately established (see the `sqla_connection` property).

'name' is what we call now the 'sql_cell_handle' -- starts with '@', followed by
the hex of the datasource uuid (usually -- the legacy "local database" (was sqlite, now duckdb)
and bigquery do not use the hex convention because they predate datasources)
class BaseConnection(Connection):
sql_cell_handle: str
human_name: str

'human_name' is the name that the user gave the datasource ('My PostgreSQL Connection')
(again, only for real datasource connections). There's a slight risk of name collision
due to having the same name used between user and space scopes, but so be it.
def __init__(self, sql_cell_handle: str, human_name: str):
super().__init__()

"""
if not sql_cell_handle.startswith("@"):
raise ValueError("sql_cell_handle values must start with '@'")

Expand All @@ -71,47 +128,31 @@ def __init__(
self.sql_cell_handle = sql_cell_handle
self.human_name = human_name

# SLQA-centric fields hereon down, to be pushed into SQLA subclass in the future.
self._engine = sqlalchemy.create_engine(connection_url, **create_engine_kwargs)
self._create_engine_kwargs = create_engine_kwargs

def close(self):
"""General-ish API method; SQLA-centric implementation"""
if self._sqla_connection:
self._sqla_connection.close()
self.reset_connection_pool()
# Dict of drivername -> Connection implementation
_drivername_to_connection_type: Dict[str, Type[Connection]] = {}

####
# SLQA-centric methods / properties here down
####

is_sqlalchemy_based = True
def connection_class(drivername: str):
"""Decorator to register a concrete Connection implementation to use for the given driver"""

@property
def sqla_engine(self) -> sqlalchemy.engine.base.Engine:
return self._engine
# Explicitly allows for overwriting any old binding so as to allow for notebook-side
# hotpatching.

@property
def dialect(self):
return self.sqla_engine.url.get_dialect()
def decorator_outer(clazz):
_drivername_to_connection_type[drivername] = clazz

_sqla_connection: Optional[sqlalchemy.engine.base.Connection] = None
return clazz

@property
def sqla_connection(self) -> sqlalchemy.engine.base.Connection:
"""Lazily connect to the database. Return a SQLA Connection object, or die trying."""
return decorator_outer

if not self._sqla_connection:
self._sqla_connection = self.sqla_engine.connect()

return self._sqla_connection
def get_connection_class(drivername: str) -> Type[Connection]:
"""Return the Connection implementation class registered for this driver.
def reset_connection_pool(self):
"""Reset the SQLA connection pool, such as after an exception suspected to indicate
a broken connection has been raised.
"""
self._engine.dispose()
self._sqla_connection = None
Raises KeyError if no implementation is registered.
"""
return _drivername_to_connection_type[drivername]


ConnectionBootstrapper: TypeVar = Callable[[], Connection]
Expand Down
5 changes: 5 additions & 0 deletions noteable/sql/meta_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ class MetaCommand:
include_in_help = True

def __init__(self, shell: InteractiveShell, conn: Connection, assign_to_varname: Optional[str]):
if not conn.is_sqlalchemy_based:
raise ValueError(
'Meta commands only working against SQLAlchemy-based connections at this time.'
)

self.shell = shell
self.conn = conn
self.assign_to_varname = assign_to_varname
Expand Down
Loading

0 comments on commit ee15ce4

Please sign in to comment.