diff --git a/dlt/common/configuration/specs/connection_string_credentials.py b/dlt/common/configuration/specs/connection_string_credentials.py index e7b0e5f900..9dd6f00942 100644 --- a/dlt/common/configuration/specs/connection_string_credentials.py +++ b/dlt/common/configuration/specs/connection_string_credentials.py @@ -1,5 +1,5 @@ from typing import Any, ClassVar, Dict, List, Optional -from sqlalchemy.engine import URL, make_url +from dlt.common.libs.sql_alchemy import URL, make_url from dlt.common.configuration.specs.exceptions import InvalidConnectionString from dlt.common.typing import TSecretValue @@ -26,6 +26,7 @@ def parse_native_representation(self, native_value: Any) -> None: # update only values that are not None self.update({k: v for k, v in url._asdict().items() if v is not None}) if self.query is not None: + # query may be immutable so make it mutable self.query = dict(self.query) except Exception: raise InvalidConnectionString(self.__class__, native_value, self.drivername) diff --git a/dlt/common/libs/sql_alchemy.py b/dlt/common/libs/sql_alchemy.py new file mode 100644 index 0000000000..2f3b51ec0d --- /dev/null +++ b/dlt/common/libs/sql_alchemy.py @@ -0,0 +1,446 @@ +""" +Ports fragments of URL class from Sql Alchemy to use them when dependency is not available. +""" + +from typing import cast + + +try: + import sqlalchemy +except ImportError: + # port basic functionality without the whole Sql Alchemy + + import re + from typing import ( + Any, + Dict, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Tuple, + TypeVar, + Union, + overload, + ) + import collections.abc as collections_abc + from urllib.parse import ( + quote_plus, + parse_qsl, + quote, + unquote, + ) + + _KT = TypeVar("_KT", bound=Any) + _VT = TypeVar("_VT", bound=Any) + + class ImmutableDict(Dict[_KT, _VT]): + """Not a real immutable dict""" + + def __setitem__(self, __key: _KT, __value: _VT) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + def __delitem__(self, _KT: Any) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + def update(self, *arg: Any, **kw: Any) -> None: + raise NotImplementedError("Cannot modify immutable dict") + + EMPTY_DICT: ImmutableDict[Any, Any] = ImmutableDict() + + def to_list(value: Any, default: Optional[List[Any]] = None) -> List[Any]: + if value is None: + return default + if not isinstance(value, collections_abc.Iterable) or isinstance(value, str): + return [value] + elif isinstance(value, list): + return value + else: + return list(value) + + class URL(NamedTuple): + """ + Represent the components of a URL used to connect to a database. + + Based on SqlAlchemy URL class with copyright as below: + + # engine/url.py + # Copyright (C) 2005-2023 the SQLAlchemy authors and contributors + # + # This module is part of SQLAlchemy and is released under + # the MIT License: https://www.opensource.org/licenses/mit-license.php + """ + + drivername: str + """database backend and driver name, such as `postgresql+psycopg2`""" + username: Optional[str] + "username string" + password: Optional[str] + """password, which is normally a string but may also be any object that has a `__str__()` method.""" + host: Optional[str] + """hostname or IP number. May also be a data source name for some drivers.""" + port: Optional[int] + """integer port number""" + database: Optional[str] + """database name""" + query: ImmutableDict[str, Union[Tuple[str, ...], str]] + """an immutable mapping representing the query string. contains strings + for keys and either strings or tuples of strings for values""" + + @classmethod + def create( + cls, + drivername: str, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Mapping[str, Union[Sequence[str], str]] = None, + ) -> "URL": + """Create a new `URL` object.""" + return cls( + cls._assert_str(drivername, "drivername"), + cls._assert_none_str(username, "username"), + password, + cls._assert_none_str(host, "host"), + cls._assert_port(port), + cls._assert_none_str(database, "database"), + cls._str_dict(query or EMPTY_DICT), + ) + + @classmethod + def _assert_port(cls, port: Optional[int]) -> Optional[int]: + if port is None: + return None + try: + return int(port) + except TypeError: + raise TypeError("Port argument must be an integer or None") + + @classmethod + def _assert_str(cls, v: str, paramname: str) -> str: + if not isinstance(v, str): + raise TypeError("%s must be a string" % paramname) + return v + + @classmethod + def _assert_none_str(cls, v: Optional[str], paramname: str) -> Optional[str]: + if v is None: + return v + + return cls._assert_str(v, paramname) + + @classmethod + def _str_dict( + cls, + dict_: Optional[ + Union[ + Sequence[Tuple[str, Union[Sequence[str], str]]], + Mapping[str, Union[Sequence[str], str]], + ] + ], + ) -> ImmutableDict[str, Union[Tuple[str, ...], str]]: + if dict_ is None: + return EMPTY_DICT + + @overload + def _assert_value( + val: str, + ) -> str: ... + + @overload + def _assert_value( + val: Sequence[str], + ) -> Union[str, Tuple[str, ...]]: ... + + def _assert_value( + val: Union[str, Sequence[str]], + ) -> Union[str, Tuple[str, ...]]: + if isinstance(val, str): + return val + elif isinstance(val, collections_abc.Sequence): + return tuple(_assert_value(elem) for elem in val) + else: + raise TypeError( + "Query dictionary values must be strings or sequences of strings" + ) + + def _assert_str(v: str) -> str: + if not isinstance(v, str): + raise TypeError("Query dictionary keys must be strings") + return v + + dict_items: Iterable[Tuple[str, Union[Sequence[str], str]]] + if isinstance(dict_, collections_abc.Sequence): + dict_items = dict_ + else: + dict_items = dict_.items() + + return ImmutableDict( + { + _assert_str(key): _assert_value( + value, + ) + for key, value in dict_items + } + ) + + def set( # noqa + self, + drivername: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + database: Optional[str] = None, + query: Optional[Mapping[str, Union[Sequence[str], str]]] = None, + ) -> "URL": + """return a new `URL` object with modifications.""" + + kw: Dict[str, Any] = {} + if drivername is not None: + kw["drivername"] = drivername + if username is not None: + kw["username"] = username + if password is not None: + kw["password"] = password + if host is not None: + kw["host"] = host + if port is not None: + kw["port"] = port + if database is not None: + kw["database"] = database + if query is not None: + kw["query"] = query + + return self._assert_replace(**kw) + + def _assert_replace(self, **kw: Any) -> "URL": + """argument checks before calling _replace()""" + + if "drivername" in kw: + self._assert_str(kw["drivername"], "drivername") + for name in "username", "host", "database": + if name in kw: + self._assert_none_str(kw[name], name) + if "port" in kw: + self._assert_port(kw["port"]) + if "query" in kw: + kw["query"] = self._str_dict(kw["query"]) + + return self._replace(**kw) + + def update_query_string(self, query_string: str, append: bool = False) -> "URL": + return self.update_query_pairs(parse_qsl(query_string), append=append) + + def update_query_pairs( + self, + key_value_pairs: Iterable[Tuple[str, Union[str, List[str]]]], + append: bool = False, + ) -> "URL": + """Return a new `URL` object with the `query` parameter dictionary updated by the given sequence of key/value pairs""" + existing_query = self.query + new_keys: Dict[str, Union[str, List[str]]] = {} + + for key, value in key_value_pairs: + if key in new_keys: + new_keys[key] = to_list(new_keys[key]) + cast("List[str]", new_keys[key]).append(cast(str, value)) + else: + new_keys[key] = to_list(value) if isinstance(value, (list, tuple)) else value + + new_query: Mapping[str, Union[str, Sequence[str]]] + if append: + new_query = {} + + for k in new_keys: + if k in existing_query: + new_query[k] = tuple(to_list(existing_query[k]) + to_list(new_keys[k])) + else: + new_query[k] = new_keys[k] + + new_query.update( + {k: existing_query[k] for k in set(existing_query).difference(new_keys)} + ) + else: + new_query = ImmutableDict( + { + **self.query, + **{k: tuple(v) if isinstance(v, list) else v for k, v in new_keys.items()}, + } + ) + return self.set(query=new_query) + + def update_query_dict( + self, + query_parameters: Mapping[str, Union[str, List[str]]], + append: bool = False, + ) -> "URL": + return self.update_query_pairs(query_parameters.items(), append=append) + + def render_as_string(self, hide_password: bool = True) -> str: + """Render this `URL` object as a string.""" + s = self.drivername + "://" + if self.username is not None: + s += quote(self.username, safe=" +") + if self.password is not None: + s += ":" + ("***" if hide_password else quote(str(self.password), safe=" +")) + s += "@" + if self.host is not None: + if ":" in self.host: + s += f"[{self.host}]" + else: + s += self.host + if self.port is not None: + s += ":" + str(self.port) + if self.database is not None: + s += "/" + self.database + if self.query: + keys = to_list(self.query) + keys.sort() + s += "?" + "&".join( + f"{quote_plus(k)}={quote_plus(element)}" + for k in keys + for element in to_list(self.query[k]) + ) + return s + + def __repr__(self) -> str: + return self.render_as_string() + + def __copy__(self) -> "URL": + return self.__class__.create( + self.drivername, + self.username, + self.password, + self.host, + self.port, + self.database, + self.query.copy(), + ) + + def __deepcopy__(self, memo: Any) -> "URL": + return self.__copy__() + + def __hash__(self) -> int: + return hash(str(self)) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, URL) + and self.drivername == other.drivername + and self.username == other.username + and self.password == other.password + and self.host == other.host + and self.database == other.database + and self.query == other.query + and self.port == other.port + ) + + def __ne__(self, other: Any) -> bool: + return not self == other + + def get_backend_name(self) -> str: + """Return the backend name. + + This is the name that corresponds to the database backend in + use, and is the portion of the `drivername` + that is to the left of the plus sign. + + """ + if "+" not in self.drivername: + return self.drivername + else: + return self.drivername.split("+")[0] + + def get_driver_name(self) -> str: + """Return the backend name. + + This is the name that corresponds to the DBAPI driver in + use, and is the portion of the `drivername` + that is to the right of the plus sign. + """ + + if "+" not in self.drivername: + return self.drivername + else: + return self.drivername.split("+")[1] + + def make_url(name_or_url: Union[str, URL]) -> URL: + """Given a string, produce a new URL instance. + + The format of the URL generally follows `RFC-1738`, with some exceptions, including + that underscores, and not dashes or periods, are accepted within the + "scheme" portion. + + If a `URL` object is passed, it is returned as is.""" + + if isinstance(name_or_url, str): + return _parse_url(name_or_url) + elif not isinstance(name_or_url, URL): + raise ValueError(f"Expected string or URL object, got {name_or_url!r}") + else: + return name_or_url + + def _parse_url(name: str) -> URL: + pattern = re.compile( + r""" + (?P[\w\+]+):// + (?: + (?P[^:/]*) + (?::(?P[^@]*))? + @)? + (?: + (?: + \[(?P[^/\?]+)\] | + (?P[^/:\?]+) + )? + (?::(?P[^/\?]*))? + )? + (?:/(?P[^\?]*))? + (?:\?(?P.*))? + """, + re.X, + ) + + m = pattern.match(name) + if m is not None: + components = m.groupdict() + query: Optional[Dict[str, Union[str, List[str]]]] + if components["query"] is not None: + query = {} + + for key, value in parse_qsl(components["query"]): + if key in query: + query[key] = to_list(query[key]) + cast("List[str]", query[key]).append(value) + else: + query[key] = value + else: + query = None + + components["query"] = query + if components["username"] is not None: + components["username"] = unquote(components["username"]) + + if components["password"] is not None: + components["password"] = unquote(components["password"]) + + ipv4host = components.pop("ipv4host") + ipv6host = components.pop("ipv6host") + components["host"] = ipv4host or ipv6host + name = components.pop("name") + + if components["port"]: + components["port"] = int(components["port"]) + + return URL.create(name, **components) # type: ignore + + else: + raise ValueError("Could not parse SQLAlchemy URL from string '%s'" % name) + +else: + from sqlalchemy.engine import URL, make_url # type: ignore[assignment] diff --git a/dlt/destinations/impl/mssql/configuration.py b/dlt/destinations/impl/mssql/configuration.py index f00998cfb2..45c448fab7 100644 --- a/dlt/destinations/impl/mssql/configuration.py +++ b/dlt/destinations/impl/mssql/configuration.py @@ -1,5 +1,5 @@ from typing import Final, ClassVar, Any, List, Dict, Optional, TYPE_CHECKING -from sqlalchemy.engine import URL +from dlt.common.libs.sql_alchemy import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials diff --git a/dlt/destinations/impl/postgres/configuration.py b/dlt/destinations/impl/postgres/configuration.py index f1d30a7342..109d422650 100644 --- a/dlt/destinations/impl/postgres/configuration.py +++ b/dlt/destinations/impl/postgres/configuration.py @@ -1,5 +1,5 @@ from typing import Final, ClassVar, Any, List, TYPE_CHECKING -from sqlalchemy.engine import URL +from dlt.common.libs.sql_alchemy import URL from dlt.common.configuration import configspec from dlt.common.configuration.specs import ConnectionStringCredentials diff --git a/dlt/destinations/impl/snowflake/configuration.py b/dlt/destinations/impl/snowflake/configuration.py index 01f5ca6e03..4f97f08700 100644 --- a/dlt/destinations/impl/snowflake/configuration.py +++ b/dlt/destinations/impl/snowflake/configuration.py @@ -3,7 +3,7 @@ from typing import Final, Optional, Any, Dict, ClassVar, List, TYPE_CHECKING -from sqlalchemy.engine import URL +from dlt.common.libs.sql_alchemy import URL from dlt import version from dlt.common.exceptions import MissingDependencyException diff --git a/docs/website/docs/intro.md b/docs/website/docs/intro.md index 04af626566..27b896ba2a 100644 --- a/docs/website/docs/intro.md +++ b/docs/website/docs/intro.md @@ -141,7 +141,7 @@ from sqlalchemy import create_engine # Use any SQL database supported by SQLAlchemy, below we use a public # MySQL instance to get data. -# NOTE: you'll need to install pymysql with `pip install pymysql` +# NOTE: you'll need to install pymysql with `pip install sqlalchemy pymysql` # NOTE: loading data from public mysql instance may take several seconds engine = create_engine("mysql+pymysql://rfamro@mysql-rfam-public.ebi.ac.uk:4497/Rfam") @@ -165,7 +165,7 @@ print(load_info) Install **pymysql** driver: ```sh -pip install pymysql +pip install sqlalchemy pymysql ``` diff --git a/poetry.lock b/poetry.lock index 9c1c9b4226..cad68180dc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -8986,4 +8986,4 @@ weaviate = ["weaviate-client"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.13" -content-hash = "3380a5a646776e0fc0d895b5271bb769872ac1cdb09a842af61ba1741d1c03b3" +content-hash = "a7aa3e523522ab3260a7a19f097a34349b66cf046289db1e17b48f88f7fd189f" diff --git a/pyproject.toml b/pyproject.toml index ee98c4d1e0..0eb33d885d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,6 @@ hexbytes = ">=0.2.2" tzdata = ">=2022.1" tomlkit = ">=0.11.3" pathvalidate = ">=2.5.2" -SQLAlchemy = ">=1.4.0" typing-extensions = ">=4.0.0" makefun = ">=1.15.0" click = ">=7.1" @@ -137,7 +136,6 @@ types-python-dateutil = ">=2.8.15" flake8-tidy-imports = ">=4.8.0" flake8-encodings = "^0.5.0" flake8-builtins = "^1.5.3" -types-SQLAlchemy = ">=1.4.53" boto3-stubs = "^1.28.28" types-tqdm = "^4.66.0.2" types-psutil = "^5.9.5.16" @@ -145,6 +143,7 @@ types-psycopg2 = "^2.9.21.14" cryptography = "^41.0.7" google-api-python-client = ">=1.7.11" pytest-asyncio = "^0.23.5" +types-sqlalchemy = "^1.4.53.38" [tool.poetry.group.pipeline] optional=true @@ -186,6 +185,7 @@ sentry-sdk = "^1.5.6" optional = true [tool.poetry.group.docs.dependencies] +SQLAlchemy = ">=1.4.0" pymysql = "^1.1.0" pypdf2 = "^3.0.1" pydoc-markdown = "^4.8.2" diff --git a/tests/common/configuration/test_configuration.py b/tests/common/configuration/test_configuration.py index 81d49432d7..a883f76ddb 100644 --- a/tests/common/configuration/test_configuration.py +++ b/tests/common/configuration/test_configuration.py @@ -778,10 +778,11 @@ def test_values_serialization() -> None: assert deserialize_value("K", v, Wei) == Wei("0.01") # test credentials - credentials_str = "databricks+connector://token:@:443/?conn_timeout=15&search_path=a%2Cb%2Cc" + credentials_str = "databricks+connector://token:-databricks_token-@:443/?conn_timeout=15&search_path=a%2Cb%2Cc" credentials = deserialize_value("credentials", credentials_str, ConnectionStringCredentials) assert credentials.drivername == "databricks+connector" assert credentials.query == {"conn_timeout": "15", "search_path": "a,b,c"} + assert credentials.password == "-databricks_token-" assert serialize_value(credentials) == credentials_str # using dict also works credentials_dict = dict(credentials) diff --git a/tests/common/configuration/test_spec_union.py b/tests/common/configuration/test_spec_union.py index 25c32920bc..4892967ab7 100644 --- a/tests/common/configuration/test_spec_union.py +++ b/tests/common/configuration/test_spec_union.py @@ -1,7 +1,6 @@ import itertools import os import pytest -from sqlalchemy.engine import Engine, create_engine from typing import Optional, Union, Any import dlt @@ -236,6 +235,10 @@ def test_google_auth_union(environment: Any) -> None: assert isinstance(credentials, GcpServiceAccountCredentials) +class Engine: + pass + + @dlt.source def sql_database(credentials: Union[ConnectionStringCredentials, Engine, str] = dlt.secrets.value): yield dlt.resource([credentials], name="creds") @@ -243,7 +246,7 @@ def sql_database(credentials: Union[ConnectionStringCredentials, Engine, str] = def test_union_concrete_type(environment: Any) -> None: # we can pass engine explicitly - engine = create_engine("sqlite:///:memory:", echo=True) + engine = Engine() db = sql_database(credentials=engine) creds = list(db)[0] assert isinstance(creds, Engine) diff --git a/tests/load/snowflake/test_snowflake_configuration.py b/tests/load/snowflake/test_snowflake_configuration.py index fb8ff925c0..d0ca4de41b 100644 --- a/tests/load/snowflake/test_snowflake_configuration.py +++ b/tests/load/snowflake/test_snowflake_configuration.py @@ -1,7 +1,7 @@ import os import pytest from pathlib import Path -from sqlalchemy.engine import make_url +from dlt.common.libs.sql_alchemy import make_url pytest.importorskip("snowflake")