Skip to content

Commit

Permalink
feat(api): thread kwargs around properly to support more complex conn…
Browse files Browse the repository at this point in the history
…ection arguments
  • Loading branch information
cpcloud authored and gforsyth committed Mar 27, 2023
1 parent db29e10 commit 7e0e15b
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 19 deletions.
12 changes: 7 additions & 5 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def connect(self, *args, **kwargs) -> BaseBackend:
new_backend.reconnect()
return new_backend

def _from_url(self, url: str) -> BaseBackend:
def _from_url(self, url: str, **kwargs) -> BaseBackend:
"""Construct an ibis backend from a SQLAlchemy-conforming URL."""
raise NotImplementedError(
f"`_from_url` not implemented for the {self.name} backend"
Expand Down Expand Up @@ -991,12 +991,14 @@ def connect(resource: Path | str, **kwargs: Any) -> BaseBackend:
parsed = urllib.parse.urlparse(url)
scheme = parsed.scheme or "file"

# Merge explicit kwargs with query string, explicit kwargs
# taking precedence
kwargs = dict(urllib.parse.parse_qsl(parsed.query), **kwargs)
orig_kwargs = kwargs.copy()
kwargs = dict(urllib.parse.parse_qsl(parsed.query))

if scheme == "file":
path = parsed.netloc + parsed.path
# Merge explicit kwargs with query string, explicit kwargs
# taking precedence
kwargs.update(orig_kwargs)
if path.endswith(".duckdb"):
return ibis.duckdb.connect(path, **kwargs)
elif path.endswith((".sqlite", ".db")):
Expand Down Expand Up @@ -1038,4 +1040,4 @@ def connect(resource: Path | str, **kwargs: Any) -> BaseBackend:
except AttributeError:
raise ValueError(f"Don't know how to connect to {resource!r}") from None

return backend._from_url(url)
return backend._from_url(url, **orig_kwargs)
13 changes: 9 additions & 4 deletions ibis/backends/base/sql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Iterable, Mapping

import toolz

import ibis.common.exceptions as exc
import ibis.expr.analysis as an
import ibis.expr.operations as ops
Expand All @@ -31,13 +33,15 @@ class BaseSQLBackend(BaseBackend):
table_class = ops.DatabaseTable
table_expr_class = ir.Table

def _from_url(self, url: str) -> BaseBackend:
def _from_url(self, url: str, **kwargs: Any) -> BaseBackend:
"""Connect to a backend using a URL `url`.
Parameters
----------
url
URL with which to connect to a backend.
kwargs
Additional keyword arguments passed to the `connect` method.
Returns
-------
Expand All @@ -47,7 +51,7 @@ def _from_url(self, url: str) -> BaseBackend:
import sqlalchemy as sa

url = sa.engine.make_url(url)

new_kwargs = kwargs.copy()
kwargs = {}

for name in ("host", "port", "database", "password"):
Expand All @@ -60,8 +64,9 @@ def _from_url(self, url: str) -> BaseBackend:
kwargs["user"] = username

kwargs.update(url.query)
self._convert_kwargs(kwargs)
return self.connect(**kwargs)
new_kwargs = toolz.merge(kwargs, new_kwargs)
self._convert_kwargs(new_kwargs)
return self.connect(**new_kwargs)

def table(self, name: str, database: str | None = None) -> ir.Table:
"""Construct a table expression.
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ class Backend(BaseSQLBackend):
database_class = BigQueryDatabase
table_class = BigQueryTable

def _from_url(self, url):
def _from_url(self, url: str, **kwargs):
result = urlparse(url)
params = parse_qs(result.query)
return self.connect(
project_id=result.netloc or params.get("project_id", [""])[0],
dataset_id=result.path[1:] or params.get("dataset_id", [""])[0],
**kwargs,
)

def do_connect(
Expand Down
17 changes: 11 additions & 6 deletions ibis/backends/clickhouse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,15 @@ def sql(self, query: str, schema=None) -> ir.Table:
schema = self._get_schema_using_query(query)
return ops.SQLQueryResult(query, ibis.schema(schema), self).to_expr()

def _from_url(self, url: str) -> BaseBackend:
def _from_url(self, url: str, **kwargs) -> BaseBackend:
"""Connect to a backend using a URL `url`.
Parameters
----------
url
URL with which to connect to a backend.
kwargs
Additional keyword arguments
Returns
-------
Expand All @@ -126,11 +128,14 @@ def _from_url(self, url: str) -> BaseBackend:
"""
url = sa.engine.make_url(url)

kwargs = {
name: value
for name in ("host", "port", "database", "password")
if (value := getattr(url, name, None))
}
kwargs = toolz.merge(
{
name: value
for name in ("host", "port", "database", "password")
if (value := getattr(url, name, None))
},
kwargs,
)
if username := url.username:
kwargs["user"] = username

Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class Options(ibis.config.Config):

treat_nan_as_null: bool = False

def _from_url(self, url: str) -> Backend:
def _from_url(self, url: str, **kwargs) -> Backend:
"""Construct a PySpark backend from a URL `url`."""
url = sa.engine.make_url(url)

Expand All @@ -136,7 +136,7 @@ def _from_url(self, url: str) -> Backend:

builder = SparkSession.builder.config(conf=conf)
session = builder.getOrCreate()
return self.connect(session)
return self.connect(session, **kwargs)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
10 changes: 9 additions & 1 deletion ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
from ibis.backends.base.sql.alchemy.query_builder import _AlchemyTableSetFormatter

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa

import ibis.expr.schema as sch


@contextlib.contextmanager
def _handle_pyarrow_warning(*, action: str):
Expand Down Expand Up @@ -181,7 +184,7 @@ def do_connect(
"STRICT_JSON_OUTPUT": "TRUE",
},
)
self._default_connector_format = connect_args["session_parameters"].get(
self._default_connector_format = connect_args["session_parameters"].setdefault(
PARAMETER_PYTHON_CONNECTOR_QUERY_RESULT_FORMAT, "JSON"
)
engine = sa.create_engine(
Expand Down Expand Up @@ -263,6 +266,11 @@ def to_pyarrow(
return res[expr.get_name()][0]
return res

def fetch_from_cursor(self, cursor, schema: sch.Schema) -> pd.DataFrame:
if _NATIVE_ARROW and self._default_connector_format == "ARROW":
return cursor.cursor.fetch_pandas_all()
return super().fetch_from_cursor(cursor, schema)

def to_pyarrow_batches(
self,
expr: ir.Expr,
Expand Down

0 comments on commit 7e0e15b

Please sign in to comment.