From b35845f63f6be5b535ea180266cf614d876559fe Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Tue, 13 Feb 2024 09:41:26 -0800 Subject: [PATCH 01/10] Prepare release environment --- .codegen.json | 17 + .github/workflows/acceptance.yml | 46 ++ .github/workflows/push.yml | 28 +- .github/workflows/release.yml | 48 ++ CHANGELOG.md | 5 + CONTRIBUTING.md | 118 +++- LICENSE | 82 ++- Makefile | 28 + pyproject.toml | 722 +++++++++++++++++--- src/databricks/labs/lsql/__about__.py | 5 +- src/databricks/labs/lsql/py.typed | 1 + tests/__init__.py | 3 - tests/integration/__init__.py | 0 tests/{ => integration}/test_integration.py | 0 tests/{ => unit}/test_lib.py | 0 15 files changed, 949 insertions(+), 154 deletions(-) create mode 100644 .codegen.json create mode 100644 .github/workflows/acceptance.yml create mode 100644 .github/workflows/release.yml create mode 100644 CHANGELOG.md create mode 100644 Makefile create mode 100644 src/databricks/labs/lsql/py.typed create mode 100644 tests/integration/__init__.py rename tests/{ => integration}/test_integration.py (100%) rename tests/{ => unit}/test_lib.py (100%) diff --git a/.codegen.json b/.codegen.json new file mode 100644 index 00000000..43712dca --- /dev/null +++ b/.codegen.json @@ -0,0 +1,17 @@ +{ + "version": { + "src/databricks/labs/lsql/__about__.py": "__version__ = \"$VERSION\"" + }, + "toolchain": { + "required": ["python3"], + "pre_setup": [ + "python3 -m pip install hatch==1.7.0", + "python3 -m hatch env create" + ], + "prepend_path": ".venv/bin", + "acceptance_path": "tests/integration", + "test": [ + "pytest -n 4 --cov src --cov-report=xml --timeout 30 tests/unit --durations 20" + ] + } +} \ No newline at end of file diff --git a/.github/workflows/acceptance.yml b/.github/workflows/acceptance.yml new file mode 100644 index 00000000..55a38d83 --- /dev/null +++ b/.github/workflows/acceptance.yml @@ -0,0 +1,46 @@ +name: acceptance + +on: + pull_request: + types: [opened, synchronize] + +permissions: + id-token: write + contents: read + pull-requests: write + +jobs: + integration: + if: github.event_name == 'pull_request' + environment: runtime + runs-on: larger # needs to be explicitly enabled per repo + steps: + - name: Checkout Code + uses: actions/checkout@v2.5.0 + + - name: Unshallow + run: git fetch --prune --unshallow + + - name: Install Python + uses: actions/setup-python@v4 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: '3.10' + + - name: Install hatch + run: pip install hatch==1.7.0 + + - uses: azure/login@v1 + with: + client-id: ${{ secrets.ARM_CLIENT_ID }} + tenant-id: ${{ secrets.ARM_TENANT_ID }} + allow-no-subscriptions: true + + - name: Run integration tests + run: hatch run integration + env: + CLOUD_ENV: azure + DATABRICKS_HOST: "${{ vars.DATABRICKS_HOST }}" + DATABRICKS_CLUSTER_ID: "${{ vars.DATABRICKS_CLUSTER_ID }}" + DATABRICKS_WAREHOUSE_ID: "${{ vars.DATABRICKS_WAREHOUSE_ID }}" diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index d8b0ad2b..0b54d8f7 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -14,14 +14,12 @@ on: branches: - main -env: - HATCH_VERSION: 1.7.0 - jobs: ci: strategy: + fail-fast: false matrix: - pyVersion: [ '3.8', '3.9', '3.10', '3.11', '3.12' ] + pyVersion: [ '3.10', '3.11', '3.12' ] runs-on: ubuntu-latest steps: - name: Checkout @@ -34,11 +32,10 @@ jobs: cache-dependency-path: '**/pyproject.toml' python-version: ${{ matrix.pyVersion }} - - name: Install hatch - run: pip install hatch==$HATCH_VERSION - - name: Run unit tests - run: hatch run unit:test + run: | + pip install hatch==1.7.0 + make test - name: Publish test coverage uses: codecov/codecov-action@v1 @@ -49,15 +46,8 @@ jobs: - name: Checkout uses: actions/checkout@v3 - - name: Install Python - uses: actions/setup-python@v4 - with: - cache: 'pip' - cache-dependency-path: '**/pyproject.toml' - python-version: 3.10.x - - - name: Install hatch - run: pip install hatch==$HATCH_VERSION + - name: Format all files + run: make dev fmt - - name: Verify linting - run: hatch run lint:verify + - name: Fail on differences + run: git diff --exit-code diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..d08f42cb --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,48 @@ +name: Release + +on: + push: + tags: + - 'v*' + +jobs: + publish: + runs-on: ubuntu-latest + environment: release + permissions: + # Used to authenticate to PyPI via OIDC and sign the release's artifacts with sigstore-python. + id-token: write + # Used to attach signing artifacts to the published release. + contents: write + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: '3.10' + + - name: Build wheels + run: | + pip install hatch==1.7.0 + hatch build + + - name: Draft release + uses: softprops/action-gh-release@v1 + with: + draft: true + files: | + dist/databricks_*.whl + dist/databricks_*.tar.gz + + - uses: pypa/gh-action-pypi-publish@release/v1 + name: Publish package distributions to PyPI + + - name: Sign artifacts with Sigstore + uses: sigstore/gh-action-sigstore-python@v2.1.1 + with: + inputs: | + dist/databricks_*.whl + dist/databricks_*.tar.gz + release-signing-artifacts: true \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..2aa8409e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,5 @@ +# Version changelog + +## 0.0.0 + +Initial commit diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 967fdb5d..90809b5c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,3 +1,117 @@ -To setup local dev environment, you have to install `hatch` tooling: `pip install hatch`. +# Contributing -After, you have to configure your IDE with it: `hatch run python -c "import sys; print(sys.executable)" | pbcopy` \ No newline at end of file +## First Principles + +Favoring standard libraries over external dependencies, especially in specific contexts like Databricks, is a best practice in software +development. + +There are several reasons why this approach is encouraged: +- Standard libraries are typically well-vetted, thoroughly tested, and maintained by the official maintainers of the programming language or platform. This ensures a higher level of stability and reliability. +- External dependencies, especially lesser-known or unmaintained ones, can introduce bugs, security vulnerabilities, or compatibility issues that can be challenging to resolve. Adding external dependencies increases the complexity of your codebase. +- Each dependency may have its own set of dependencies, potentially leading to a complex web of dependencies that can be difficult to manage. This complexity can lead to maintenance challenges, increased risk, and longer build times. +- External dependencies can pose security risks. If a library or package has known security vulnerabilities and is widely used, it becomes an attractive target for attackers. Minimizing external dependencies reduces the potential attack surface and makes it easier to keep your code secure. +- Relying on standard libraries enhances code portability. It ensures your code can run on different platforms and environments without being tightly coupled to specific external dependencies. This is particularly important in settings like Databricks, where you may need to run your code on different clusters or setups. +- External dependencies may have their versioning schemes and compatibility issues. When using standard libraries, you have more control over versioning and can avoid conflicts between different dependencies in your project. +- Fewer external dependencies mean faster build and deployment times. Downloading, installing, and managing external packages can slow down these processes, especially in large-scale projects or distributed computing environments like Databricks. +- External dependencies can be abandoned or go unmaintained over time. This can lead to situations where your project relies on outdated or unsupported code. When you depend on standard libraries, you have confidence that the core functionality you rely on will continue to be maintained and improved. + +While minimizing external dependencies is essential, exceptions can be made case-by-case. There are situations where external dependencies are +justified, such as when a well-established and actively maintained library provides significant benefits, like time savings, performance improvements, +or specialized functionality unavailable in standard libraries. + +## Common fixes for `mypy` errors + +See https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html for more details + +### ..., expression has type "None", variable has type "str" + +* Add `assert ... is not None` if it's a body of a method. Example: + +``` +# error: Argument 1 to "delete" of "DashboardWidgetsAPI" has incompatible type "str | None"; expected "str" +self._ws.dashboard_widgets.delete(widget.id) +``` + +after + +``` +assert widget.id is not None +self._ws.dashboard_widgets.delete(widget.id) +``` + +* Add `... | None` if it's in the dataclass. Example: `cloud: str = None` -> `cloud: str | None = None` + +### ..., has incompatible type "Path"; expected "str" + +Add `.as_posix()` to convert Path to str + +### Argument 2 to "get" of "dict" has incompatible type "None"; expected ... + +Add a valid default value for the dictionary return. + +Example: +```python +def viz_type(self) -> str: + return self.viz.get("type", None) +``` + +after: + +Example: +```python +def viz_type(self) -> str: + return self.viz.get("type", "UNKNOWN") +``` + +## Local Setup + +This section provides a step-by-step guide to set up and start working on the project. These steps will help you set up your project environment and dependencies for efficient development. + +To begin, run `make dev` create the default environment and install development dependencies, assuming you've already cloned the github repo. + +```shell +make dev +``` + +Verify installation with +```shell +make test +``` + +Before every commit, apply the consistent formatting of the code, as we want our codebase look consistent: +```shell +make fmt +``` + +Before every commit, run automated bug detector (`make lint`) and unit tests (`make test`) to ensure that automated +pull request checks do pass, before your code is reviewed by others: +```shell +make test +``` + +## First contribution + +Here are the example steps to submit your first contribution: + +1. Make a Fork from ucx repo (if you really want to contribute) +2. `git clone` +3. `git checkout main` (or `gcm` if you're using [ohmyzsh](https://ohmyz.sh/)). +4. `git pull` (or `gl` if you're using [ohmyzsh](https://ohmyz.sh/)). +5. `git checkout -b FEATURENAME` (or `gcb FEATURENAME` if you're using [ohmyzsh](https://ohmyz.sh/)). +6. .. do the work +7. `make fmt` +8. `make lint` +9. .. fix if any +10. `make test` +11. .. fix if any +12. `git commit -a`. Make sure to enter meaningful commit message title. +13. `git push origin FEATURENAME` +14. Go to GitHub UI and create PR. Alternatively, `gh pr create` (if you have [GitHub CLI](https://cli.github.com/) installed). + Use a meaningful pull request title because it'll appear in the release notes. Use `Resolves #NUMBER` in pull + request description to [automatically link it](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/using-keywords-in-issues-and-pull-requests#linking-a-pull-request-to-an-issue) + to an existing issue. +15. announce PR for the review + +## Troubleshooting + +If you encounter any package dependency errors after `git pull`, run `make clean` diff --git a/LICENSE b/LICENSE index e02a93e6..c8d0d24a 100644 --- a/LICENSE +++ b/LICENSE @@ -1,25 +1,69 @@ -DB license + Databricks License + Copyright (2024) Databricks, Inc. -Copyright (2023) Databricks, Inc. + Definitions. + + Agreement: The agreement between Databricks, Inc., and you governing + the use of the Databricks Services, as that term is defined in + the Master Cloud Services Agreement (MCSA) located at + www.databricks.com/legal/mcsa. + + Licensed Materials: The source code, object code, data, and/or other + works to which this license applies. -Definitions. + Scope of Use. You may not use the Licensed Materials except in + connection with your use of the Databricks Services pursuant to + the Agreement. Your use of the Licensed Materials must comply at all + times with any restrictions applicable to the Databricks Services, + generally, and must be used in accordance with any applicable + documentation. You may view, use, copy, modify, publish, and/or + distribute the Licensed Materials solely for the purposes of using + the Licensed Materials within or connecting to the Databricks Services. + If you do not agree to these terms, you may not view, use, copy, + modify, publish, and/or distribute the Licensed Materials. + + Redistribution. You may redistribute and sublicense the Licensed + Materials so long as all use is in compliance with these terms. + In addition: + + - You must give any other recipients a copy of this License; + - You must cause any modified files to carry prominent notices + stating that you changed the files; + - You must retain, in any derivative works that you distribute, + all copyright, patent, trademark, and attribution notices, + excluding those notices that do not pertain to any part of + the derivative works; and + - If a "NOTICE" text file is provided as part of its + distribution, then any derivative works that you distribute + must include a readable copy of the attribution notices + contained within such NOTICE file, excluding those notices + that do not pertain to any part of the derivative works. -Agreement: The agreement between Databricks, Inc., and you governing the use of the Databricks Services, which shall be, with respect to Databricks, the Databricks Terms of Service located at www.databricks.com/termsofservice, and with respect to Databricks Community Edition, the Community Edition Terms of Service located at www.databricks.com/ce-termsofuse, in each case unless you have entered into a separate written agreement with Databricks governing the use of the applicable Databricks Services. + You may add your own copyright statement to your modifications and may + provide additional license terms and conditions for use, reproduction, + or distribution of your modifications, or for any such derivative works + as a whole, provided your use, reproduction, and distribution of + the Licensed Materials otherwise complies with the conditions stated + in this License. -Software: The source code and object code to which this license applies. + Termination. This license terminates automatically upon your breach of + these terms or upon the termination of your Agreement. Additionally, + Databricks may terminate this license at any time on notice. Upon + termination, you must permanently delete the Licensed Materials and + all copies thereof. -Scope of Use. You may not use this Software except in connection with your use of the Databricks Services pursuant to the Agreement. Your use of the Software must comply at all times with any restrictions applicable to the Databricks Services, generally, and must be used in accordance with any applicable documentation. You may view, use, copy, modify, publish, and/or distribute the Software solely for the purposes of using the code within or connecting to the Databricks Services. If you do not agree to these terms, you may not view, use, copy, modify, publish, and/or distribute the Software. + DISCLAIMER; LIMITATION OF LIABILITY. -Redistribution. You may redistribute and sublicense the Software so long as all use is in compliance with these terms. In addition: - -You must give any other recipients a copy of this License; -You must cause any modified files to carry prominent notices stating that you changed the files; -You must retain, in the source code form of any derivative works that you distribute, all copyright, patent, trademark, and attribution notices from the source code form, excluding those notices that do not pertain to any part of the derivative works; and -If the source code form includes a "NOTICE" text file as part of its distribution, then any derivative works that you distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the derivative works. -You may add your own copyright statement to your modifications and may provide additional license terms and conditions for use, reproduction, or distribution of your modifications, or for any such derivative works as a whole, provided your use, reproduction, and distribution of the Software otherwise complies with the conditions stated in this License. - -Termination. This license terminates automatically upon your breach of these terms or upon the termination of your Agreement. Additionally, Databricks may terminate this license at any time on notice. Upon termination, you must permanently delete the Software and all copies thereof. - -DISCLAIMER; LIMITATION OF LIABILITY. - -THE SOFTWARE IS PROVIDED “AS-IS” AND WITH ALL FAULTS. DATABRICKS, ON BEHALF OF ITSELF AND ITS LICENSORS, SPECIFICALLY DISCLAIMS ALL WARRANTIES RELATING TO THE SOURCE CODE, EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, IMPLIED WARRANTIES, CONDITIONS AND OTHER TERMS OF MERCHANTABILITY, SATISFACTORY QUALITY OR FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. DATABRICKS AND ITS LICENSORS TOTAL AGGREGATE LIABILITY RELATING TO OR ARISING OUT OF YOUR USE OF OR DATABRICKS’ PROVISIONING OF THE SOURCE CODE SHALL BE LIMITED TO ONE THOUSAND ($1,000) DOLLARS. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file + THE LICENSED MATERIALS ARE PROVIDED “AS-IS” AND WITH ALL FAULTS. + DATABRICKS, ON BEHALF OF ITSELF AND ITS LICENSORS, SPECIFICALLY + DISCLAIMS ALL WARRANTIES RELATING TO THE LICENSED MATERIALS, EXPRESS + AND IMPLIED, INCLUDING, WITHOUT LIMITATION, IMPLIED WARRANTIES, + CONDITIONS AND OTHER TERMS OF MERCHANTABILITY, SATISFACTORY QUALITY OR + FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. DATABRICKS AND + ITS LICENSORS TOTAL AGGREGATE LIABILITY RELATING TO OR ARISING OUT OF + YOUR USE OF OR DATABRICKS’ PROVISIONING OF THE LICENSED MATERIALS SHALL + BE LIMITED TO ONE THOUSAND ($1,000) DOLLARS. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE LICENSED MATERIALS OR + THE USE OR OTHER DEALINGS IN THE LICENSED MATERIALS. diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..4b9667fe --- /dev/null +++ b/Makefile @@ -0,0 +1,28 @@ +all: clean lint fmt test coverage + +clean: + rm -fr .venv clean htmlcov .mypy_cache .pytest_cache .ruff_cache .coverage coverage.xml + rm -fr **/*.pyc + +.venv/bin/python: + pip install hatch + hatch env create + +dev: .venv/bin/python + @hatch run which python + +lint: + hatch run verify + +fmt: + hatch run fmt + +test: + hatch run test + +integration: + hatch run integration + +coverage: + hatch run coverage && open htmlcov/index.html + diff --git a/pyproject.toml b/pyproject.toml index 0de2dbe6..eb8e45be 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,7 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - [project] name = "databricks-labs-lsql" dynamic = ["version"] -description = 'Lightweight stateless SQL execution for Databricks with minimal dependencies' +description = "Lightweight stateless SQL execution for Databricks with minimal dependencies" readme = "README.md" requires-python = ">=3.10" license-files = { paths = ["LICENSE", "NOTICE"] } @@ -15,17 +11,16 @@ authors = [ ] classifiers = [ "License :: Other/Proprietary License", - "Development Status :: 4 - Beta", + "Development Status :: 3 - Alpha", "Programming Language :: Python", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ + "databricks-labs-blueprint~=0.2.5", "databricks-sdk~=0.21.0", "PyYAML>=6.0.0,<7.0.0", "sqlglot~=22.2.1" @@ -40,123 +35,636 @@ Source = "https://github.com/databrickslabs/databricks-labs-lsql" path = "src/databricks/labs/lsql/__about__.py" [tool.hatch.envs.default] -python = "python3" dependencies = [ - "coverage[toml]>=6.5", - "pytest", + "coverage[toml]>=6.5", + "pytest", + "pylint", + "pytest-xdist", + "pytest-cov>=4.0.0,<5.0.0", + "pytest-mock>=3.0.0,<4.0.0", + "pytest-timeout", + "ruff>=0.0.243", + "isort>=2.5.0", + "mypy", + "types-PyYAML", + "types-requests", ] + +python="3.10" + +# store virtual env as the child of this folder. Helps VSCode (and PyCharm) to run better +path = ".venv" + [tool.hatch.envs.default.scripts] -test = "pytest {args:tests}" -test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] +test = "pytest -n 2 --cov src --cov-report=xml --timeout 30 tests/unit --durations 20" +coverage = "pytest -n 2 --cov src tests/unit --timeout 30 --cov-report=html --durations 20" +integration = "pytest -n 10 --cov src tests/integration --durations 20" +fmt = ["isort .", + "ruff format", + "ruff . --fix", + "mypy .", + "pylint --output-format=colorized -j 0 src"] +verify = ["black --check .", + "isort . --check-only", + "ruff .", + "mypy .", + "pylint --output-format=colorized -j 0 src"] -[[tool.hatch.envs.all.matrix]] -python = ["3.7", "3.8", "3.9", "3.10", "3.11"] +[tool.isort] +profile = "black" -[tool.hatch.envs.lint] -detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] -[tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/databricks/labs/lsql tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +[tool.pytest.ini_options] +addopts = "--no-header" +cache_dir = ".venv/pytest-cache" [tool.black] -target-version = ["py37"] +target-version = ["py310"] line-length = 120 skip-string-normalization = true [tool.ruff] -target-version = "py37" +cache-dir = ".venv/ruff-cache" +target-version = "py310" line-length = 120 -select = [ - "A", - "ARG", - "B", - "C", - "DTZ", - "E", - "EM", - "F", - "FBT", - "I", - "ICN", - "ISC", - "N", - "PLC", - "PLE", - "PLR", - "PLW", - "Q", - "RUF", - "S", - "T", - "TID", - "UP", - "W", - "YTT", -] -ignore = [ - # Allow non-abstract empty methods in abstract base classes - "B027", - # Allow boolean positional values in function calls, like `dict.get(... True)` - "FBT003", - # Ignore checks for possible passwords - "S105", "S106", "S107", - # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", -] -unfixable = [ - # Don't touch unused imports - "F401", -] [tool.ruff.isort] -known-first-party = ["databricks.labs"] - -[tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" - -[tool.ruff.per-file-ignores] -# Tests can use magic values, assertions, and relative imports -"tests/**/*" = ["PLR2004", "S101", "TID252"] +known-first-party = ["databricks.labs.blueprint"] [tool.coverage.run] branch = true parallel = true -omit = [ - "src/databricks/labs/lsql/__about__.py", -] - -[tool.coverage.paths] -databricks_labs_lsql = ["src/databricks/labs/lsql", "*/databricks/labs/lsql/src/databricks/labs/lsql"] -tests = ["tests", "*/databricks/labs/lsql/tests"] [tool.coverage.report] +omit = ["*/working-copy/*", 'src/databricks/labs/blueprint/__main__.py'] exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", ] + +[tool.pylint.main] +# PyLint configuration is adapted from Google Python Style Guide with modifications. +# Sources https://google.github.io/styleguide/pylintrc +# License: https://github.com/google/styleguide/blob/gh-pages/LICENSE + +# Analyse import fallback blocks. This can be used to support both Python 2 and 3 +# compatible code, which means that the block might have code that exists only in +# one or another interpreter, leading to false positives when analysed. +# analyse-fallback-blocks = + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint in +# a server-like mode. +# clear-cache-post-run = + +# Always return a 0 (non-error) status code, even if lint errors are found. This +# is primarily useful in continuous integration scripts. +# exit-zero = + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +# extension-pkg-allow-list = + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +# extension-pkg-whitelist = + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +# fail-on = + +# Specify a score threshold under which the program will exit with error. +fail-under = 10.0 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +# from-stdin = + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, it +# can't be used as an escape character. +# ignore-paths = + +# Files or directories matching the regular expression patterns are skipped. The +# regex matches against base names, not paths. The default value ignores Emacs +# file locks +ignore-patterns = ["^\\.#"] + +# List of module names for which member attributes should not be checked (useful +# for modules/projects where namespaces are manipulated during runtime and thus +# existing member attributes cannot be deduced by static analysis). It supports +# qualified module names, as well as Unix pattern matching. +# ignored-modules = + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +# init-hook = + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +# jobs = + +# Control the amount of potential inferred values when inferring a single object. +# This can help the performance when dealing with large functions or complex, +# nested conditions. +limit-inference-results = 100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins = ["pylint.extensions.check_elif", "pylint.extensions.bad_builtin", "pylint.extensions.docparams", "pylint.extensions.for_any_all", "pylint.extensions.set_membership", "pylint.extensions.code_style", "pylint.extensions.overlapping_exceptions", "pylint.extensions.typing", "pylint.extensions.redefined_variable_type", "pylint.extensions.comparison_placement", "pylint.extensions.broad_try_clause", "pylint.extensions.dict_init_mutate", "pylint.extensions.consider_refactoring_into_while_condition"] + +# Pickle collected data for later comparisons. +persistent = true + +# Minimum Python version to use for version dependent checks. Will default to the +# version used to run pylint. +py-version = "3.10" + +# Discover python modules and packages in the file system subtree. +# recursive = + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +# source-roots = + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode = true + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +# unsafe-load-any-extension = + +[tool.pylint.basic] +# Naming style matching correct argument names. +argument-naming-style = "snake_case" + +# Regular expression matching correct argument names. Overrides argument-naming- +# style. If left empty, argument names will be checked with the set naming style. +argument-rgx = "[a-z_][a-z0-9_]{2,30}$" + +# Naming style matching correct attribute names. +attr-naming-style = "snake_case" + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +attr-rgx = "[a-z_][a-z0-9_]{2,}$" + +# Bad variable names which should always be refused, separated by a comma. +bad-names = ["foo", "bar", "baz", "toto", "tutu", "tata"] + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +# bad-names-rgxs = + +# Naming style matching correct class attribute names. +class-attribute-naming-style = "any" + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +class-attribute-rgx = "([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$" + +# Naming style matching correct class constant names. +class-const-naming-style = "UPPER_CASE" + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +# class-const-rgx = + +# Naming style matching correct class names. +class-naming-style = "PascalCase" + +# Regular expression matching correct class names. Overrides class-naming-style. +# If left empty, class names will be checked with the set naming style. +class-rgx = "[A-Z_][a-zA-Z0-9]+$" + +# Naming style matching correct constant names. +const-naming-style = "UPPER_CASE" + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming style. +const-rgx = "(([A-Z_][A-Z0-9_]*)|(__.*__))$" + +# Minimum line length for functions/classes that require docstrings, shorter ones +# are exempt. +docstring-min-length = -1 + +# Naming style matching correct function names. +function-naming-style = "snake_case" + +# Regular expression matching correct function names. Overrides function-naming- +# style. If left empty, function names will be checked with the set naming style. +function-rgx = "[a-z_][a-z0-9_]{2,30}$" + +# Good variable names which should always be accepted, separated by a comma. +good-names = ["i", "j", "k", "ex", "Run", "_"] + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +# good-names-rgxs = + +# Include a hint for the correct naming format with invalid-name. +# include-naming-hint = + +# Naming style matching correct inline iteration names. +inlinevar-naming-style = "any" + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +inlinevar-rgx = "[A-Za-z_][A-Za-z0-9_]*$" + +# Naming style matching correct method names. +method-naming-style = "snake_case" + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +method-rgx = "[a-z_][a-z0-9_]{2,}$" + +# Naming style matching correct module names. +module-naming-style = "snake_case" + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +module-rgx = "(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$" + +# Colon-delimited sets of names that determine each other's naming style when the +# name regexes allow several styles. +# name-group = + +# Regular expression which should only match function or class names that do not +# require a docstring. +no-docstring-rgx = "__.*__" + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. These +# decorators are taken in consideration only for invalid-name. +property-classes = ["abc.abstractproperty"] + +# Regular expression matching correct type alias names. If left empty, type alias +# names will be checked with the set naming style. +# typealias-rgx = + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +# typevar-rgx = + +# Naming style matching correct variable names. +variable-naming-style = "snake_case" + +# Regular expression matching correct variable names. Overrides variable-naming- +# style. If left empty, variable names will be checked with the set naming style. +variable-rgx = "[a-z_][a-z0-9_]{2,30}$" + +[tool.pylint.broad_try_clause] +# Maximum number of statements allowed in a try clause +max-try-statements = 7 + +[tool.pylint.classes] +# Warn about protected attribute access inside special methods +# check-protected-access-in-special-methods = + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods = ["__init__", "__new__", "setUp", "__post_init__"] + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected = ["_asdict", "_fields", "_replace", "_source", "_make"] + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg = ["cls"] + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg = ["mcs"] + +[tool.pylint.deprecated_builtins] +# List of builtins function names that should not be used, separated by a comma +bad-functions = ["map", "input"] + +[tool.pylint.design] +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +# exclude-too-few-public-methods = + +# List of qualified class names to ignore when counting class parents (see R0901) +# ignored-parents = + +# Maximum number of arguments for function / method. +max-args = 9 + +# Maximum number of attributes for a class (see R0902). +max-attributes = 11 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr = 5 + +# Maximum number of branch for function / method body. +max-branches = 20 + +# Maximum number of locals for function / method body. +max-locals = 19 + +# Maximum number of parents for a class (see R0901). +max-parents = 7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods = 20 + +# Maximum number of return / yield for function / method body. +max-returns = 11 + +# Maximum number of statements in function / method body. +max-statements = 50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods = 2 + +[tool.pylint.exceptions] +# Exceptions that will emit a warning when caught. +overgeneral-exceptions = ["builtins.Exception"] + +[tool.pylint.format] +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +# expected-line-ending-format = + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines = "^\\s*(# )??$" + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren = 4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string = " " + +# Maximum number of characters on a single line. +max-line-length = 100 + +# Maximum number of lines in a module. +max-module-lines = 2000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +# single-line-class-stmt = + +# Allow the body of an if to be on the same line as the test if there is no else. +# single-line-if-stmt = + +[tool.pylint.imports] +# List of modules that can be imported at any level, not just the top level one. +# allow-any-import-level = + +# Allow explicit reexports by alias from a package __init__. +# allow-reexport-from-package = + +# Allow wildcard imports from modules that define __all__. +# allow-wildcard-with-all = + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules = ["regsub", "TERMIOS", "Bastion", "rexec"] + +# Output a graph (.gv or any supported image format) of external dependencies to +# the given file (report RP0402 must not be disabled). +# ext-import-graph = + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be disabled). +# import-graph = + +# Output a graph (.gv or any supported image format) of internal dependencies to +# the given file (report RP0402 must not be disabled). +# int-import-graph = + +# Force import order to recognize a module as part of the standard compatibility +# libraries. +# known-standard-library = + +# Force import order to recognize a module as part of a third party library. +known-third-party = ["enchant"] + +# Couples of modules and preferred modules, separated by a comma. +# preferred-modules = + +[tool.pylint.logging] +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style = "old" + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules = ["logging"] + +[tool.pylint."messages control"] +# Only show warnings with the listed confidence levels. Leave empty to show all. +# Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence = ["HIGH", "CONTROL_FLOW", "INFERENCE", "INFERENCE_FAILURE", "UNDEFINED"] + +# Disable the message, report, category or checker with the given id(s). You can +# either give multiple identifiers separated by comma (,) or put this option +# multiple times (only on the command line, not in the configuration file where +# it should appear only once). You can also use "--disable=all" to disable +# everything first and then re-enable specific checks. For example, if you want +# to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable = ["raw-checker-failed", "bad-inline-option", "locally-disabled", "file-ignored", "suppressed-message", "deprecated-pragma", "use-implicit-booleaness-not-comparison-to-string", "use-implicit-booleaness-not-comparison-to-zero", "consider-using-augmented-assign", "prefer-typing-namedtuple", "attribute-defined-outside-init", "invalid-name", "missing-module-docstring", "missing-class-docstring", "missing-function-docstring", "protected-access", "too-few-public-methods", "line-too-long", "too-many-lines", "trailing-whitespace", "missing-final-newline", "trailing-newlines", "bad-indentation", "unnecessary-semicolon", "multiple-statements", "superfluous-parens", "mixed-line-endings", "unexpected-line-ending-format", "fixme", "consider-using-assignment-expr", "logging-fstring-interpolation", "consider-using-any-or-all"] + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where it +# should appear only once). See also the "--disable" option for examples. +enable = ["useless-suppression", "use-symbolic-message-instead"] + +[tool.pylint.method_args] +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods = ["requests.api.delete", "requests.api.get", "requests.api.head", "requests.api.options", "requests.api.patch", "requests.api.post", "requests.api.put", "requests.api.request"] + +[tool.pylint.miscellaneous] +# List of note tags to take in consideration, separated by a comma. +notes = ["FIXME", "XXX", "TODO"] + +# Regular expression of note tags to take in consideration. +# notes-rgx = + +[tool.pylint.parameter_documentation] +# Whether to accept totally missing parameter documentation in the docstring of a +# function that has parameters. +accept-no-param-doc = true + +# Whether to accept totally missing raises documentation in the docstring of a +# function that raises an exception. +accept-no-raise-doc = true + +# Whether to accept totally missing return documentation in the docstring of a +# function that returns a statement. +accept-no-return-doc = true + +# Whether to accept totally missing yields documentation in the docstring of a +# generator. +accept-no-yields-doc = true + +# If the docstring type cannot be guessed the specified docstring type will be +# used. +default-docstring-type = "default" + +[tool.pylint.refactoring] +# Maximum number of nested blocks for function / method body +max-nested-blocks = 5 + +# Complete name of functions that never returns. When checking for inconsistent- +# return-statements if a never returning function is called then it will be +# considered as an explicit return statement and no message will be printed. +never-returning-functions = ["sys.exit", "argparse.parse_error"] + +[tool.pylint.reports] +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each category, +# as well as 'statement' which is the total number of statements analyzed. This +# score is used by the global evaluation report (RP0004). +evaluation = "max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))" + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +# msg-template = + +# Set the output format. Available formats are: text, parseable, colorized, json2 +# (improved json format), json (old json format) and msvs (visual studio). You +# can also give a reporter class, e.g. mypackage.mymodule.MyReporterClass. +# output-format = + +# Tells whether to display a full report or only the messages. +# reports = + +# Activate the evaluation score. +score = true + +[tool.pylint.similarities] +# Comments are removed from the similarity computation +ignore-comments = true + +# Docstrings are removed from the similarity computation +ignore-docstrings = true + +# Imports are removed from the similarity computation +ignore-imports = true + +# Signatures are removed from the similarity computation +ignore-signatures = true + +# Minimum lines number of a similarity. +min-similarity-lines = 6 + +[tool.pylint.spelling] +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions = 2 + +# Spelling dictionary name. No available dictionaries : You need to install both +# the python package and the system dependency for enchant to work. +# spelling-dict = + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives = "fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:,pragma:,# noinspection" + +# List of comma separated words that should not be checked. +# spelling-ignore-words = + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file = ".pyenchant_pylint_custom_dict.txt" + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +# spelling-store-unknown-words = + +[tool.pylint.typecheck] +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators = ["contextlib.contextmanager"] + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members = "REQUEST,acl_users,aq_parent,argparse.Namespace" + +# Tells whether missing members accessed in mixin class should be ignored. A +# class is considered mixin if its name matches the mixin-class-rgx option. +# Tells whether to warn about missing members when the owner of the attribute is +# inferred to be None. +ignore-none = true + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference can +# return multiple potential results while evaluating a Python object, but some +# branches might not be evaluated, which results in partial inference. In that +# case, it might be useful to still emit no-member and other checks for the rest +# of the inferred objects. +ignore-on-opaque-inference = true + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins = ["no-member", "not-async-context-manager", "not-context-manager", "attribute-defined-outside-init"] + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes = ["SQLObject", "optparse.Values", "thread._local", "_thread._local"] + +# Show a hint with possible names when a member name was not found. The aspect of +# finding the hint is based on edit distance. +missing-member-hint = true + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance = 1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices = 1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx = ".*MixIn" + +# List of decorators that change the signature of a decorated function. +# signature-mutators = + +[tool.pylint.variables] +# List of additional names supposed to be defined in builtins. Remember that you +# should avoid defining new builtins when possible. +# additional-builtins = + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables = true + +# List of names allowed to shadow builtins +# allowed-redefined-builtins = + +# List of strings which can identify a callback function by name. A callback name +# must start or end with one of those strings. +callbacks = ["cb_", "_cb"] + +# A regular expression matching the name of dummy variables (i.e. expected to not +# be used). +dummy-variables-rgx = "_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_" + +# Argument names that match this expression will be ignored. +ignored-argument-names = "_.*|^ignored_|^unused_" + +# Tells whether we should check for unused import in __init__ files. +# init-import = + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules = ["six.moves", "past.builtins", "future.builtins", "builtins", "io"] diff --git a/src/databricks/labs/lsql/__about__.py b/src/databricks/labs/lsql/__about__.py index 97688a1d..6c8e6b97 100644 --- a/src/databricks/labs/lsql/__about__.py +++ b/src/databricks/labs/lsql/__about__.py @@ -1,4 +1 @@ -# SPDX-FileCopyrightText: 2023-present Serge Smertin -# -# SPDX-License-Identifier: MIT -__version__ = "0.0.1" +__version__ = "0.0.0" diff --git a/src/databricks/labs/lsql/py.typed b/src/databricks/labs/lsql/py.typed new file mode 100644 index 00000000..731adc94 --- /dev/null +++ b/src/databricks/labs/lsql/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. The databricks-labs-lsql package uses inline types. diff --git a/tests/__init__.py b/tests/__init__.py index 255c9c76..e69de29b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present Serge Smertin -# -# SPDX-License-Identifier: MIT diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_integration.py b/tests/integration/test_integration.py similarity index 100% rename from tests/test_integration.py rename to tests/integration/test_integration.py diff --git a/tests/test_lib.py b/tests/unit/test_lib.py similarity index 100% rename from tests/test_lib.py rename to tests/unit/test_lib.py From b8e985e935d899a64dffd8b85e8ccb98ca32bebc Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Tue, 13 Feb 2024 10:05:44 -0800 Subject: [PATCH 02/10] .. --- src/databricks/labs/lsql/__init__.py | 1 - src/databricks/labs/lsql/__main__.py | 1 - src/databricks/labs/lsql/deployment.py | 40 +++ src/databricks/labs/lsql/lib.py | 439 +++++++++++------------- src/databricks/labs/lsql/sql_backend.py | 195 +++++++---- tests/integration/test_integration.py | 23 +- tests/unit/test_lib.py | 279 ++++++++------- 7 files changed, 541 insertions(+), 437 deletions(-) create mode 100644 src/databricks/labs/lsql/deployment.py diff --git a/src/databricks/labs/lsql/__init__.py b/src/databricks/labs/lsql/__init__.py index 013f83dc..e69de29b 100644 --- a/src/databricks/labs/lsql/__init__.py +++ b/src/databricks/labs/lsql/__init__.py @@ -1 +0,0 @@ -from .lib import Row \ No newline at end of file diff --git a/src/databricks/labs/lsql/__main__.py b/src/databricks/labs/lsql/__main__.py index dbcf98df..e69de29b 100644 --- a/src/databricks/labs/lsql/__main__.py +++ b/src/databricks/labs/lsql/__main__.py @@ -1 +0,0 @@ -from databricks.sdk import Config diff --git a/src/databricks/labs/lsql/deployment.py b/src/databricks/labs/lsql/deployment.py new file mode 100644 index 00000000..56846347 --- /dev/null +++ b/src/databricks/labs/lsql/deployment.py @@ -0,0 +1,40 @@ +import logging +import pkgutil +from typing import Any + +from databricks.labs.lsql.sql_backend import Dataclass, SqlBackend + +logger = logging.getLogger(__name__) + + +class SchemaDeployer: + def __init__(self, sql_backend: SqlBackend, inventory_schema: str, mod: Any): + self._sql_backend = sql_backend + self._inventory_schema = inventory_schema + self._module = mod + + def deploy_schema(self): + logger.info(f"Ensuring {self._inventory_schema} database exists") + self._sql_backend.execute(f"CREATE SCHEMA IF NOT EXISTS hive_metastore.{self._inventory_schema}") + + def delete_schema(self): + logger.info(f"deleting {self._inventory_schema} database") + + self._sql_backend.execute(f"DROP SCHEMA IF EXISTS hive_metastore.{self._inventory_schema} CASCADE") + + def deploy_table(self, name: str, klass: Dataclass): + logger.info(f"Ensuring {self._inventory_schema}.{name} table exists") + self._sql_backend.create_table(f"hive_metastore.{self._inventory_schema}.{name}", klass) + + def deploy_view(self, name: str, relative_filename: str): + query = self._load(relative_filename) + logger.info(f"Ensuring {self._inventory_schema}.{name} view matches {relative_filename} contents") + ddl = f"CREATE OR REPLACE VIEW hive_metastore.{self._inventory_schema}.{name} AS {query}" + self._sql_backend.execute(ddl) + + def _load(self, relative_filename: str) -> str: + data = pkgutil.get_data(self._module.__name__, relative_filename) + assert data is not None + sql = data.decode("utf-8") + sql = sql.replace("$inventory", f"hive_metastore.{self._inventory_schema}") + return sql diff --git a/src/databricks/labs/lsql/lib.py b/src/databricks/labs/lsql/lib.py index a6eaee92..27503627 100644 --- a/src/databricks/labs/lsql/lib.py +++ b/src/databricks/labs/lsql/lib.py @@ -1,42 +1,56 @@ -import base64 -import datetime import functools import json import logging import random import time +from collections.abc import Iterator from datetime import timedelta -from typing import Any, Dict, Iterator, List, Optional +from typing import Any -from databricks.sdk.core import DatabricksError -from databricks.sdk.service.sql import (ColumnInfoTypeName, Disposition, - ExecuteStatementResponse, Format, - StatementExecutionAPI, StatementState, - StatementStatus) +from databricks.sdk import WorkspaceClient, errors +from databricks.sdk.errors import DataLoss, NotFound +from databricks.sdk.service.sql import ( + ColumnInfoTypeName, + Disposition, + ExecuteStatementResponse, + Format, + ResultData, + ServiceError, + ServiceErrorCode, + StatementState, + StatementStatus, +) -_LOG = logging.getLogger(__name__) +MAX_SLEEP_PER_ATTEMPT = 10 +MAX_PLATFORM_TIMEOUT = 50 -class Row(tuple): +MIN_PLATFORM_TIMEOUT = 5 - def __new__(cls, columns: List[str], values: List[Any]) -> 'Row': - row = tuple.__new__(cls, values) - row.__columns__ = columns - return row +_LOG = logging.getLogger("databricks.sdk") - # Python SDK convention - def as_dict(self) -> Dict[str, any]: - return dict(zip(self.__columns__, self)) - # PySpark convention - asDict = as_dict +class _RowCreator(tuple): + def __new__(cls, fields): + instance = super().__new__(cls, fields) + return instance + + def __repr__(self): + field_values = ", ".join(f"{field}={getattr(self, field)}" for field in self) + return f"{self.__class__.__name__}({field_values})" + + +class Row(tuple): + # Python SDK convention + def as_dict(self) -> dict[str, Any]: + return dict(zip(self.__columns__, self, strict=True)) # PySpark convention def __contains__(self, item): return item in self.__columns__ def __getitem__(self, col): - if isinstance(col, (int, slice)): + if isinstance(col, int | slice): return super().__getitem__(col) # if columns are named `2 + 2`, for example return self.__getattr__(col) @@ -46,281 +60,214 @@ def __getattr__(self, col): idx = self.__columns__.index(col) return self[idx] except IndexError: - raise AttributeError(col) + raise AttributeError(col) from None except ValueError: - raise AttributeError(col) + raise AttributeError(col) from None def __repr__(self): - return f"Row({', '.join(f'{k}={v}' for (k, v) in zip(self.__columns__, self))})" - - -class StatementExecutionExt(StatementExecutionAPI): - """ - Execute SQL statements in a stateless manner. - - Primary use-case of :py:meth:`iterate_rows` and :py:meth:`execute` methods is oriented at executing SQL queries in - a stateless manner straight away from Databricks SDK for Python, without requiring any external dependencies. - Results are fetched in JSON format through presigned external links. This is perfect for serverless applications - like AWS Lambda, Azure Functions, or any other containerised short-lived applications, where container startup - time is faster with the smaller dependency set. - - .. code-block: - - for (pickup_zip, dropoff_zip) in w.statement_execution.iterate_rows(warehouse_id, - 'SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10', catalog='samples'): - print(f'pickup_zip={pickup_zip}, dropoff_zip={dropoff_zip}') + return f"Row({', '.join(f'{k}={v}' for (k, v) in zip(self.__columns__, self, strict=True))})" - Method :py:meth:`iterate_rows` returns an iterator of objects, that resemble :class:`pyspark.sql.Row` APIs, but full - compatibility is not the goal of this implementation. - .. code-block:: - - iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog='samples') - for row in iterate_rows('SELECT * FROM nyctaxi.trips LIMIT 10'): - pickup_time, dropoff_time = row[0], row[1] - pickup_zip = row.pickup_zip - dropoff_zip = row['dropoff_zip'] - all_fields = row.as_dict() - print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') - - When you only need to execute the query and have no need to iterate over results, use the :py:meth:`execute`. - - .. code-block:: - - w.statement_execution.execute(warehouse_id, 'CREATE TABLE foo AS SELECT * FROM range(10)') - - Applications, that need to a more traditional SQL Python APIs with cursors, efficient data transfer of hundreds of - megabytes or gigabytes of data serialized in Apache Arrow format, and low result fetching latency, should use - the stateful Databricks SQL Connector for Python. - """ - - def __init__(self, api_client): - super().__init__(api_client) - self._type_converters = { +class StatementExecutionExt: + def __init__(self, ws: WorkspaceClient): + self._api = ws.api_client + self.execute_statement = functools.partial(ws.statement_execution.execute_statement) + self.cancel_execution = functools.partial(ws.statement_execution.cancel_execution) + self.get_statement = functools.partial(ws.statement_execution.get_statement) + self.type_converters = { ColumnInfoTypeName.ARRAY: json.loads, - ColumnInfoTypeName.BINARY: base64.b64decode, + # ColumnInfoTypeName.BINARY: not_supported(ColumnInfoTypeName.BINARY), ColumnInfoTypeName.BOOLEAN: bool, + # ColumnInfoTypeName.BYTE: not_supported(ColumnInfoTypeName.BYTE), ColumnInfoTypeName.CHAR: str, - ColumnInfoTypeName.DATE: self._parse_date, + # ColumnInfoTypeName.DATE: not_supported(ColumnInfoTypeName.DATE), ColumnInfoTypeName.DOUBLE: float, ColumnInfoTypeName.FLOAT: float, ColumnInfoTypeName.INT: int, + # ColumnInfoTypeName.INTERVAL: not_supported(ColumnInfoTypeName.INTERVAL), ColumnInfoTypeName.LONG: int, ColumnInfoTypeName.MAP: json.loads, ColumnInfoTypeName.NULL: lambda _: None, ColumnInfoTypeName.SHORT: int, ColumnInfoTypeName.STRING: str, ColumnInfoTypeName.STRUCT: json.loads, - ColumnInfoTypeName.TIMESTAMP: self._parse_timestamp, + # ColumnInfoTypeName.TIMESTAMP: not_supported(ColumnInfoTypeName.TIMESTAMP), + # ColumnInfoTypeName.USER_DEFINED_TYPE: not_supported(ColumnInfoTypeName.USER_DEFINED_TYPE), } - @staticmethod - def _parse_date(value: str) -> datetime.date: - year, month, day = value.split('-') - return datetime.date(int(year), int(month), int(day)) - - @staticmethod - def _parse_timestamp(value: str) -> datetime.datetime: - # make it work with Python 3.7 to 3.10 as well - return datetime.datetime.fromisoformat(value.replace('Z', '+00:00')) - @staticmethod def _raise_if_needed(status: StatementStatus): if status.state not in [StatementState.FAILED, StatementState.CANCELED, StatementState.CLOSED]: return - err = status.error - if err is not None: - message = err.message.strip() - error_code = err.error_code.value - raise DatabricksError(message, error_code=error_code) - raise DatabricksError(status.state.value) - - def execute(self, - warehouse_id: str, - statement: str, - *, - byte_limit: Optional[int] = None, - catalog: Optional[str] = None, - schema: Optional[str] = None, - timeout: timedelta = timedelta(minutes=20), - ) -> ExecuteStatementResponse: - """(Experimental) Execute a SQL statement and block until results are ready, - including starting the warehouse if needed. - - This is a high-level implementation that works with fetching records in JSON format. - It can be considered as a quick way to run SQL queries by just depending on - Databricks SDK for Python without the need of any other compiled library dependencies. - - This method is a higher-level wrapper over :py:meth:`execute_statement` and fetches results - in JSON format through the external link disposition, with client-side polling until - the statement succeeds in execution. Whenever the statement is failed, cancelled, or - closed, this method raises `DatabricksError` with the state message and the relevant - error code. - - To seamlessly iterate over the rows from query results, please use :py:meth:`iterate_rows`. - - :param warehouse_id: str - Warehouse upon which to execute a statement. - :param statement: str - SQL statement to execute - :param byte_limit: int (optional) - Applies the given byte limit to the statement's result size. Byte counts are based on internal - representations and may not match measurable sizes in the JSON format. - :param catalog: str (optional) - Sets default catalog for statement execution, similar to `USE CATALOG` in SQL. - :param schema: str (optional) - Sets default schema for statement execution, similar to `USE SCHEMA` in SQL. - :param timeout: timedelta (optional) - Timeout after which the query is cancelled. If timeout is less than 50 seconds, - it is handled on the server side. If the timeout is greater than 50 seconds, - Databricks SDK for Python cancels the statement execution and throws `TimeoutError`. - :return: ExecuteStatementResponse - """ + status_error = status.error + if status_error is None: + status_error = ServiceError(message="unknown", error_code=ServiceErrorCode.UNKNOWN) + error_message = status_error.message + if error_message is None: + error_message = "" + if "SCHEMA_NOT_FOUND" in error_message: + raise NotFound(error_message) + if "TABLE_OR_VIEW_NOT_FOUND" in error_message: + raise NotFound(error_message) + if "DELTA_TABLE_NOT_FOUND" in error_message: + raise NotFound(error_message) + if "DELTA_MISSING_TRANSACTION_LOG" in error_message: + raise DataLoss(error_message) + mapping = { + ServiceErrorCode.ABORTED: errors.Aborted, + ServiceErrorCode.ALREADY_EXISTS: errors.AlreadyExists, + ServiceErrorCode.BAD_REQUEST: errors.BadRequest, + ServiceErrorCode.CANCELLED: errors.Cancelled, + ServiceErrorCode.DEADLINE_EXCEEDED: errors.DeadlineExceeded, + ServiceErrorCode.INTERNAL_ERROR: errors.InternalError, + ServiceErrorCode.IO_ERROR: errors.InternalError, + ServiceErrorCode.NOT_FOUND: errors.NotFound, + ServiceErrorCode.RESOURCE_EXHAUSTED: errors.ResourceExhausted, + ServiceErrorCode.SERVICE_UNDER_MAINTENANCE: errors.TemporarilyUnavailable, + ServiceErrorCode.TEMPORARILY_UNAVAILABLE: errors.TemporarilyUnavailable, + ServiceErrorCode.UNAUTHENTICATED: errors.Unauthenticated, + ServiceErrorCode.UNKNOWN: errors.Unknown, + ServiceErrorCode.WORKSPACE_TEMPORARILY_UNAVAILABLE: errors.TemporarilyUnavailable, + } + error_code = status_error.error_code + if error_code is None: + error_code = ServiceErrorCode.UNKNOWN + error_class = mapping.get(error_code, errors.Unknown) + raise error_class(error_message) + + def execute( + self, + warehouse_id: str, + statement: str, + *, + byte_limit: int | None = None, + catalog: str | None = None, + schema: str | None = None, + timeout: timedelta = timedelta(minutes=20), + ) -> ExecuteStatementResponse: # The wait_timeout field must be 0 seconds (disables wait), # or between 5 seconds and 50 seconds. wait_timeout = None - if 5 <= timeout.total_seconds() <= 50: + if MIN_PLATFORM_TIMEOUT <= timeout.total_seconds() <= MAX_PLATFORM_TIMEOUT: # set server-side timeout wait_timeout = f"{timeout.total_seconds()}s" _LOG.debug(f"Executing SQL statement: {statement}") + # technically, we can do Disposition.EXTERNAL_LINKS, but let's push it further away. # format is limited to Format.JSON_ARRAY, but other iterations may include ARROW_STREAM. - immediate_response = self.execute_statement(warehouse_id=warehouse_id, - statement=statement, - catalog=catalog, - schema=schema, - disposition=Disposition.EXTERNAL_LINKS, - format=Format.JSON_ARRAY, - byte_limit=byte_limit, - wait_timeout=wait_timeout) - - if immediate_response.status.state == StatementState.SUCCEEDED: + immediate_response = self.execute_statement( + warehouse_id=warehouse_id, + statement=statement, + catalog=catalog, + schema=schema, + disposition=Disposition.INLINE, + format=Format.JSON_ARRAY, + byte_limit=byte_limit, + wait_timeout=wait_timeout, + ) + + status = immediate_response.status + if status is None: + status = StatementStatus(state=StatementState.FAILED) + if status.state == StatementState.SUCCEEDED: return immediate_response - self._raise_if_needed(immediate_response.status) + self._raise_if_needed(status) attempt = 1 status_message = "polling..." deadline = time.time() + timeout.total_seconds() + statement_id = immediate_response.statement_id + if not statement_id: + msg = f"No statement id: {immediate_response}" + raise ValueError(msg) while time.time() < deadline: - res = self.get_statement(immediate_response.statement_id) - if res.status.state == StatementState.SUCCEEDED: - return ExecuteStatementResponse(manifest=res.manifest, - result=res.result, - statement_id=res.statement_id, - status=res.status) - status_message = f"current status: {res.status.state.value}" - self._raise_if_needed(res.status) - sleep = attempt - if sleep > 10: - # sleep 10s max per attempt - sleep = 10 - _LOG.debug(f"SQL statement {res.statement_id}: {status_message} (sleeping ~{sleep}s)") + res = self.get_statement(statement_id) + result_status = res.status + if not result_status: + msg = f"Result status is none: {res}" + raise ValueError(msg) + state = result_status.state + if not state: + state = StatementState.FAILED + if state == StatementState.SUCCEEDED: + return ExecuteStatementResponse( + manifest=res.manifest, result=res.result, statement_id=statement_id, status=result_status + ) + status_message = f"current status: {state.value}" + self._raise_if_needed(result_status) + sleep = min(attempt, MAX_SLEEP_PER_ATTEMPT) + _LOG.debug(f"SQL statement {statement_id}: {status_message} (sleeping ~{sleep}s)") time.sleep(sleep + random.random()) attempt += 1 - self.cancel_execution(immediate_response.statement_id) + self.cancel_execution(statement_id) msg = f"timed out after {timeout}: {status_message}" raise TimeoutError(msg) - def iterate_rows(self, - warehouse_id: str, - statement: str, - *, - byte_limit: Optional[int] = None, - catalog: Optional[str] = None, - schema: Optional[str] = None, - timeout: timedelta = timedelta(minutes=20), - ) -> Iterator[Row]: - """(Experimental) Execute a query and iterate over all available records. - - This method is a wrapper over :py:meth:`execute` with the handling of chunked result - processing and deserialization of those into separate rows, which are yielded from - a returned iterator. Every row API resembles those of :class:`pyspark.sql.Row`, - but full compatibility is not the goal of this implementation. - - .. code-block:: - - iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog='samples') - for row in iterate_rows('SELECT * FROM nyctaxi.trips LIMIT 10'): - pickup_time, dropoff_time = row[0], row[1] - pickup_zip = row.pickup_zip - dropoff_zip = row['dropoff_zip'] - all_fields = row.as_dict() - print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') - - :param warehouse_id: str - Warehouse upon which to execute a statement. - :param statement: str - SQL statement to execute - :param byte_limit: int (optional) - Applies the given byte limit to the statement's result size. Byte counts are based on internal - representations and may not match measurable sizes in the JSON format. - :param catalog: str (optional) - Sets default catalog for statement execution, similar to `USE CATALOG` in SQL. - :param schema: str (optional) - Sets default schema for statement execution, similar to `USE SCHEMA` in SQL. - :param timeout: timedelta (optional) - Timeout after which the query is cancelled. If timeout is less than 50 seconds, - it is handled on the server side. If the timeout is greater than 50 seconds, - Databricks SDK for Python cancels the statement execution and throws `TimeoutError`. - :return: Iterator[Row] - """ - execute_response = self.execute(warehouse_id, - statement, - byte_limit=byte_limit, - catalog=catalog, - schema=schema, - timeout=timeout) - if execute_response.result.external_links is None: - return [] - for row in self._iterate_external_disposition(execute_response): - yield row - - def _result_schema(self, execute_response: ExecuteStatementResponse): - col_names = [] - col_conv = [] - for col in execute_response.manifest.schema.columns: - col_names.append(col.name) - conv = self._type_converters.get(col.type_name, None) - if conv is None: - msg = f"{col.name} has no {col.type_name.value} converter" - raise ValueError(msg) - col_conv.append(conv) - row_factory = functools.partial(Row, col_names) - return row_factory, col_conv - - def _iterate_external_disposition(self, execute_response: ExecuteStatementResponse) -> Iterator[Row]: - # ensure that we close the HTTP session after fetching the external links + def execute_fetch_all( + self, + warehouse_id: str, + statement: str, + *, + byte_limit: int | None = None, + catalog: str | None = None, + schema: str | None = None, + timeout: timedelta = timedelta(minutes=20), + ) -> Iterator[Row]: + execute_response = self.execute( + warehouse_id, statement, byte_limit=byte_limit, catalog=catalog, schema=schema, timeout=timeout + ) + col_conv, row_factory = self._row_converters(execute_response) result_data = execute_response.result - row_factory, col_conv = self._result_schema(execute_response) - with self._api._new_session() as http: - while True: - for external_link in result_data.external_links: - response = http.get(external_link.external_link) - response.raise_for_status() - for data in response.json(): - yield row_factory(col_conv[i](value) for i, value in enumerate(data)) - - if external_link.next_chunk_index is None: - return - - result_data = self.get_statement_result_chunk_n(execute_response.statement_id, - external_link.next_chunk_index) - - def _iterate_inline_disposition(self, execute_response: ExecuteStatementResponse) -> Iterator[Row]: - result_data = execute_response.result - row_factory, col_conv = self._result_schema(execute_response) + if result_data is None: + return while True: - # case for Disposition.INLINE, where we get rows embedded into a response - for data in result_data.data_array: + data_array = result_data.data_array + if not data_array: + data_array = [] + for data in data_array: # enumerate() + iterator + tuple constructor makes it more performant # on larger humber of records for Python, even though it's less # readable code. - yield row_factory(col_conv[i](value) for i, value in enumerate(data)) - + row = [] + for i, value in enumerate(data): + if value is None: + row.append(None) + else: + row.append(col_conv[i](value)) + yield row_factory(row) if result_data.next_chunk_index is None: return + if not result_data.next_chunk_internal_link: + continue + # TODO: replace once ES-828324 is fixed + json_response = self._api.do("GET", result_data.next_chunk_internal_link) + result_data = ResultData.from_dict(json_response) # type: ignore[arg-type] - result_data = self.get_statement_result_chunk_n(execute_response.statement_id, - result_data.next_chunk_index) \ No newline at end of file + def _row_converters(self, execute_response): + col_names = [] + col_conv = [] + manifest = execute_response.manifest + if not manifest: + msg = f"missing manifest: {execute_response}" + raise ValueError(msg) + manifest_schema = manifest.schema + if not manifest_schema: + msg = f"missing schema: {manifest}" + raise ValueError(msg) + columns = manifest_schema.columns + if not columns: + columns = [] + for col in columns: + col_names.append(col.name) + type_name = col.type_name + if not type_name: + type_name = ColumnInfoTypeName.NULL + conv = self.type_converters.get(type_name, None) + if conv is None: + msg = f"{col.name} has no {type_name.value} converter" + raise ValueError(msg) + col_conv.append(conv) + row_factory = type("Row", (Row,), {"__columns__": col_names}) + return col_conv, row_factory diff --git a/src/databricks/labs/lsql/sql_backend.py b/src/databricks/labs/lsql/sql_backend.py index 1c3a9c27..2182ad35 100644 --- a/src/databricks/labs/lsql/sql_backend.py +++ b/src/databricks/labs/lsql/sql_backend.py @@ -1,94 +1,148 @@ import dataclasses import logging +import os import re from abc import ABC, abstractmethod -from typing import Iterator, ClassVar +from collections.abc import Callable, Iterable, Iterator, Sequence +from types import UnionType +from typing import Any, ClassVar, Protocol, TypeVar from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import ( + BadRequest, + DatabricksError, + DataLoss, + NotFound, + PermissionDenied, + Unknown, +) -from databricks.labs.lsql.lib import StatementExecutionExt +from databricks.labs.lsql.lib import Row, StatementExecutionExt logger = logging.getLogger(__name__) + +class DataclassInstance(Protocol): + __dataclass_fields__: ClassVar[dict] + + +Result = TypeVar("Result", bound=DataclassInstance) +Dataclass = type[DataclassInstance] +ResultFn = Callable[[], Iterable[Result]] + + class SqlBackend(ABC): @abstractmethod - def execute(self, sql): + def execute(self, sql: str) -> None: raise NotImplementedError @abstractmethod - def fetch(self, sql) -> Iterator[any]: + def fetch(self, sql: str) -> Iterator[Any]: raise NotImplementedError @abstractmethod - def save_table(self, full_name: str, rows: list[any], klass: type, mode: str = "append"): + def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"): raise NotImplementedError - def create_table(self, full_name: str, klass: type): + def create_table(self, full_name: str, klass: Dataclass): ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._schema_for(klass)}) USING DELTA" self.execute(ddl) - _builtin_type_mapping: ClassVar[dict[type, str]] = {str: "STRING", int: "INT", bool: "BOOLEAN", float: "FLOAT"} + _builtin_type_mapping: ClassVar[dict[type, str]] = { + str: "STRING", + int: "LONG", + bool: "BOOLEAN", + float: "FLOAT", + } @classmethod - def _schema_for(cls, klass): + def _schema_for(cls, klass: Dataclass): fields = [] for f in dataclasses.fields(klass): - if f.type not in cls._builtin_type_mapping: - msg = f"Cannot auto-convert {f.type}" + field_type = f.type + if isinstance(field_type, UnionType): + field_type = field_type.__args__[0] + if field_type not in cls._builtin_type_mapping: + msg = f"Cannot auto-convert {field_type}" raise SyntaxError(msg) not_null = " NOT NULL" if f.default is None: not_null = "" - spark_type = cls._builtin_type_mapping[f.type] + spark_type = cls._builtin_type_mapping[field_type] fields.append(f"{f.name} {spark_type}{not_null}") return ", ".join(fields) @classmethod - def _filter_none_rows(cls, rows, full_name): + def _filter_none_rows(cls, rows, klass): if len(rows) == 0: return rows results = [] - nullable_fields = set() - - for field in dataclasses.fields(rows[0]): - if field.default is None: - nullable_fields.add(field.name) - + class_fields = dataclasses.fields(klass) for row in rows: if row is None: continue - row_contains_none = False - for column, value in dataclasses.asdict(row).items(): - if value is None and column not in nullable_fields: - logger.warning(f"[{full_name}] Field {column} is None, filtering row") - row_contains_none = True - break - - if not row_contains_none: - results.append(row) + for field in class_fields: + if not hasattr(row, field.name): + logger.debug(f"Field {field.name} not present in row {dataclasses.asdict(row)}") + continue + if field.default is not None and getattr(row, field.name) is None: + msg = f"Not null constraint violated for column {field.name}, row = {dataclasses.asdict(row)}" + raise ValueError(msg) + results.append(row) return results + _whitespace = re.compile(r"\s{2,}") + + @classmethod + def _only_n_bytes(cls, j: str, num_bytes: int = 96) -> str: + j = cls._whitespace.sub(" ", j) + diff = len(j.encode("utf-8")) - num_bytes + if diff > 0: + return f"{j[:num_bytes]}... ({diff} more bytes)" + return j + + @staticmethod + def _api_error_from_message(error_message: str) -> DatabricksError: + if "SCHEMA_NOT_FOUND" in error_message: + return NotFound(error_message) + if "TABLE_OR_VIEW_NOT_FOUND" in error_message: + return NotFound(error_message) + if "DELTA_TABLE_NOT_FOUND" in error_message: + return NotFound(error_message) + if "DELTA_MISSING_TRANSACTION_LOG" in error_message: + return DataLoss(error_message) + if "UNRESOLVED_COLUMN.WITH_SUGGESTION" in error_message: + return BadRequest(error_message) + if "PARSE_SYNTAX_ERROR" in error_message: + return BadRequest(error_message) + if "Operation not allowed" in error_message: + return PermissionDenied(error_message) + return Unknown(error_message) + class StatementExecutionBackend(SqlBackend): def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: int = 1000): - self._sql = StatementExecutionExt(ws.api_client) + self._sql = StatementExecutionExt(ws) self._warehouse_id = warehouse_id self._max_records_per_batch = max_records_per_batch + debug_truncate_bytes = ws.config.debug_truncate_bytes + # while unit-testing, this value will contain a mock + self._debug_truncate_bytes = debug_truncate_bytes if isinstance(debug_truncate_bytes, int) else 96 - def execute(self, sql): - logger.debug(f"[api][execute] {sql}") + def execute(self, sql: str) -> None: + logger.debug(f"[api][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") self._sql.execute(self._warehouse_id, sql) - def fetch(self, sql) -> Iterator[any]: - logger.debug(f"[api][fetch] {sql}") + def fetch(self, sql: str) -> Iterator[Row]: + logger.debug(f"[api][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") return self._sql.execute_fetch_all(self._warehouse_id, sql) - def save_table(self, full_name: str, rows: list[any], klass: dataclasses.dataclass, mode="append"): + def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"): if mode == "overwrite": msg = "Overwrite mode is not yet supported" raise NotImplementedError(msg) - rows = self._filter_none_rows(rows, full_name) + rows = self._filter_none_rows(rows, klass) self.create_table(full_name, klass) if len(rows) == 0: return @@ -105,88 +159,103 @@ def _row_to_sql(row, fields): data = [] for f in fields: value = getattr(row, f.name) + field_type = f.type + if isinstance(field_type, UnionType): + field_type = field_type.__args__[0] if value is None: data.append("NULL") - elif f.type == bool: + elif field_type == bool: data.append("TRUE" if value else "FALSE") - elif f.type == str: + elif field_type == str: value = str(value).replace("'", "''") data.append(f"'{value}'") - elif f.type == int: + elif field_type == int: data.append(f"{value}") else: - msg = f"unknown type: {f.type}" + msg = f"unknown type: {field_type}" raise ValueError(msg) return ", ".join(data) class RuntimeBackend(SqlBackend): - def __init__(self): - from pyspark.sql.session import SparkSession + def __init__(self, debug_truncate_bytes: int | None = None): + # pylint: disable-next=import-error,import-outside-toplevel + from pyspark.sql.session import SparkSession # type: ignore[import-not-found] if "DATABRICKS_RUNTIME_VERSION" not in os.environ: msg = "Not in the Databricks Runtime" raise RuntimeError(msg) self._spark = SparkSession.builder.getOrCreate() - - def execute(self, sql): - logger.debug(f"[spark][execute] {sql}") - self._spark.sql(sql) - - def fetch(self, sql) -> Iterator[any]: - logger.debug(f"[spark][fetch] {sql}") - return self._spark.sql(sql).collect() - - def save_table(self, full_name: str, rows: list[any], klass: dataclasses.dataclass, mode: str = "append"): - rows = self._filter_none_rows(rows, full_name) + self._debug_truncate_bytes = debug_truncate_bytes if debug_truncate_bytes else 96 + + def execute(self, sql: str) -> None: + logger.debug(f"[spark][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") + try: + self._spark.sql(sql) + except Exception as e: + error_message = str(e) + raise self._api_error_from_message(error_message) from None + + def fetch(self, sql: str) -> Iterator[Row]: + logger.debug(f"[spark][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") + try: + return self._spark.sql(sql).collect() + except Exception as e: + error_message = str(e) + raise self._api_error_from_message(error_message) from None + + def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"): + rows = self._filter_none_rows(rows, klass) if len(rows) == 0: self.create_table(full_name, klass) return # pyspark deals well with lists of dataclass instances, as long as schema is provided - df = self._spark.createDataFrame(rows, self._schema_for(rows[0])) + df = self._spark.createDataFrame(rows, self._schema_for(klass)) df.write.saveAsTable(full_name, mode=mode) + class MockBackend(SqlBackend): - def __init__(self, *, fails_on_first: dict | None = None, rows: dict | None = None): + def __init__(self, *, fails_on_first: dict | None = None, rows: dict | None = None, debug_truncate_bytes=96): self._fails_on_first = fails_on_first if not rows: rows = {} self._rows = rows - self._save_table = [] - self.queries = [] + self._save_table: list[tuple[str, Sequence[DataclassInstance], str]] = [] + self._debug_truncate_bytes = debug_truncate_bytes + self.queries: list[str] = [] - def _sql(self, sql): - logger.debug(f"Mock backend.sql() received SQL: {sql}") + def _sql(self, sql: str): + logger.debug(f"Mock backend.sql() received SQL: {self._only_n_bytes(sql, self._debug_truncate_bytes)}") seen_before = sql in self.queries self.queries.append(sql) if not seen_before and self._fails_on_first is not None: for match, failure in self._fails_on_first.items(): if match in sql: - raise RuntimeError(failure) + raise self._api_error_from_message(failure) from None def execute(self, sql): self._sql(sql) - def fetch(self, sql) -> Iterator[any]: + def fetch(self, sql) -> Iterator[Row]: self._sql(sql) rows = [] if self._rows: for pattern in self._rows.keys(): r = re.compile(pattern) - if r.match(sql): + if r.search(sql): logger.debug(f"Found match: {sql}") rows.extend(self._rows[pattern]) logger.debug(f"Returning rows: {rows}") return iter(rows) - def save_table(self, full_name: str, rows: list[any], klass, mode: str = "append"): + def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass, mode: str = "append"): if klass.__class__ == type: self._save_table.append((full_name, rows, mode)) - def rows_written_for(self, full_name: str, mode: str) -> list[any]: - rows = [] + def rows_written_for(self, full_name: str, mode: str) -> list[DataclassInstance]: + rows: list[DataclassInstance] = [] for stub_full_name, stub_rows, stub_mode in self._save_table: if not (stub_full_name == full_name and stub_mode == mode): continue diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index e9fff9e0..9e45b7b3 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -3,29 +3,30 @@ def test_sql_execution_chunked(w): all_warehouses = w.warehouses.list() - assert len(w.warehouses.list()) > 0, 'at least one SQL warehouse required' + assert len(w.warehouses.list()) > 0, "at least one SQL warehouse required" warehouse_id = all_warehouses[0].id total = 0 fetch = functools.partial(w.statement_execution.iterate_rows, warehouse_id) - for (x, ) in fetch('SELECT id FROM range(2000000)'): + for (x,) in fetch("SELECT id FROM range(2000000)"): total += x print(total) def test_sql_execution(w, env_or_skip): - warehouse_id = env_or_skip('TEST_DEFAULT_WAREHOUSE_ID') - for (pickup_zip, dropoff_zip) in w.statement_execution.iterate_rows( - warehouse_id, 'SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10', catalog='samples'): - print(f'pickup_zip={pickup_zip}, dropoff_zip={dropoff_zip}') + warehouse_id = env_or_skip("TEST_DEFAULT_WAREHOUSE_ID") + for pickup_zip, dropoff_zip in w.statement_execution.iterate_rows( + warehouse_id, "SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10", catalog="samples" + ): + print(f"pickup_zip={pickup_zip}, dropoff_zip={dropoff_zip}") def test_sql_execution_partial(w, env_or_skip): - warehouse_id = env_or_skip('TEST_DEFAULT_WAREHOUSE_ID') - iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog='samples') - for row in iterate_rows('SELECT * FROM nyctaxi.trips LIMIT 10'): + warehouse_id = env_or_skip("TEST_DEFAULT_WAREHOUSE_ID") + iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog="samples") + for row in iterate_rows("SELECT * FROM nyctaxi.trips LIMIT 10"): pickup_time, dropoff_time = row[0], row[1] pickup_zip = row.pickup_zip - dropoff_zip = row['dropoff_zip'] + dropoff_zip = row["dropoff_zip"] all_fields = row.as_dict() - print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') \ No newline at end of file + print(f"{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}") diff --git a/tests/unit/test_lib.py b/tests/unit/test_lib.py index 4ea8ff8a..18494dee 100644 --- a/tests/unit/test_lib.py +++ b/tests/unit/test_lib.py @@ -1,92 +1,116 @@ import datetime import pytest - from databricks.sdk import WorkspaceClient from databricks.sdk.core import DatabricksError -from databricks.labs.lsql import Row -from databricks.sdk.service.sql import (ColumnInfo, ColumnInfoTypeName, - Disposition, ExecuteStatementResponse, - ExternalLink, Format, - GetStatementResponse, ResultData, - ResultManifest, ResultSchema, - ServiceError, ServiceErrorCode, - StatementState, StatementStatus, - timedelta) +from databricks.sdk.service.sql import ( + ColumnInfo, + ColumnInfoTypeName, + Disposition, + ExecuteStatementResponse, + ExternalLink, + Format, + GetStatementResponse, + ResultData, + ResultManifest, + ResultSchema, + ServiceError, + ServiceErrorCode, + StatementState, + StatementStatus, + timedelta, +) + +from databricks.labs.lsql.lib import Row def test_execute_poll_succeeds(config, mocker): w = WorkspaceClient(config=config) - execute_statement = mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.PENDING), - statement_id='bcd', - )) - - get_statement = mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.get_statement', - return_value=GetStatementResponse( - manifest=ResultManifest(), - result=ResultData(byte_count=100500), - statement_id='bcd', - status=StatementStatus(state=StatementState.SUCCEEDED))) - - response = w.statement_execution.execute('abc', 'SELECT 2+2') + execute_statement = mocker.patch( + "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", + return_value=ExecuteStatementResponse( + status=StatementStatus(state=StatementState.PENDING), + statement_id="bcd", + ), + ) + + get_statement = mocker.patch( + "databricks.sdk.service.sql.StatementExecutionAPI.get_statement", + return_value=GetStatementResponse( + manifest=ResultManifest(), + result=ResultData(byte_count=100500), + statement_id="bcd", + status=StatementStatus(state=StatementState.SUCCEEDED), + ), + ) + + response = w.statement_execution.execute("abc", "SELECT 2+2") assert response.status.state == StatementState.SUCCEEDED assert response.result.byte_count == 100500 - execute_statement.assert_called_with(warehouse_id='abc', - statement='SELECT 2+2', - disposition=Disposition.EXTERNAL_LINKS, - format=Format.JSON_ARRAY, - byte_limit=None, - catalog=None, - schema=None, - wait_timeout=None) - get_statement.assert_called_with('bcd') + execute_statement.assert_called_with( + warehouse_id="abc", + statement="SELECT 2+2", + disposition=Disposition.EXTERNAL_LINKS, + format=Format.JSON_ARRAY, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + get_statement.assert_called_with("bcd") def test_execute_fails(config, mocker): w = WorkspaceClient(config=config) - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse( - statement_id='bcd', - status=StatementStatus(state=StatementState.FAILED, - error=ServiceError(error_code=ServiceErrorCode.RESOURCE_EXHAUSTED, - message='oops...')))) + mocker.patch( + "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", + return_value=ExecuteStatementResponse( + statement_id="bcd", + status=StatementStatus( + state=StatementState.FAILED, + error=ServiceError(error_code=ServiceErrorCode.RESOURCE_EXHAUSTED, message="oops..."), + ), + ), + ) with pytest.raises(DatabricksError) as err: - w.statement_execution.execute('abc', 'SELECT 2+2') - assert err.value.error_code == 'RESOURCE_EXHAUSTED' - assert str(err.value) == 'oops...' + w.statement_execution.execute("abc", "SELECT 2+2") + assert err.value.error_code == "RESOURCE_EXHAUSTED" + assert str(err.value) == "oops..." def test_execute_poll_waits(config, mocker): w = WorkspaceClient(config=config) - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse(status=StatementStatus(state=StatementState.PENDING), - statement_id='bcd', - )) + mocker.patch( + "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", + return_value=ExecuteStatementResponse( + status=StatementStatus(state=StatementState.PENDING), + statement_id="bcd", + ), + ) runs = [] def _get_statement(statement_id): - assert statement_id == 'bcd' + assert statement_id == "bcd" if len(runs) == 0: runs.append(1) - return GetStatementResponse(status=StatementStatus(state=StatementState.RUNNING), - statement_id='bcd') + return GetStatementResponse(status=StatementStatus(state=StatementState.RUNNING), statement_id="bcd") - return GetStatementResponse(manifest=ResultManifest(), - result=ResultData(byte_count=100500), - statement_id='bcd', - status=StatementStatus(state=StatementState.SUCCEEDED)) + return GetStatementResponse( + manifest=ResultManifest(), + result=ResultData(byte_count=100500), + statement_id="bcd", + status=StatementStatus(state=StatementState.SUCCEEDED), + ) - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.get_statement', wraps=_get_statement) + mocker.patch("databricks.sdk.service.sql.StatementExecutionAPI.get_statement", wraps=_get_statement) - response = w.statement_execution.execute('abc', 'SELECT 2+2') + response = w.statement_execution.execute("abc", "SELECT 2+2") assert response.status.state == StatementState.SUCCEEDED assert response.result.byte_count == 100500 @@ -95,125 +119,150 @@ def _get_statement(statement_id): def test_execute_poll_timeouts_on_client(config, mocker): w = WorkspaceClient(config=config) - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse(status=StatementStatus(state=StatementState.PENDING), - statement_id='bcd', - )) + mocker.patch( + "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", + return_value=ExecuteStatementResponse( + status=StatementStatus(state=StatementState.PENDING), + statement_id="bcd", + ), + ) - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.get_statement', - return_value=GetStatementResponse(status=StatementStatus(state=StatementState.RUNNING), - statement_id='bcd')) + mocker.patch( + "databricks.sdk.service.sql.StatementExecutionAPI.get_statement", + return_value=GetStatementResponse(status=StatementStatus(state=StatementState.RUNNING), statement_id="bcd"), + ) - cancel_execution = mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.cancel_execution') + cancel_execution = mocker.patch("databricks.sdk.service.sql.StatementExecutionAPI.cancel_execution") with pytest.raises(TimeoutError): - w.statement_execution.execute('abc', 'SELECT 2+2', timeout=timedelta(seconds=1)) + w.statement_execution.execute("abc", "SELECT 2+2", timeout=timedelta(seconds=1)) - cancel_execution.assert_called_with('bcd') + cancel_execution.assert_called_with("bcd") def test_fetch_all_no_chunks(config, mocker): w = WorkspaceClient(config=config) - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest(schema=ResultSchema(columns=[ - ColumnInfo(name='id', type_name=ColumnInfoTypeName.INT), - ColumnInfo(name='since', type_name=ColumnInfoTypeName.DATE), - ColumnInfo(name='now', type_name=ColumnInfoTypeName.TIMESTAMP), - ])), - result=ResultData(external_links=[ExternalLink(external_link='https://singed-url')]), - statement_id='bcd', - )) + mocker.patch( + "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", + return_value=ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema( + columns=[ + ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT), + ColumnInfo(name="since", type_name=ColumnInfoTypeName.DATE), + ColumnInfo(name="now", type_name=ColumnInfoTypeName.TIMESTAMP), + ] + ) + ), + result=ResultData(external_links=[ExternalLink(external_link="https://singed-url")]), + statement_id="bcd", + ), + ) raw_response = mocker.Mock() - raw_response.json = lambda: [["1", "2023-09-01", "2023-09-01T13:21:53Z"], - ["2", "2023-09-01", "2023-09-01T13:21:53Z"]] + raw_response.json = lambda: [ + ["1", "2023-09-01", "2023-09-01T13:21:53Z"], + ["2", "2023-09-01", "2023-09-01T13:21:53Z"], + ] - http_get = mocker.patch('requests.sessions.Session.get', return_value=raw_response) + http_get = mocker.patch("requests.sessions.Session.get", return_value=raw_response) rows = list( - w.statement_execution.iterate_rows( - 'abc', 'SELECT id, CAST(NOW() AS DATE) AS since, NOW() AS now FROM range(2)')) + w.statement_execution.iterate_rows("abc", "SELECT id, CAST(NOW() AS DATE) AS since, NOW() AS now FROM range(2)") + ) assert len(rows) == 2 assert rows[0].id == 1 assert isinstance(rows[0].since, datetime.date) assert isinstance(rows[0].now, datetime.datetime) - http_get.assert_called_with('https://singed-url') + http_get.assert_called_with("https://singed-url") def test_fetch_all_no_chunks_no_converter(config, mocker): w = WorkspaceClient(config=config) - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest(schema=ResultSchema( - columns=[ColumnInfo(name='id', type_name=ColumnInfoTypeName.INTERVAL), ])), - result=ResultData(external_links=[ExternalLink(external_link='https://singed-url')]), - statement_id='bcd', - )) + mocker.patch( + "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", + return_value=ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema( + columns=[ + ColumnInfo(name="id", type_name=ColumnInfoTypeName.INTERVAL), + ] + ) + ), + result=ResultData(external_links=[ExternalLink(external_link="https://singed-url")]), + statement_id="bcd", + ), + ) raw_response = mocker.Mock() raw_response.json = lambda: [["1"], ["2"]] - mocker.patch('requests.sessions.Session.get', return_value=raw_response) + mocker.patch("requests.sessions.Session.get", return_value=raw_response) with pytest.raises(ValueError): - list(w.statement_execution.iterate_rows('abc', 'SELECT id FROM range(2)')) + list(w.statement_execution.iterate_rows("abc", "SELECT id FROM range(2)")) def test_fetch_all_two_chunks(config, mocker): w = WorkspaceClient(config=config) - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest(schema=ResultSchema(columns=[ - ColumnInfo(name='id', type_name=ColumnInfoTypeName.INT), - ColumnInfo(name='now', type_name=ColumnInfoTypeName.TIMESTAMP), - ])), - result=ResultData( - external_links=[ExternalLink(external_link='https://first', next_chunk_index=1)]), - statement_id='bcd', - )) + mocker.patch( + "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", + return_value=ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema( + columns=[ + ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT), + ColumnInfo(name="now", type_name=ColumnInfoTypeName.TIMESTAMP), + ] + ) + ), + result=ResultData(external_links=[ExternalLink(external_link="https://first", next_chunk_index=1)]), + statement_id="bcd", + ), + ) next_chunk = mocker.patch( - 'databricks.sdk.service.sql.StatementExecutionAPI.get_statement_result_chunk_n', - return_value=ResultData(external_links=[ExternalLink(external_link='https://second')])) + "databricks.sdk.service.sql.StatementExecutionAPI.get_statement_result_chunk_n", + return_value=ResultData(external_links=[ExternalLink(external_link="https://second")]), + ) raw_response = mocker.Mock() raw_response.json = lambda: [["1", "2023-09-01T13:21:53Z"], ["2", "2023-09-01T13:21:53Z"]] - http_get = mocker.patch('requests.sessions.Session.get', return_value=raw_response) + http_get = mocker.patch("requests.sessions.Session.get", return_value=raw_response) - rows = list(w.statement_execution.iterate_rows('abc', 'SELECT id, NOW() AS now FROM range(2)')) + rows = list(w.statement_execution.iterate_rows("abc", "SELECT id, NOW() AS now FROM range(2)")) assert len(rows) == 4 - assert http_get.call_args_list == [mocker.call('https://first'), mocker.call('https://second')] - next_chunk.assert_called_with('bcd', 1) + assert http_get.call_args_list == [mocker.call("https://first"), mocker.call("https://second")] + next_chunk.assert_called_with("bcd", 1) def test_row(): - row = Row(['a', 'b', 'c'], [1, 2, 3]) + row = Row(["a", "b", "c"], [1, 2, 3]) a, b, c = row assert a == 1 assert b == 2 assert c == 3 - assert 'a' in row + assert "a" in row assert row.a == 1 - assert row['a'] == 1 + assert row["a"] == 1 as_dict = row.as_dict() assert as_dict == row.asDict() - assert as_dict['b'] == 2 + assert as_dict["b"] == 2 - assert 'Row(a=1, b=2, c=3)' == str(row) + assert "Row(a=1, b=2, c=3)" == str(row) with pytest.raises(AttributeError): - print(row.x) \ No newline at end of file + print(row.x) From a7c26fe3823130344d601e446a059634ffe4ff4f Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Sun, 10 Mar 2024 17:08:11 +0100 Subject: [PATCH 03/10] .. --- src/databricks/labs/lsql/lib.py | 295 ++++++++++++++++++++------ tests/integration/conftest.py | 78 +++++++ tests/integration/test_integration.py | 11 +- 3 files changed, 312 insertions(+), 72 deletions(-) create mode 100644 tests/integration/conftest.py diff --git a/src/databricks/labs/lsql/lib.py b/src/databricks/labs/lsql/lib.py index 27503627..f9265e43 100644 --- a/src/databricks/labs/lsql/lib.py +++ b/src/databricks/labs/lsql/lib.py @@ -1,12 +1,16 @@ +import base64 +import datetime import functools import json import logging import random +import threading import time from collections.abc import Iterator from datetime import timedelta from typing import Any +import requests from databricks.sdk import WorkspaceClient, errors from databricks.sdk.errors import DataLoss, NotFound from databricks.sdk.service.sql import ( @@ -18,7 +22,7 @@ ServiceError, ServiceErrorCode, StatementState, - StatementStatus, + StatementStatus, State, ) MAX_SLEEP_PER_ATTEMPT = 10 @@ -27,7 +31,7 @@ MIN_PLATFORM_TIMEOUT = 5 -_LOG = logging.getLogger("databricks.sdk") +logger = logging.getLogger(__name__) class _RowCreator(tuple): @@ -69,32 +73,78 @@ def __repr__(self): class StatementExecutionExt: + """ + Execute SQL statements in a stateless manner. + + Primary use-case of :py:meth:`iterate_rows` and :py:meth:`execute` methods is oriented at executing SQL queries in + a stateless manner straight away from Databricks SDK for Python, without requiring any external dependencies. + Results are fetched in JSON format through presigned external links. This is perfect for serverless applications + like AWS Lambda, Azure Functions, or any other containerised short-lived applications, where container startup + time is faster with the smaller dependency set. + + .. code-block: + + for (pickup_zip, dropoff_zip) in w.statement_execution.iterate_rows(warehouse_id, + 'SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10', catalog='samples'): + print(f'pickup_zip={pickup_zip}, dropoff_zip={dropoff_zip}') + + Method :py:meth:`iterate_rows` returns an iterator of objects, that resemble :class:`pyspark.sql.Row` APIs, but full + compatibility is not the goal of this implementation. + + .. code-block:: + + iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog='samples') + for row in iterate_rows('SELECT * FROM nyctaxi.trips LIMIT 10'): + pickup_time, dropoff_time = row[0], row[1] + pickup_zip = row.pickup_zip + dropoff_zip = row['dropoff_zip'] + all_fields = row.as_dict() + print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') + + When you only need to execute the query and have no need to iterate over results, use the :py:meth:`execute`. + + .. code-block:: + + w.statement_execution.execute(warehouse_id, 'CREATE TABLE foo AS SELECT * FROM range(10)') + + Applications, that need to a more traditional SQL Python APIs with cursors, efficient data transfer of hundreds of + megabytes or gigabytes of data serialized in Apache Arrow format, and low result fetching latency, should use + the stateful Databricks SQL Connector for Python. + """ + def __init__(self, ws: WorkspaceClient): + self._ws = ws self._api = ws.api_client - self.execute_statement = functools.partial(ws.statement_execution.execute_statement) - self.cancel_execution = functools.partial(ws.statement_execution.cancel_execution) - self.get_statement = functools.partial(ws.statement_execution.get_statement) - self.type_converters = { + self._http = requests.Session() + self._lock = threading.Lock() + self._type_converters = { ColumnInfoTypeName.ARRAY: json.loads, - # ColumnInfoTypeName.BINARY: not_supported(ColumnInfoTypeName.BINARY), + ColumnInfoTypeName.BINARY: base64.b64decode, ColumnInfoTypeName.BOOLEAN: bool, - # ColumnInfoTypeName.BYTE: not_supported(ColumnInfoTypeName.BYTE), ColumnInfoTypeName.CHAR: str, - # ColumnInfoTypeName.DATE: not_supported(ColumnInfoTypeName.DATE), + ColumnInfoTypeName.DATE: self._parse_date, ColumnInfoTypeName.DOUBLE: float, ColumnInfoTypeName.FLOAT: float, ColumnInfoTypeName.INT: int, - # ColumnInfoTypeName.INTERVAL: not_supported(ColumnInfoTypeName.INTERVAL), ColumnInfoTypeName.LONG: int, ColumnInfoTypeName.MAP: json.loads, ColumnInfoTypeName.NULL: lambda _: None, ColumnInfoTypeName.SHORT: int, ColumnInfoTypeName.STRING: str, ColumnInfoTypeName.STRUCT: json.loads, - # ColumnInfoTypeName.TIMESTAMP: not_supported(ColumnInfoTypeName.TIMESTAMP), - # ColumnInfoTypeName.USER_DEFINED_TYPE: not_supported(ColumnInfoTypeName.USER_DEFINED_TYPE), + ColumnInfoTypeName.TIMESTAMP: self._parse_timestamp, } + @staticmethod + def _parse_date(value: str) -> datetime.date: + year, month, day = value.split('-') + return datetime.date(int(year), int(month), int(day)) + + @staticmethod + def _parse_timestamp(value: str) -> datetime.datetime: + # make it work with Python 3.7 to 3.10 as well + return datetime.datetime.fromisoformat(value.replace('Z', '+00:00')) + @staticmethod def _raise_if_needed(status: StatementStatus): if status.state not in [StatementState.FAILED, StatementState.CANCELED, StatementState.CLOSED]: @@ -135,33 +185,86 @@ def _raise_if_needed(status: StatementStatus): error_class = mapping.get(error_code, errors.Unknown) raise error_class(error_message) + def _default_warehouse(self) -> str: + with self._lock: + if self._ws.config.warehouse_id: + return self._ws.config.warehouse_id + ids = [] + for v in self._ws.warehouses.list(): + if v.state in [State.DELETED, State.DELETING]: + continue + elif v.state == State.RUNNING: + self._ws.config.warehouse_id = v.id + return self._ws.config.warehouse_id + ids.append(v.id) + if self._ws.config.warehouse_id == "" and len(ids) > 0: + # otherwise - first warehouse + self._ws.config.warehouse_id = ids[0] + return self._ws.config.warehouse_id + raise ValueError("no warehouse id given") + def execute( self, - warehouse_id: str, statement: str, *, + warehouse_id: str | None = None, byte_limit: int | None = None, catalog: str | None = None, schema: str | None = None, timeout: timedelta = timedelta(minutes=20), + disposition: Disposition | None = None, ) -> ExecuteStatementResponse: + """(Experimental) Execute a SQL statement and block until results are ready, + including starting the warehouse if needed. + + This is a high-level implementation that works with fetching records in JSON format. + It can be considered as a quick way to run SQL queries by just depending on + Databricks SDK for Python without the need of any other compiled library dependencies. + + This method is a higher-level wrapper over :py:meth:`execute_statement` and fetches results + in JSON format through the external link disposition, with client-side polling until + the statement succeeds in execution. Whenever the statement is failed, cancelled, or + closed, this method raises `DatabricksError` with the state message and the relevant + error code. + + To seamlessly iterate over the rows from query results, please use :py:meth:`iterate_rows`. + + :param warehouse_id: str + Warehouse upon which to execute a statement. + :param statement: str + SQL statement to execute + :param byte_limit: int (optional) + Applies the given byte limit to the statement's result size. Byte counts are based on internal + representations and may not match measurable sizes in the JSON format. + :param catalog: str (optional) + Sets default catalog for statement execution, similar to `USE CATALOG` in SQL. + :param schema: str (optional) + Sets default schema for statement execution, similar to `USE SCHEMA` in SQL. + :param timeout: timedelta (optional) + Timeout after which the query is cancelled. If timeout is less than 50 seconds, + it is handled on the server side. If the timeout is greater than 50 seconds, + Databricks SDK for Python cancels the statement execution and throws `TimeoutError`. + :return: ExecuteStatementResponse + """ # The wait_timeout field must be 0 seconds (disables wait), # or between 5 seconds and 50 seconds. wait_timeout = None if MIN_PLATFORM_TIMEOUT <= timeout.total_seconds() <= MAX_PLATFORM_TIMEOUT: # set server-side timeout wait_timeout = f"{timeout.total_seconds()}s" + if not warehouse_id: + warehouse_id = self._default_warehouse() - _LOG.debug(f"Executing SQL statement: {statement}") + logger.debug(f"Executing SQL statement: {statement}") # technically, we can do Disposition.EXTERNAL_LINKS, but let's push it further away. # format is limited to Format.JSON_ARRAY, but other iterations may include ARROW_STREAM. - immediate_response = self.execute_statement( + immediate_response = self._ws.statement_execution.execute_statement( warehouse_id=warehouse_id, statement=statement, catalog=catalog, schema=schema, - disposition=Disposition.INLINE, + disposition=disposition, format=Format.JSON_ARRAY, byte_limit=byte_limit, wait_timeout=wait_timeout, @@ -183,7 +286,7 @@ def execute( msg = f"No statement id: {immediate_response}" raise ValueError(msg) while time.time() < deadline: - res = self.get_statement(statement_id) + res = self._ws.statement_execution.get_statement(statement_id) result_status = res.status if not result_status: msg = f"Result status is none: {res}" @@ -198,76 +301,136 @@ def execute( status_message = f"current status: {state.value}" self._raise_if_needed(result_status) sleep = min(attempt, MAX_SLEEP_PER_ATTEMPT) - _LOG.debug(f"SQL statement {statement_id}: {status_message} (sleeping ~{sleep}s)") + logger.debug(f"SQL statement {statement_id}: {status_message} (sleeping ~{sleep}s)") time.sleep(sleep + random.random()) attempt += 1 - self.cancel_execution(statement_id) + self._ws.statement_execution.cancel_execution(statement_id) msg = f"timed out after {timeout}: {status_message}" raise TimeoutError(msg) + def __call__(self, *args, **kwargs): + yield from self.execute_fetch_all(*args, **kwargs) + def execute_fetch_all( self, - warehouse_id: str, statement: str, *, + warehouse_id: str | None = None, byte_limit: int | None = None, catalog: str | None = None, schema: str | None = None, timeout: timedelta = timedelta(minutes=20), + disposition: Disposition | None = None, ) -> Iterator[Row]: - execute_response = self.execute( - warehouse_id, statement, byte_limit=byte_limit, catalog=catalog, schema=schema, timeout=timeout - ) - col_conv, row_factory = self._row_converters(execute_response) + """(Experimental) Execute a query and iterate over all available records. + + This method is a wrapper over :py:meth:`execute` with the handling of chunked result + processing and deserialization of those into separate rows, which are yielded from + a returned iterator. Every row API resembles those of :class:`pyspark.sql.Row`, + but full compatibility is not the goal of this implementation. + + .. code-block:: + + iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog='samples') + for row in iterate_rows('SELECT * FROM nyctaxi.trips LIMIT 10'): + pickup_time, dropoff_time = row[0], row[1] + pickup_zip = row.pickup_zip + dropoff_zip = row['dropoff_zip'] + all_fields = row.as_dict() + print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') + + :param warehouse_id: str + Warehouse upon which to execute a statement. + :param statement: str + SQL statement to execute + :param byte_limit: int (optional) + Applies the given byte limit to the statement's result size. Byte counts are based on internal + representations and may not match measurable sizes in the JSON format. + :param catalog: str (optional) + Sets default catalog for statement execution, similar to `USE CATALOG` in SQL. + :param schema: str (optional) + Sets default schema for statement execution, similar to `USE SCHEMA` in SQL. + :param timeout: timedelta (optional) + Timeout after which the query is cancelled. If timeout is less than 50 seconds, + it is handled on the server side. If the timeout is greater than 50 seconds, + Databricks SDK for Python cancels the statement execution and throws `TimeoutError`. + :return: Iterator[Row] + """ + execute_response = self.execute(statement, + warehouse_id=warehouse_id, + byte_limit=byte_limit, + catalog=catalog, + schema=schema, + timeout=timeout, + disposition=disposition) + if execute_response.result.external_links is None: + return [] + # for row in self._iterate_external_disposition(execute_response): result_data = execute_response.result - if result_data is None: - return + row_factory, col_conv = self._result_schema(execute_response) while True: - data_array = result_data.data_array - if not data_array: - data_array = [] - for data in data_array: - # enumerate() + iterator + tuple constructor makes it more performant - # on larger humber of records for Python, even though it's less - # readable code. - row = [] - for i, value in enumerate(data): - if value is None: - row.append(None) - else: - row.append(col_conv[i](value)) - yield row_factory(row) - if result_data.next_chunk_index is None: + next_chunk_index = result_data.next_chunk_index + if result_data.data_array: + for data in result_data.data_array: + # enumerate() + iterator + tuple constructor makes it more performant + # on larger humber of records for Python, even though it's less + # readable code. + yield row_factory(col_conv[i](value) for i, value in enumerate(data)) + for external_link in result_data.external_links: + response = self._http.get(external_link.external_link) + response.raise_for_status() + for data in response.json(): + yield row_factory(col_conv[i](value) for i, value in enumerate(data)) + if next_chunk_index is None: return - if not result_data.next_chunk_internal_link: - continue - # TODO: replace once ES-828324 is fixed - json_response = self._api.do("GET", result_data.next_chunk_internal_link) - result_data = ResultData.from_dict(json_response) # type: ignore[arg-type] + result_data = self._ws.statement_execution.get_statement_result_chunk_n( + execute_response.statement_id, + next_chunk_index) - def _row_converters(self, execute_response): + def _result_schema(self, execute_response: ExecuteStatementResponse): col_names = [] col_conv = [] - manifest = execute_response.manifest - if not manifest: - msg = f"missing manifest: {execute_response}" - raise ValueError(msg) - manifest_schema = manifest.schema - if not manifest_schema: - msg = f"missing schema: {manifest}" - raise ValueError(msg) - columns = manifest_schema.columns - if not columns: - columns = [] - for col in columns: + for col in execute_response.manifest.schema.columns: col_names.append(col.name) - type_name = col.type_name - if not type_name: - type_name = ColumnInfoTypeName.NULL - conv = self.type_converters.get(type_name, None) + conv = self._type_converters.get(col.type_name, None) if conv is None: - msg = f"{col.name} has no {type_name.value} converter" + msg = f"{col.name} has no {col.type_name.value} converter" raise ValueError(msg) col_conv.append(conv) - row_factory = type("Row", (Row,), {"__columns__": col_names}) - return col_conv, row_factory + row_factory = functools.partial(Row, col_names) + return row_factory, col_conv + + def _iterate_external_disposition(self, execute_response: ExecuteStatementResponse) -> Iterator[Row]: + # ensure that we close the HTTP session after fetching the external links + result_data = execute_response.result + row_factory, col_conv = self._result_schema(execute_response) + with self._api._new_session() as http: + while True: + for external_link in result_data.external_links: + response = http.get(external_link.external_link) + response.raise_for_status() + for data in response.json(): + yield row_factory(col_conv[i](value) for i, value in enumerate(data)) + + if external_link.next_chunk_index is None: + return + + result_data = self._ws.statement_execution.get_statement_result_chunk_n(execute_response.statement_id, + external_link.next_chunk_index) + + def _iterate_inline_disposition(self, execute_response: ExecuteStatementResponse) -> Iterator[Row]: + result_data = execute_response.result + row_factory, col_conv = self._result_schema(execute_response) + while True: + # case for Disposition.INLINE, where we get rows embedded into a response + for data in result_data.data_array: + # enumerate() + iterator + tuple constructor makes it more performant + # on larger humber of records for Python, even though it's less + # readable code. + yield row_factory(col_conv[i](value) for i, value in enumerate(data)) + + if result_data.next_chunk_index is None: + return + + result_data = self._ws.statement_execution.get_statement_result_chunk_n(execute_response.statement_id, + result_data.next_chunk_index) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 00000000..b8b61fe5 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,78 @@ +import json +import logging +import os +import pathlib +import string +import sys +from typing import MutableMapping + +from databricks.sdk import WorkspaceClient +from pytest import fixture + +from databricks.labs.lsql.__about__ import __version__ +from databricks.labs.blueprint.logger import install_logger + +install_logger() +logging.getLogger("databricks").setLevel("DEBUG") + + +def _is_in_debug() -> bool: + return os.path.basename(sys.argv[0]) in {"_jb_pytest_runner.py", "testlauncher.py"} + + +@fixture # type: ignore[no-redef] +def debug_env_name(): + return "ucws" + + +@fixture +def debug_env(monkeypatch, debug_env_name) -> MutableMapping[str, str]: + if not _is_in_debug(): + return os.environ + conf_file = pathlib.Path.home() / ".databricks/debug-env.json" + if not conf_file.exists(): + return os.environ + with conf_file.open("r") as f: + conf = json.load(f) + if debug_env_name not in conf: + sys.stderr.write( + f"""{debug_env_name} not found in ~/.databricks/debug-env.json + + this usually means that you have to add the following fixture to + conftest.py file in the relevant directory: + + @fixture + def debug_env_name(): + return 'ENV_NAME' # where ENV_NAME is one of: {", ".join(conf.keys())} + """ + ) + msg = f"{debug_env_name} not found in ~/.databricks/debug-env.json" + raise KeyError(msg) + for k, v in conf[debug_env_name].items(): + monkeypatch.setenv(k, v) + return os.environ + + +@fixture +def make_random(): + import random + + def inner(k=16) -> str: + charset = string.ascii_uppercase + string.ascii_lowercase + string.digits + return "".join(random.choices(charset, k=int(k))) + + return inner + + +@fixture +def product_info(): + return "lsql", __version__ + + +@fixture +def ws(product_info, debug_env) -> WorkspaceClient: + # Use variables from Unified Auth + # See https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html + product_name, product_version = product_info + return WorkspaceClient(host=debug_env["DATABRICKS_HOST"], product=product_name, product_version=product_version) + diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 9e45b7b3..900f6d3f 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -1,14 +1,13 @@ import functools +from databricks.labs.lsql.lib import StatementExecutionExt +from databricks.sdk.service.sql import Disposition -def test_sql_execution_chunked(w): - all_warehouses = w.warehouses.list() - assert len(w.warehouses.list()) > 0, "at least one SQL warehouse required" - warehouse_id = all_warehouses[0].id +def test_sql_execution_chunked(ws): + see = StatementExecutionExt(ws) total = 0 - fetch = functools.partial(w.statement_execution.iterate_rows, warehouse_id) - for (x,) in fetch("SELECT id FROM range(2000000)"): + for (x,) in see("SELECT id FROM range(2000000)", disposition=Disposition.EXTERNAL_LINKS): total += x print(total) From f6d47d2b7397e2d0e7e6daa811f40c6fbc1ca4dc Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Sun, 10 Mar 2024 18:40:46 +0100 Subject: [PATCH 04/10] .. --- src/databricks/labs/lsql/__init__.py | 3 + src/databricks/labs/lsql/lib.py | 104 ++++++++++++-------------- tests/integration/test_integration.py | 11 ++- tests/unit/test_lib.py | 57 ++++++++------ 4 files changed, 92 insertions(+), 83 deletions(-) diff --git a/src/databricks/labs/lsql/__init__.py b/src/databricks/labs/lsql/__init__.py index e69de29b..6dfe02b6 100644 --- a/src/databricks/labs/lsql/__init__.py +++ b/src/databricks/labs/lsql/__init__.py @@ -0,0 +1,3 @@ +from .lib import Row + +__all__ = ["Row"] diff --git a/src/databricks/labs/lsql/lib.py b/src/databricks/labs/lsql/lib.py index f9265e43..0c70bb31 100644 --- a/src/databricks/labs/lsql/lib.py +++ b/src/databricks/labs/lsql/lib.py @@ -6,6 +6,7 @@ import random import threading import time +import types from collections.abc import Iterator from datetime import timedelta from typing import Any @@ -34,17 +35,29 @@ logger = logging.getLogger(__name__) -class _RowCreator(tuple): - def __new__(cls, fields): - instance = super().__new__(cls, fields) - return instance - - def __repr__(self): - field_values = ", ".join(f"{field}={getattr(self, field)}" for field in self) - return f"{self.__class__.__name__}({field_values})" - - class Row(tuple): + def __new__(cls, *args, **kwargs): + if args and kwargs: + raise ValueError("cannot mix positional and keyword arguments") + if kwargs: + # PySpark's compatibility layer + row = tuple.__new__(cls, list(kwargs.values())) + row.__columns__ = list(kwargs.keys()) + return row + if len(args) == 1 and hasattr(cls, '__columns__') and isinstance(args[0], (types.GeneratorType, list, tuple)): + # this type returned by Row.factory() and we already know the column names + return cls(*args[0]) + if len(args) == 2 and isinstance(args[0], (list, tuple)) and isinstance(args[1], (list, tuple)): + # UCX's compatibility layer + row = tuple.__new__(cls, args[1]) + row.__columns__ = args[0] + return row + return tuple.__new__(cls, args) + + @classmethod + def factory(cls, col_names: list[str]) -> type: + return type("Row", (Row,), {"__columns__": col_names}) + # Python SDK convention def as_dict(self) -> dict[str, Any]: return dict(zip(self.__columns__, self, strict=True)) @@ -60,6 +73,8 @@ def __getitem__(self, col): return self.__getattr__(col) def __getattr__(self, col): + if col.startswith("__"): + raise AttributeError(col) try: idx = self.__columns__.index(col) return self[idx] @@ -69,7 +84,7 @@ def __getattr__(self, col): raise AttributeError(col) from None def __repr__(self): - return f"Row({', '.join(f'{k}={v}' for (k, v) in zip(self.__columns__, self, strict=True))})" + return f"Row({', '.join(f'{k}={v!r}' for (k, v) in zip(self.__columns__, self, strict=True))})" class StatementExecutionExt: @@ -112,11 +127,12 @@ class StatementExecutionExt: the stateful Databricks SQL Connector for Python. """ - def __init__(self, ws: WorkspaceClient): + def __init__(self, ws: WorkspaceClient, disposition: Disposition | None = None): self._ws = ws self._api = ws.api_client self._http = requests.Session() self._lock = threading.Lock() + self._disposition = disposition self._type_converters = { ColumnInfoTypeName.ARRAY: json.loads, ColumnInfoTypeName.BINARY: base64.b64decode, @@ -320,7 +336,6 @@ def execute_fetch_all( catalog: str | None = None, schema: str | None = None, timeout: timedelta = timedelta(minutes=20), - disposition: Disposition | None = None, ) -> Iterator[Row]: """(Experimental) Execute a query and iterate over all available records. @@ -362,75 +377,50 @@ def execute_fetch_all( catalog=catalog, schema=schema, timeout=timeout, - disposition=disposition) - if execute_response.result.external_links is None: - return [] - # for row in self._iterate_external_disposition(execute_response): + disposition=self._disposition) result_data = execute_response.result + if result_data is None: + return [] row_factory, col_conv = self._result_schema(execute_response) while True: - next_chunk_index = result_data.next_chunk_index if result_data.data_array: for data in result_data.data_array: # enumerate() + iterator + tuple constructor makes it more performant # on larger humber of records for Python, even though it's less # readable code. yield row_factory(col_conv[i](value) for i, value in enumerate(data)) + next_chunk_index = result_data.next_chunk_index for external_link in result_data.external_links: + next_chunk_index = external_link.next_chunk_index response = self._http.get(external_link.external_link) response.raise_for_status() for data in response.json(): yield row_factory(col_conv[i](value) for i, value in enumerate(data)) - if next_chunk_index is None: + if not next_chunk_index: return result_data = self._ws.statement_execution.get_statement_result_chunk_n( execute_response.statement_id, next_chunk_index) def _result_schema(self, execute_response: ExecuteStatementResponse): + manifest = execute_response.manifest + if not manifest: + msg = f"missing manifest: {execute_response}" + raise ValueError(msg) + manifest_schema = manifest.schema + if not manifest_schema: + msg = f"missing schema: {manifest}" + raise ValueError(msg) col_names = [] col_conv = [] - for col in execute_response.manifest.schema.columns: + columns = manifest_schema.columns + if not columns: + columns = [] + for col in columns: col_names.append(col.name) conv = self._type_converters.get(col.type_name, None) if conv is None: msg = f"{col.name} has no {col.type_name.value} converter" raise ValueError(msg) col_conv.append(conv) - row_factory = functools.partial(Row, col_names) - return row_factory, col_conv - - def _iterate_external_disposition(self, execute_response: ExecuteStatementResponse) -> Iterator[Row]: - # ensure that we close the HTTP session after fetching the external links - result_data = execute_response.result - row_factory, col_conv = self._result_schema(execute_response) - with self._api._new_session() as http: - while True: - for external_link in result_data.external_links: - response = http.get(external_link.external_link) - response.raise_for_status() - for data in response.json(): - yield row_factory(col_conv[i](value) for i, value in enumerate(data)) - - if external_link.next_chunk_index is None: - return - - result_data = self._ws.statement_execution.get_statement_result_chunk_n(execute_response.statement_id, - external_link.next_chunk_index) - - def _iterate_inline_disposition(self, execute_response: ExecuteStatementResponse) -> Iterator[Row]: - result_data = execute_response.result - row_factory, col_conv = self._result_schema(execute_response) - while True: - # case for Disposition.INLINE, where we get rows embedded into a response - for data in result_data.data_array: - # enumerate() + iterator + tuple constructor makes it more performant - # on larger humber of records for Python, even though it's less - # readable code. - yield row_factory(col_conv[i](value) for i, value in enumerate(data)) - - if result_data.next_chunk_index is None: - return - - result_data = self._ws.statement_execution.get_statement_result_chunk_n(execute_response.statement_id, - result_data.next_chunk_index) + return Row.factory(col_names), col_conv diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 900f6d3f..22b4cdd0 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -1,15 +1,18 @@ import functools +import pytest + from databricks.labs.lsql.lib import StatementExecutionExt from databricks.sdk.service.sql import Disposition -def test_sql_execution_chunked(ws): - see = StatementExecutionExt(ws) +@pytest.mark.parametrize("disposition", [None, Disposition.INLINE, Disposition.EXTERNAL_LINKS]) +def test_sql_execution_chunked(ws, disposition): + see = StatementExecutionExt(ws, disposition=disposition) total = 0 - for (x,) in see("SELECT id FROM range(2000000)", disposition=Disposition.EXTERNAL_LINKS): + for (x,) in see("SELECT id FROM range(2000000)"): total += x - print(total) + assert total == 1999999000000 def test_sql_execution(w, env_or_skip): diff --git a/tests/unit/test_lib.py b/tests/unit/test_lib.py index 18494dee..e26036a1 100644 --- a/tests/unit/test_lib.py +++ b/tests/unit/test_lib.py @@ -24,6 +24,41 @@ from databricks.labs.lsql.lib import Row +@pytest.mark.parametrize('row', [ + Row(foo="bar", enabled=True), + Row(['foo', 'enabled'], ['bar', True]), +]) +def test_row_from_kwargs(row): + assert row.foo == "bar" + assert row["foo"] == "bar" + assert "foo" in row + assert len(row) == 2 + assert list(row) == ["bar", True] + assert row.as_dict() == {"foo": "bar", "enabled": True} + foo, enabled = row + assert foo == "bar" + assert enabled is True + assert str(row) == "Row(foo='bar', enabled=True)" + with pytest.raises(AttributeError): + print(row.x) + + +def test_row_factory(): + factory = Row.factory(["a", "b"]) + row = factory(1, 2) + a, b = row + assert a == 1 + assert b == 2 + + +def test_row_factory_with_generator(): + factory = Row.factory(["a", "b"]) + row = factory(_+1 for _ in range(2)) + a, b = row + assert a == 1 + assert b == 2 + + def test_execute_poll_succeeds(config, mocker): w = WorkspaceClient(config=config) @@ -244,25 +279,3 @@ def test_fetch_all_two_chunks(config, mocker): assert http_get.call_args_list == [mocker.call("https://first"), mocker.call("https://second")] next_chunk.assert_called_with("bcd", 1) - - -def test_row(): - row = Row(["a", "b", "c"], [1, 2, 3]) - - a, b, c = row - assert a == 1 - assert b == 2 - assert c == 3 - - assert "a" in row - assert row.a == 1 - assert row["a"] == 1 - - as_dict = row.as_dict() - assert as_dict == row.asDict() - assert as_dict["b"] == 2 - - assert "Row(a=1, b=2, c=3)" == str(row) - - with pytest.raises(AttributeError): - print(row.x) From bf4c1f1789b75456875e9746e5d8d15744015c0e Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Sun, 10 Mar 2024 19:38:59 +0100 Subject: [PATCH 05/10] .. --- .github/workflows/acceptance.yml | 27 +++-- src/databricks/labs/lsql/__init__.py | 3 +- src/databricks/labs/lsql/{lib.py => core.py} | 113 +++++++++++++++---- src/databricks/labs/lsql/sql_backend.py | 4 +- tests/integration/conftest.py | 16 ++- tests/integration/test_integration.py | 76 +++++++++++-- tests/unit/test_lib.py | 2 +- 7 files changed, 187 insertions(+), 54 deletions(-) rename src/databricks/labs/lsql/{lib.py => core.py} (82%) diff --git a/.github/workflows/acceptance.yml b/.github/workflows/acceptance.yml index 55a38d83..495f86d6 100644 --- a/.github/workflows/acceptance.yml +++ b/.github/workflows/acceptance.yml @@ -2,18 +2,21 @@ name: acceptance on: pull_request: - types: [opened, synchronize] + types: [ opened, synchronize, ready_for_review ] permissions: id-token: write contents: read pull-requests: write +concurrency: + group: single-acceptance-job-per-repo + jobs: integration: - if: github.event_name == 'pull_request' + if: github.event_name == 'pull_request' && github.event.pull_request.draft == false environment: runtime - runs-on: larger # needs to be explicitly enabled per repo + runs-on: larger steps: - name: Checkout Code uses: actions/checkout@v2.5.0 @@ -31,16 +34,12 @@ jobs: - name: Install hatch run: pip install hatch==1.7.0 - - uses: azure/login@v1 - with: - client-id: ${{ secrets.ARM_CLIENT_ID }} - tenant-id: ${{ secrets.ARM_TENANT_ID }} - allow-no-subscriptions: true - - name: Run integration tests - run: hatch run integration + uses: databrickslabs/sandbox/acceptance@acceptance/v0.1.4 + with: + vault_uri: ${{ secrets.VAULT_URI }} + timeout: 45m env: - CLOUD_ENV: azure - DATABRICKS_HOST: "${{ vars.DATABRICKS_HOST }}" - DATABRICKS_CLUSTER_ID: "${{ vars.DATABRICKS_CLUSTER_ID }}" - DATABRICKS_WAREHOUSE_ID: "${{ vars.DATABRICKS_WAREHOUSE_ID }}" + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + ARM_CLIENT_ID: ${{ secrets.ARM_CLIENT_ID }} + ARM_TENANT_ID: ${{ secrets.ARM_TENANT_ID }} diff --git a/src/databricks/labs/lsql/__init__.py b/src/databricks/labs/lsql/__init__.py index 6dfe02b6..a8cb9877 100644 --- a/src/databricks/labs/lsql/__init__.py +++ b/src/databricks/labs/lsql/__init__.py @@ -1,3 +1,4 @@ -from .lib import Row +from .core import Row __all__ = ["Row"] + diff --git a/src/databricks/labs/lsql/lib.py b/src/databricks/labs/lsql/core.py similarity index 82% rename from src/databricks/labs/lsql/lib.py rename to src/databricks/labs/lsql/core.py index 0c70bb31..eed8699e 100644 --- a/src/databricks/labs/lsql/lib.py +++ b/src/databricks/labs/lsql/core.py @@ -12,6 +12,7 @@ from typing import Any import requests +import sqlglot from databricks.sdk import WorkspaceClient, errors from databricks.sdk.errors import DataLoss, NotFound from databricks.sdk.service.sql import ( @@ -127,11 +128,22 @@ class StatementExecutionExt: the stateful Databricks SQL Connector for Python. """ - def __init__(self, ws: WorkspaceClient, disposition: Disposition | None = None): + def __init__(self, ws: WorkspaceClient, + disposition: Disposition | None = None, + warehouse_id: str | None = None, + byte_limit: int | None = None, + catalog: str | None = None, + schema: str | None = None, + timeout: timedelta = timedelta(minutes=20)): self._ws = ws self._api = ws.api_client self._http = requests.Session() self._lock = threading.Lock() + self._warehouse_id = warehouse_id + self._schema = schema + self._timeout = timeout + self._catalog = catalog + self._byte_limit = byte_limit self._disposition = disposition self._type_converters = { ColumnInfoTypeName.ARRAY: json.loads, @@ -153,16 +165,19 @@ def __init__(self, ws: WorkspaceClient, disposition: Disposition | None = None): @staticmethod def _parse_date(value: str) -> datetime.date: + """Parse date from string in ISO format.""" year, month, day = value.split('-') return datetime.date(int(year), int(month), int(day)) @staticmethod def _parse_timestamp(value: str) -> datetime.datetime: + """Parse timestamp from string in ISO format.""" # make it work with Python 3.7 to 3.10 as well return datetime.datetime.fromisoformat(value.replace('Z', '+00:00')) @staticmethod def _raise_if_needed(status: StatementStatus): + """Raise an exception if the statement status is failed, canceled, or closed.""" if status.state not in [StatementState.FAILED, StatementState.CANCELED, StatementState.CLOSED]: return status_error = status.error @@ -202,8 +217,14 @@ def _raise_if_needed(status: StatementStatus): raise error_class(error_message) def _default_warehouse(self) -> str: + """Get the default warehouse id from the workspace client configuration + or DATABRICKS_WAREHOUSE_ID environment variable. If not set, it will use + the first available warehouse that is not in the DELETED or DELETING state.""" with self._lock: + if self._warehouse_id: + return self._warehouse_id if self._ws.config.warehouse_id: + self._warehouse_id = self._ws.config.warehouse_id return self._ws.config.warehouse_id ids = [] for v in self._ws.warehouses.list(): @@ -227,8 +248,7 @@ def execute( byte_limit: int | None = None, catalog: str | None = None, schema: str | None = None, - timeout: timedelta = timedelta(minutes=20), - disposition: Disposition | None = None, + timeout: timedelta | None = None, ) -> ExecuteStatementResponse: """(Experimental) Execute a SQL statement and block until results are ready, including starting the warehouse if needed. @@ -264,26 +284,20 @@ def execute( """ # The wait_timeout field must be 0 seconds (disables wait), # or between 5 seconds and 50 seconds. - wait_timeout = None - if MIN_PLATFORM_TIMEOUT <= timeout.total_seconds() <= MAX_PLATFORM_TIMEOUT: - # set server-side timeout - wait_timeout = f"{timeout.total_seconds()}s" - if not warehouse_id: - warehouse_id = self._default_warehouse() - + timeout, wait_timeout = self._statement_timeouts(timeout) logger.debug(f"Executing SQL statement: {statement}") # technically, we can do Disposition.EXTERNAL_LINKS, but let's push it further away. # format is limited to Format.JSON_ARRAY, but other iterations may include ARROW_STREAM. immediate_response = self._ws.statement_execution.execute_statement( - warehouse_id=warehouse_id, statement=statement, - catalog=catalog, - schema=schema, - disposition=disposition, - format=Format.JSON_ARRAY, - byte_limit=byte_limit, + warehouse_id=warehouse_id or self._default_warehouse(), + catalog=catalog or self._catalog, + schema=schema or self._schema, + disposition=self._disposition, + byte_limit=byte_limit or self._byte_limit, wait_timeout=wait_timeout, + format=Format.JSON_ARRAY, ) status = immediate_response.status @@ -324,10 +338,19 @@ def execute( msg = f"timed out after {timeout}: {status_message}" raise TimeoutError(msg) - def __call__(self, *args, **kwargs): - yield from self.execute_fetch_all(*args, **kwargs) + def _statement_timeouts(self, timeout): + if timeout is None: + timeout = self._timeout + wait_timeout = None + if MIN_PLATFORM_TIMEOUT <= timeout.total_seconds() <= MAX_PLATFORM_TIMEOUT: + # set server-side timeout + wait_timeout = f"{timeout.total_seconds()}s" + return timeout, wait_timeout - def execute_fetch_all( + def __call__(self, statement: str): + yield from self.fetch_all(statement) + + def fetch_all( self, statement: str, *, @@ -335,9 +358,9 @@ def execute_fetch_all( byte_limit: int | None = None, catalog: str | None = None, schema: str | None = None, - timeout: timedelta = timedelta(minutes=20), + timeout: timedelta | None = None, ) -> Iterator[Row]: - """(Experimental) Execute a query and iterate over all available records. + """Execute a query and iterate over all available records. This method is a wrapper over :py:meth:`execute` with the handling of chunked result processing and deserialization of those into separate rows, which are yielded from @@ -376,8 +399,7 @@ def execute_fetch_all( byte_limit=byte_limit, catalog=catalog, schema=schema, - timeout=timeout, - disposition=self._disposition) + timeout=timeout) result_data = execute_response.result if result_data is None: return [] @@ -402,6 +424,51 @@ def execute_fetch_all( execute_response.statement_id, next_chunk_index) + def fetch_one(self, statement: str, **kwargs) -> Row | None: + """Execute a query and fetch the first available record. + + This method is a wrapper over :py:meth:`fetch_all` and fetches only the first row + from the result set. If no records are available, it returns `None`. + + .. code-block:: + + row = see.fetch_one('SELECT * FROM samples.nyctaxi.trips LIMIT 1') + if row: + pickup_time, dropoff_time = row[0], row[1] + pickup_zip = row.pickup_zip + dropoff_zip = row['dropoff_zip'] + all_fields = row.as_dict() + print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') + + :param statement: str + SQL statement to execute + :return: Row | None + """ + statement = self._add_limit(statement) + for row in self.fetch_all(statement, **kwargs): + return row + return None + + def fetch_value(self, statement: str, **kwargs) -> Any | None: + """Execute a query and fetch the first available value.""" + for v, in self.fetch_all(statement, **kwargs): + return v + return None + + def _add_limit(self, statement: str) -> str: + # parse with sqlglot if there's a limit statement + statements = sqlglot.parse(statement, read="databricks") + if not statements: + raise ValueError(f"cannot parse statement: {statement}") + statement_ast = statements[0] + if isinstance(statement_ast, sqlglot.expressions.Select): + if statement_ast.limit is not None: + limit_text = statement_ast.args.get('limit').text('expression') + if limit_text != '1': + raise ValueError(f"limit is not 1: {limit_text}") + return statement_ast.limit(expression=1).sql('databricks') + return statement + def _result_schema(self, execute_response: ExecuteStatementResponse): manifest = execute_response.manifest if not manifest: diff --git a/src/databricks/labs/lsql/sql_backend.py b/src/databricks/labs/lsql/sql_backend.py index 2182ad35..c659bc46 100644 --- a/src/databricks/labs/lsql/sql_backend.py +++ b/src/databricks/labs/lsql/sql_backend.py @@ -17,7 +17,7 @@ Unknown, ) -from databricks.labs.lsql.lib import Row, StatementExecutionExt +from databricks.labs.lsql.core import Row, StatementExecutionExt logger = logging.getLogger(__name__) @@ -136,7 +136,7 @@ def execute(self, sql: str) -> None: def fetch(self, sql: str) -> Iterator[Row]: logger.debug(f"[api][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") - return self._sql.execute_fetch_all(self._warehouse_id, sql) + return self._sql.fetch_all(self._warehouse_id, sql) def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"): if mode == "overwrite": diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index b8b61fe5..d46a5952 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -4,8 +4,9 @@ import pathlib import string import sys -from typing import MutableMapping +from typing import MutableMapping, Callable +import pytest from databricks.sdk import WorkspaceClient from pytest import fixture @@ -76,3 +77,16 @@ def ws(product_info, debug_env) -> WorkspaceClient: product_name, product_version = product_info return WorkspaceClient(host=debug_env["DATABRICKS_HOST"], product=product_name, product_version=product_version) + +@pytest.fixture +def env_or_skip(debug_env) -> Callable[[str], str]: + skip = pytest.skip + if _is_in_debug(): + skip = pytest.fail # type: ignore[assignment] + + def inner(var: str) -> str: + if var not in debug_env: + skip(f"Environment variable {var} is missing") + return debug_env[var] + + return inner \ No newline at end of file diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index 22b4cdd0..d37b3df3 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -1,10 +1,12 @@ -import functools +import logging import pytest -from databricks.labs.lsql.lib import StatementExecutionExt +from databricks.labs.lsql.core import StatementExecutionExt from databricks.sdk.service.sql import Disposition +logger = logging.getLogger(__name__) + @pytest.mark.parametrize("disposition", [None, Disposition.INLINE, Disposition.EXTERNAL_LINKS]) def test_sql_execution_chunked(ws, disposition): @@ -15,20 +17,70 @@ def test_sql_execution_chunked(ws, disposition): assert total == 1999999000000 -def test_sql_execution(w, env_or_skip): - warehouse_id = env_or_skip("TEST_DEFAULT_WAREHOUSE_ID") - for pickup_zip, dropoff_zip in w.statement_execution.iterate_rows( - warehouse_id, "SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10", catalog="samples" +def test_sql_execution(ws, env_or_skip): + results = [] + see = StatementExecutionExt(ws, warehouse_id=env_or_skip("TEST_DEFAULT_WAREHOUSE_ID")) + for pickup_zip, dropoff_zip in see.fetch_all( + "SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10", + catalog="samples" ): - print(f"pickup_zip={pickup_zip}, dropoff_zip={dropoff_zip}") + results.append((pickup_zip, dropoff_zip)) + assert results == [ + (10282, 10171), + (10110, 10110), + (10103, 10023), + (10022, 10017), + (10110, 10282), + (10009, 10065), + (10153, 10199), + (10112, 10069), + (10023, 10153), + (10012, 10003) + ] -def test_sql_execution_partial(w, env_or_skip): - warehouse_id = env_or_skip("TEST_DEFAULT_WAREHOUSE_ID") - iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog="samples") - for row in iterate_rows("SELECT * FROM nyctaxi.trips LIMIT 10"): +def test_sql_execution_partial(ws, env_or_skip): + results = [] + see = StatementExecutionExt(ws, warehouse_id=env_or_skip("TEST_DEFAULT_WAREHOUSE_ID"), catalog="samples") + for row in see("SELECT * FROM nyctaxi.trips LIMIT 10"): pickup_time, dropoff_time = row[0], row[1] pickup_zip = row.pickup_zip dropoff_zip = row["dropoff_zip"] all_fields = row.as_dict() - print(f"{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}") + logger.info(f"{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}") + results.append((pickup_zip, dropoff_zip)) + assert results == [ + (10282, 10171), + (10110, 10110), + (10103, 10023), + (10022, 10017), + (10110, 10282), + (10009, 10065), + (10153, 10199), + (10112, 10069), + (10023, 10153), + (10012, 10003) + ] + + +def test_fetch_one(ws): + see = StatementExecutionExt(ws) + assert see.fetch_one("SELECT 1") == (1,) + + +def test_fetch_one_fails_if_limit_is_bigger(ws): + see = StatementExecutionExt(ws) + with pytest.raises(ValueError): + see.fetch_one("SELECT * FROM samples.nyctaxi.trips LIMIT 100") + + +def test_fetch_one_works(ws): + see = StatementExecutionExt(ws) + row = see.fetch_one("SELECT * FROM samples.nyctaxi.trips LIMIT 1") + assert row.pickup_zip == 10282 + + +def test_fetch_value(ws): + see = StatementExecutionExt(ws) + count = see.fetch_value("SELECT COUNT(*) FROM samples.nyctaxi.trips") + assert count == 21932 \ No newline at end of file diff --git a/tests/unit/test_lib.py b/tests/unit/test_lib.py index e26036a1..85e1eeaf 100644 --- a/tests/unit/test_lib.py +++ b/tests/unit/test_lib.py @@ -21,7 +21,7 @@ timedelta, ) -from databricks.labs.lsql.lib import Row +from databricks.labs.lsql.core import Row @pytest.mark.parametrize('row', [ From 6e5b3f582f908d1bf8b81dd5c22954cf9bc12804 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Sun, 10 Mar 2024 19:43:10 +0100 Subject: [PATCH 06/10] .. --- tests/unit/test_lib.py | 41 ++++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/tests/unit/test_lib.py b/tests/unit/test_lib.py index 85e1eeaf..1e9173b8 100644 --- a/tests/unit/test_lib.py +++ b/tests/unit/test_lib.py @@ -1,4 +1,5 @@ import datetime +from unittest.mock import create_autospec import pytest from databricks.sdk import WorkspaceClient @@ -21,7 +22,7 @@ timedelta, ) -from databricks.labs.lsql.core import Row +from databricks.labs.lsql.core import Row, StatementExecutionExt @pytest.mark.parametrize('row', [ @@ -59,42 +60,36 @@ def test_row_factory_with_generator(): assert b == 2 -def test_execute_poll_succeeds(config, mocker): - w = WorkspaceClient(config=config) - - execute_statement = mocker.patch( - "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.PENDING), - statement_id="bcd", - ), +def test_execute_poll_succeeds(): + ws = create_autospec(WorkspaceClient) + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.PENDING), + statement_id="bcd", ) - - get_statement = mocker.patch( - "databricks.sdk.service.sql.StatementExecutionAPI.get_statement", - return_value=GetStatementResponse( - manifest=ResultManifest(), - result=ResultData(byte_count=100500), - statement_id="bcd", - status=StatementStatus(state=StatementState.SUCCEEDED), - ), + ws.statement_execution.get_statement.return_value = GetStatementResponse( + manifest=ResultManifest(), + result=ResultData(byte_count=100500), + statement_id="bcd", + status=StatementStatus(state=StatementState.SUCCEEDED), ) - response = w.statement_execution.execute("abc", "SELECT 2+2") + see = StatementExecutionExt(ws) + + response = see.execute("SELECT 2+2", warehouse_id="abc") assert response.status.state == StatementState.SUCCEEDED assert response.result.byte_count == 100500 - execute_statement.assert_called_with( + ws.statement_execution.execute_statement.assert_called_with( warehouse_id="abc", statement="SELECT 2+2", - disposition=Disposition.EXTERNAL_LINKS, format=Format.JSON_ARRAY, + disposition=None, byte_limit=None, catalog=None, schema=None, wait_timeout=None, ) - get_statement.assert_called_with("bcd") + ws.statement_execution.get_statement.assert_called_with("bcd") def test_execute_fails(config, mocker): From 00891a3babb798c5ec3d105c39e1828dfbc92154 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Mon, 11 Mar 2024 13:13:17 +0100 Subject: [PATCH 07/10] .. --- Makefile | 2 +- NOTICE | 2 +- pyproject.toml | 8 + src/databricks/labs/lsql/core.py | 55 ++-- tests/{ => unit}/__init__.py | 0 tests/unit/test_core.py | 488 +++++++++++++++++++++++++++++++ tests/unit/test_lib.py | 276 ----------------- 7 files changed, 531 insertions(+), 300 deletions(-) rename tests/{ => unit}/__init__.py (100%) create mode 100644 tests/unit/test_core.py delete mode 100644 tests/unit/test_lib.py diff --git a/Makefile b/Makefile index 4b9667fe..c2b0aa79 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ clean: rm -fr **/*.pyc .venv/bin/python: - pip install hatch + pip install hatch==1.7.0 hatch env create dev: .venv/bin/python diff --git a/NOTICE b/NOTICE index 04a26a7c..605bb201 100644 --- a/NOTICE +++ b/NOTICE @@ -1,4 +1,4 @@ -Copyright (2023) Databricks, Inc. +Copyright (2024) Databricks, Inc. This Software includes software developed at Databricks (https://www.databricks.com/) and its use is subject to the included LICENSE file. diff --git a/pyproject.toml b/pyproject.toml index eb8e45be..c39b6e4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,14 @@ Documentation = "https://github.com/databrickslabs/databricks-labs-lsql#readme" Issues = "https://github.com/databrickslabs/databricks-labs-lsql/issues" Source = "https://github.com/databrickslabs/databricks-labs-lsql" +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +sources = ["src"] +include = ["src"] + [tool.hatch.version] path = "src/databricks/labs/lsql/__about__.py" diff --git a/src/databricks/labs/lsql/core.py b/src/databricks/labs/lsql/core.py index eed8699e..9f5f9979 100644 --- a/src/databricks/labs/lsql/core.py +++ b/src/databricks/labs/lsql/core.py @@ -9,7 +9,7 @@ import types from collections.abc import Iterator from datetime import timedelta -from typing import Any +from typing import Any, Callable import requests import sqlglot @@ -134,15 +134,20 @@ def __init__(self, ws: WorkspaceClient, byte_limit: int | None = None, catalog: str | None = None, schema: str | None = None, - timeout: timedelta = timedelta(minutes=20)): + timeout: timedelta = timedelta(minutes=20), + disable_magic: bool = False, + http_session_factory: Callable[[], requests.Session] | None = None): + if not http_session_factory: + http_session_factory = requests.Session self._ws = ws self._api = ws.api_client - self._http = requests.Session() + self._http = http_session_factory() self._lock = threading.Lock() self._warehouse_id = warehouse_id self._schema = schema self._timeout = timeout self._catalog = catalog + self._disable_magic = disable_magic self._byte_limit = byte_limit self._disposition = disposition self._type_converters = { @@ -223,7 +228,8 @@ def _default_warehouse(self) -> str: with self._lock: if self._warehouse_id: return self._warehouse_id - if self._ws.config.warehouse_id: + # if we create_autospec(WorkspaceClient), the warehouse_id is a MagicMock + if isinstance(self._ws.config.warehouse_id, str) and self._ws.config.warehouse_id: self._warehouse_id = self._ws.config.warehouse_id return self._ws.config.warehouse_id ids = [] @@ -234,11 +240,13 @@ def _default_warehouse(self) -> str: self._ws.config.warehouse_id = v.id return self._ws.config.warehouse_id ids.append(v.id) - if self._ws.config.warehouse_id == "" and len(ids) > 0: + if len(ids) > 0: # otherwise - first warehouse self._ws.config.warehouse_id = ids[0] return self._ws.config.warehouse_id - raise ValueError("no warehouse id given") + raise ValueError("no warehouse_id=... given, " + "neither it is set in the WorkspaceClient(..., warehouse_id=...), " + "nor in the DATABRICKS_WAREHOUSE_ID environment variable") def execute( self, @@ -407,24 +415,22 @@ def fetch_all( while True: if result_data.data_array: for data in result_data.data_array: - # enumerate() + iterator + tuple constructor makes it more performant - # on larger humber of records for Python, even though it's less - # readable code. yield row_factory(col_conv[i](value) for i, value in enumerate(data)) next_chunk_index = result_data.next_chunk_index - for external_link in result_data.external_links: - next_chunk_index = external_link.next_chunk_index - response = self._http.get(external_link.external_link) - response.raise_for_status() - for data in response.json(): - yield row_factory(col_conv[i](value) for i, value in enumerate(data)) + if result_data.external_links: + for external_link in result_data.external_links: + next_chunk_index = external_link.next_chunk_index + response = self._http.get(external_link.external_link) + response.raise_for_status() + for data in response.json(): + yield row_factory(col_conv[i](value) for i, value in enumerate(data)) if not next_chunk_index: return result_data = self._ws.statement_execution.get_statement_result_chunk_n( execute_response.statement_id, next_chunk_index) - def fetch_one(self, statement: str, **kwargs) -> Row | None: + def fetch_one(self, statement: str, disable_magic: bool = False, **kwargs) -> Row | None: """Execute a query and fetch the first available record. This method is a wrapper over :py:meth:`fetch_all` and fetches only the first row @@ -441,10 +447,14 @@ def fetch_one(self, statement: str, **kwargs) -> Row | None: print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') :param statement: str - SQL statement to execute + SQL statement to execute + :param disable_magic: bool (optional) + Disables the magic of adding `LIMIT 1` to the statement. By default, it is `False`. :return: Row | None """ - statement = self._add_limit(statement) + disable_magic = disable_magic or self._disable_magic + if not disable_magic: + statement = self._add_limit(statement) for row in self.fetch_all(statement, **kwargs): return row return None @@ -455,7 +465,8 @@ def fetch_value(self, statement: str, **kwargs) -> Any | None: return v return None - def _add_limit(self, statement: str) -> str: + @staticmethod + def _add_limit(statement: str) -> str: # parse with sqlglot if there's a limit statement statements = sqlglot.parse(statement, read="databricks") if not statements: @@ -463,9 +474,9 @@ def _add_limit(self, statement: str) -> str: statement_ast = statements[0] if isinstance(statement_ast, sqlglot.expressions.Select): if statement_ast.limit is not None: - limit_text = statement_ast.args.get('limit').text('expression') - if limit_text != '1': - raise ValueError(f"limit is not 1: {limit_text}") + limit = statement_ast.args.get('limit', None) + if limit and limit.text('expression') != '1': + raise ValueError(f"limit is not 1: {limit.text('expression')}") return statement_ast.limit(expression=1).sql('databricks') return statement diff --git a/tests/__init__.py b/tests/unit/__init__.py similarity index 100% rename from tests/__init__.py rename to tests/unit/__init__.py diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py new file mode 100644 index 00000000..d90b6c33 --- /dev/null +++ b/tests/unit/test_core.py @@ -0,0 +1,488 @@ +import datetime +import requests +from unittest.mock import create_autospec + +import pytest +from databricks.sdk import WorkspaceClient +from databricks.sdk import errors +from databricks.sdk.service.sql import ( + ColumnInfo, + ColumnInfoTypeName, + Disposition, + ExecuteStatementResponse, + ExternalLink, + Format, + GetStatementResponse, + ResultData, + ResultManifest, + ResultSchema, + ServiceError, + ServiceErrorCode, + StatementState, + StatementStatus, + timedelta, EndpointInfo, State, +) + +from databricks.labs.lsql.core import Row, StatementExecutionExt + + +@pytest.mark.parametrize('row', [ + Row(foo="bar", enabled=True), + Row(['foo', 'enabled'], ['bar', True]), +]) +def test_row_from_kwargs(row): + assert row.foo == "bar" + assert row["foo"] == "bar" + assert "foo" in row + assert len(row) == 2 + assert list(row) == ["bar", True] + assert row.as_dict() == {"foo": "bar", "enabled": True} + foo, enabled = row + assert foo == "bar" + assert enabled is True + assert str(row) == "Row(foo='bar', enabled=True)" + with pytest.raises(AttributeError): + print(row.x) + + +def test_row_factory(): + factory = Row.factory(["a", "b"]) + row = factory(1, 2) + a, b = row + assert a == 1 + assert b == 2 + + +def test_row_factory_with_generator(): + factory = Row.factory(["a", "b"]) + row = factory(_+1 for _ in range(2)) + a, b = row + assert a == 1 + assert b == 2 + + +def test_selects_warehouse_from_config(): + ws = create_autospec(WorkspaceClient) + ws.config.warehouse_id = "abc" + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws) + see.execute("SELECT 2+2") + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="SELECT 2+2", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + + +def test_selects_warehouse_from_existing_first_running(): + ws = create_autospec(WorkspaceClient) + + ws.warehouses.list.return_value = [ + EndpointInfo(id="abc", state=State.DELETING), + EndpointInfo(id="bcd", state=State.RUNNING), + EndpointInfo(id="cde", state=State.RUNNING), + ] + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + ) + + see = StatementExecutionExt(ws) + see.execute("SELECT 2+2") + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="bcd", + statement="SELECT 2+2", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + + +def test_selects_warehouse_from_existing_not_running(): + ws = create_autospec(WorkspaceClient) + + ws.warehouses.list.return_value = [ + EndpointInfo(id="efg", state=State.DELETING), + EndpointInfo(id="fgh", state=State.DELETED), + EndpointInfo(id="bcd", state=State.STOPPED), + EndpointInfo(id="cde", state=State.STARTING), + ] + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + ) + + see = StatementExecutionExt(ws) + see.execute("SELECT 2+2") + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="bcd", + statement="SELECT 2+2", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + + +def test_no_warehouse_given(): + ws = create_autospec(WorkspaceClient) + + see = StatementExecutionExt(ws) + + with pytest.raises(ValueError): + see.execute("SELECT 2+2") + + +def test_execute_poll_succeeds(): + ws = create_autospec(WorkspaceClient) + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.PENDING), + statement_id="bcd", + ) + ws.statement_execution.get_statement.return_value = GetStatementResponse( + manifest=ResultManifest(), + result=ResultData(byte_count=100500), + statement_id="bcd", + status=StatementStatus(state=StatementState.SUCCEEDED), + ) + + see = StatementExecutionExt(ws) + + response = see.execute("SELECT 2+2", warehouse_id="abc") + + assert response.status.state == StatementState.SUCCEEDED + assert response.result.byte_count == 100500 + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="SELECT 2+2", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + ws.statement_execution.get_statement.assert_called_with("bcd") + + +@pytest.mark.parametrize("status_error,platform_error_type", [ + (None, errors.Unknown), + (ServiceError(), errors.Unknown), + (ServiceError(message="..."), errors.Unknown), + (ServiceError(error_code=ServiceErrorCode.RESOURCE_EXHAUSTED, message="..."), errors.ResourceExhausted), + (ServiceError(message="... SCHEMA_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... TABLE_OR_VIEW_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... DELTA_TABLE_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... DELTA_TABLE_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... DELTA_MISSING_TRANSACTION_LOG ..."), errors.DataLoss), +]) +def test_execute_fails(status_error, platform_error_type): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.FAILED, error=status_error), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + with pytest.raises(platform_error_type): + see.execute("SELECT 2+2") + +def test_execute_poll_waits(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.PENDING), + statement_id="bcd", + ) + + ws.statement_execution.get_statement.side_effect = [ + GetStatementResponse(status=StatementStatus(state=StatementState.RUNNING), statement_id="bcd"), + GetStatementResponse( + manifest=ResultManifest(), + result=ResultData(byte_count=100500), + statement_id="bcd", + status=StatementStatus(state=StatementState.SUCCEEDED), + ), + ] + + see = StatementExecutionExt(ws, warehouse_id="abc") + + response = see.execute("SELECT 2+2") + + assert response.status.state == StatementState.SUCCEEDED + assert response.result.byte_count == 100500 + + +def test_execute_poll_timeouts_on_client(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.PENDING), + statement_id="bcd", + ) + + ws.statement_execution.get_statement.return_value = GetStatementResponse( + status=StatementStatus(state=StatementState.RUNNING), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + with pytest.raises(TimeoutError): + see.execute("SELECT 2+2", timeout=timedelta(seconds=1)) + + ws.statement_execution.cancel_execution.assert_called_with("bcd") + + +def test_fetch_all_no_chunks(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema( + columns=[ + ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT), + ColumnInfo(name="since", type_name=ColumnInfoTypeName.DATE), + ColumnInfo(name="now", type_name=ColumnInfoTypeName.TIMESTAMP), + ] + ) + ), + result=ResultData(external_links=[ExternalLink(external_link="https://singed-url")]), + statement_id="bcd", + ) + + http_session = create_autospec(requests.Session) + http_session.get("https://singed-url").json.return_value = [ + ["1", "2023-09-01", "2023-09-01T13:21:53Z"], + ["2", "2023-09-01", "2023-09-01T13:21:53Z"], + ] + + see = StatementExecutionExt(ws, warehouse_id="abc", http_session_factory=lambda: http_session) + + rows = list( + see.fetch_all("SELECT id, CAST(NOW() AS DATE) AS since, NOW() AS now FROM range(2)") + ) + + assert len(rows) == 2 + assert rows[0].id == 1 + assert isinstance(rows[0].since, datetime.date) + assert isinstance(rows[0].now, datetime.datetime) + + http_session.get.assert_called_with("https://singed-url") + + +def test_fetch_all_no_chunks_no_converter(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema( + columns=[ + ColumnInfo(name="id", type_name=ColumnInfoTypeName.INTERVAL), + ] + ) + ), + result=ResultData(external_links=[ExternalLink(external_link="https://singed-url")]), + statement_id="bcd", + ) + + http_session = create_autospec(requests.Session) + http_session.get("https://singed-url").json.return_value = [["1"], ["2"]] + + see = StatementExecutionExt(ws, warehouse_id="abc", http_session_factory=lambda: http_session) + + with pytest.raises(ValueError, match="id has no INTERVAL converter"): + list(see.fetch_all("SELECT id FROM range(2)")) + + +def test_fetch_all_two_chunks(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema( + columns=[ + ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT), + ColumnInfo(name="now", type_name=ColumnInfoTypeName.TIMESTAMP), + ] + ) + ), + result=ResultData(external_links=[ExternalLink(external_link="https://first", next_chunk_index=1)]), + statement_id="bcd", + ) + + ws.statement_execution.get_statement_result_chunk_n.return_value = ResultData( + external_links=[ExternalLink(external_link="https://second")] + ) + + http_session = create_autospec(requests.Session) + http_session.get(...).json.side_effect = [ + # https://first + [["1", "2023-09-01T13:21:53Z"], ["2", "2023-09-01T13:21:53Z"]], + # https://second + [["3", "2023-09-01T13:21:53Z"], ["4", "2023-09-01T13:21:53Z"]], + ] + + see = StatementExecutionExt(ws, warehouse_id="abc", http_session_factory=lambda: http_session) + + rows = list(see.fetch_all("SELECT id, NOW() AS now FROM range(4)")) + assert len(rows) == 4 + assert [_.id for _ in rows] == [1, 2, 3, 4] + + ws.statement_execution.get_statement_result_chunk_n.assert_called_with("bcd", 1) + + +def test_fetch_one(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) + ), + result=ResultData(data_array=[["4"]]), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + row = see.fetch_one("SELECT 2+2 AS id") + + assert row.id == 4 + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="SELECT 2 + 2 AS id LIMIT 1", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + + +def test_fetch_one_none(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) + ), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + row = see.fetch_one("SELECT 2+2 AS id") + + assert row is None + + +def test_fetch_one_disable_magic(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) + ), + result=ResultData(data_array=[["4"], ["5"], ["6"]]), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + row = see.fetch_one("SELECT 2+2 AS id", disable_magic=True) + + assert row.id == 4 + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="SELECT 2+2 AS id", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + + +def test_fetch_value(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) + ), + result=ResultData(data_array=[["4"]]), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + value = see.fetch_value("SELECT 2+2 AS id") + + assert value == 4 + + +def test_fetch_value_none(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) + ), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + value = see.fetch_value("SELECT 2+2 AS id") + + assert value is None + + +def test_callable_returns_iterator(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) + ), + result=ResultData(data_array=[["4"], ["5"], ["6"]]), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + rows = list(see("SELECT 2+2 AS id")) + + assert len(rows) == 3 + assert rows == [Row(id=4), Row(id=5), Row(id=6)] \ No newline at end of file diff --git a/tests/unit/test_lib.py b/tests/unit/test_lib.py deleted file mode 100644 index 1e9173b8..00000000 --- a/tests/unit/test_lib.py +++ /dev/null @@ -1,276 +0,0 @@ -import datetime -from unittest.mock import create_autospec - -import pytest -from databricks.sdk import WorkspaceClient -from databricks.sdk.core import DatabricksError -from databricks.sdk.service.sql import ( - ColumnInfo, - ColumnInfoTypeName, - Disposition, - ExecuteStatementResponse, - ExternalLink, - Format, - GetStatementResponse, - ResultData, - ResultManifest, - ResultSchema, - ServiceError, - ServiceErrorCode, - StatementState, - StatementStatus, - timedelta, -) - -from databricks.labs.lsql.core import Row, StatementExecutionExt - - -@pytest.mark.parametrize('row', [ - Row(foo="bar", enabled=True), - Row(['foo', 'enabled'], ['bar', True]), -]) -def test_row_from_kwargs(row): - assert row.foo == "bar" - assert row["foo"] == "bar" - assert "foo" in row - assert len(row) == 2 - assert list(row) == ["bar", True] - assert row.as_dict() == {"foo": "bar", "enabled": True} - foo, enabled = row - assert foo == "bar" - assert enabled is True - assert str(row) == "Row(foo='bar', enabled=True)" - with pytest.raises(AttributeError): - print(row.x) - - -def test_row_factory(): - factory = Row.factory(["a", "b"]) - row = factory(1, 2) - a, b = row - assert a == 1 - assert b == 2 - - -def test_row_factory_with_generator(): - factory = Row.factory(["a", "b"]) - row = factory(_+1 for _ in range(2)) - a, b = row - assert a == 1 - assert b == 2 - - -def test_execute_poll_succeeds(): - ws = create_autospec(WorkspaceClient) - ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( - status=StatementStatus(state=StatementState.PENDING), - statement_id="bcd", - ) - ws.statement_execution.get_statement.return_value = GetStatementResponse( - manifest=ResultManifest(), - result=ResultData(byte_count=100500), - statement_id="bcd", - status=StatementStatus(state=StatementState.SUCCEEDED), - ) - - see = StatementExecutionExt(ws) - - response = see.execute("SELECT 2+2", warehouse_id="abc") - - assert response.status.state == StatementState.SUCCEEDED - assert response.result.byte_count == 100500 - ws.statement_execution.execute_statement.assert_called_with( - warehouse_id="abc", - statement="SELECT 2+2", - format=Format.JSON_ARRAY, - disposition=None, - byte_limit=None, - catalog=None, - schema=None, - wait_timeout=None, - ) - ws.statement_execution.get_statement.assert_called_with("bcd") - - -def test_execute_fails(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch( - "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", - return_value=ExecuteStatementResponse( - statement_id="bcd", - status=StatementStatus( - state=StatementState.FAILED, - error=ServiceError(error_code=ServiceErrorCode.RESOURCE_EXHAUSTED, message="oops..."), - ), - ), - ) - - with pytest.raises(DatabricksError) as err: - w.statement_execution.execute("abc", "SELECT 2+2") - assert err.value.error_code == "RESOURCE_EXHAUSTED" - assert str(err.value) == "oops..." - - -def test_execute_poll_waits(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch( - "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.PENDING), - statement_id="bcd", - ), - ) - - runs = [] - - def _get_statement(statement_id): - assert statement_id == "bcd" - if len(runs) == 0: - runs.append(1) - return GetStatementResponse(status=StatementStatus(state=StatementState.RUNNING), statement_id="bcd") - - return GetStatementResponse( - manifest=ResultManifest(), - result=ResultData(byte_count=100500), - statement_id="bcd", - status=StatementStatus(state=StatementState.SUCCEEDED), - ) - - mocker.patch("databricks.sdk.service.sql.StatementExecutionAPI.get_statement", wraps=_get_statement) - - response = w.statement_execution.execute("abc", "SELECT 2+2") - - assert response.status.state == StatementState.SUCCEEDED - assert response.result.byte_count == 100500 - - -def test_execute_poll_timeouts_on_client(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch( - "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.PENDING), - statement_id="bcd", - ), - ) - - mocker.patch( - "databricks.sdk.service.sql.StatementExecutionAPI.get_statement", - return_value=GetStatementResponse(status=StatementStatus(state=StatementState.RUNNING), statement_id="bcd"), - ) - - cancel_execution = mocker.patch("databricks.sdk.service.sql.StatementExecutionAPI.cancel_execution") - - with pytest.raises(TimeoutError): - w.statement_execution.execute("abc", "SELECT 2+2", timeout=timedelta(seconds=1)) - - cancel_execution.assert_called_with("bcd") - - -def test_fetch_all_no_chunks(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch( - "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest( - schema=ResultSchema( - columns=[ - ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT), - ColumnInfo(name="since", type_name=ColumnInfoTypeName.DATE), - ColumnInfo(name="now", type_name=ColumnInfoTypeName.TIMESTAMP), - ] - ) - ), - result=ResultData(external_links=[ExternalLink(external_link="https://singed-url")]), - statement_id="bcd", - ), - ) - - raw_response = mocker.Mock() - raw_response.json = lambda: [ - ["1", "2023-09-01", "2023-09-01T13:21:53Z"], - ["2", "2023-09-01", "2023-09-01T13:21:53Z"], - ] - - http_get = mocker.patch("requests.sessions.Session.get", return_value=raw_response) - - rows = list( - w.statement_execution.iterate_rows("abc", "SELECT id, CAST(NOW() AS DATE) AS since, NOW() AS now FROM range(2)") - ) - - assert len(rows) == 2 - assert rows[0].id == 1 - assert isinstance(rows[0].since, datetime.date) - assert isinstance(rows[0].now, datetime.datetime) - - http_get.assert_called_with("https://singed-url") - - -def test_fetch_all_no_chunks_no_converter(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch( - "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest( - schema=ResultSchema( - columns=[ - ColumnInfo(name="id", type_name=ColumnInfoTypeName.INTERVAL), - ] - ) - ), - result=ResultData(external_links=[ExternalLink(external_link="https://singed-url")]), - statement_id="bcd", - ), - ) - - raw_response = mocker.Mock() - raw_response.json = lambda: [["1"], ["2"]] - - mocker.patch("requests.sessions.Session.get", return_value=raw_response) - - with pytest.raises(ValueError): - list(w.statement_execution.iterate_rows("abc", "SELECT id FROM range(2)")) - - -def test_fetch_all_two_chunks(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch( - "databricks.sdk.service.sql.StatementExecutionAPI.execute_statement", - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest( - schema=ResultSchema( - columns=[ - ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT), - ColumnInfo(name="now", type_name=ColumnInfoTypeName.TIMESTAMP), - ] - ) - ), - result=ResultData(external_links=[ExternalLink(external_link="https://first", next_chunk_index=1)]), - statement_id="bcd", - ), - ) - - next_chunk = mocker.patch( - "databricks.sdk.service.sql.StatementExecutionAPI.get_statement_result_chunk_n", - return_value=ResultData(external_links=[ExternalLink(external_link="https://second")]), - ) - - raw_response = mocker.Mock() - raw_response.json = lambda: [["1", "2023-09-01T13:21:53Z"], ["2", "2023-09-01T13:21:53Z"]] - http_get = mocker.patch("requests.sessions.Session.get", return_value=raw_response) - - rows = list(w.statement_execution.iterate_rows("abc", "SELECT id, NOW() AS now FROM range(2)")) - - assert len(rows) == 4 - - assert http_get.call_args_list == [mocker.call("https://first"), mocker.call("https://second")] - next_chunk.assert_called_with("bcd", 1) From 9fbaf7e711dd6d19c16c764bc660178e50458e57 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Mon, 11 Mar 2024 14:45:44 +0100 Subject: [PATCH 08/10] .. --- .../labs/lsql/{sql_backend.py => backends.py} | 60 +++- src/databricks/labs/lsql/deployment.py | 2 +- tests/unit/conftest.py | 6 + tests/unit/test_backends.py | 330 ++++++++++++++++++ tests/unit/test_deployment.py | 45 +++ tests/unit/views/__init__.py | 0 tests/unit/views/some.sql | 1 + 7 files changed, 426 insertions(+), 18 deletions(-) rename src/databricks/labs/lsql/{sql_backend.py => backends.py} (83%) create mode 100644 tests/unit/conftest.py create mode 100644 tests/unit/test_backends.py create mode 100644 tests/unit/test_deployment.py create mode 100644 tests/unit/views/__init__.py create mode 100644 tests/unit/views/some.sql diff --git a/src/databricks/labs/lsql/sql_backend.py b/src/databricks/labs/lsql/backends.py similarity index 83% rename from src/databricks/labs/lsql/sql_backend.py rename to src/databricks/labs/lsql/backends.py index c659bc46..9ae17b97 100644 --- a/src/databricks/labs/lsql/sql_backend.py +++ b/src/databricks/labs/lsql/backends.py @@ -123,7 +123,7 @@ def _api_error_from_message(error_message: str) -> DatabricksError: class StatementExecutionBackend(SqlBackend): def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: int = 1000): - self._sql = StatementExecutionExt(ws) + self._sql = StatementExecutionExt(ws, warehouse_id=warehouse_id) self._warehouse_id = warehouse_id self._max_records_per_batch = max_records_per_batch debug_truncate_bytes = ws.config.debug_truncate_bytes @@ -132,11 +132,11 @@ def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: def execute(self, sql: str) -> None: logger.debug(f"[api][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") - self._sql.execute(self._warehouse_id, sql) + self._sql.execute(sql) def fetch(self, sql: str) -> Iterator[Row]: logger.debug(f"[api][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") - return self._sql.fetch_all(self._warehouse_id, sql) + return self._sql.fetch_all(sql) def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"): if mode == "overwrite": @@ -155,7 +155,7 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D self.execute(sql) @staticmethod - def _row_to_sql(row, fields): + def _row_to_sql(row: DataclassInstance, fields: list[dataclasses.Field]): data = [] for f in fields: value = getattr(row, f.name) @@ -177,17 +177,10 @@ def _row_to_sql(row, fields): return ", ".join(data) -class RuntimeBackend(SqlBackend): - def __init__(self, debug_truncate_bytes: int | None = None): - # pylint: disable-next=import-error,import-outside-toplevel - from pyspark.sql.session import SparkSession # type: ignore[import-not-found] - - if "DATABRICKS_RUNTIME_VERSION" not in os.environ: - msg = "Not in the Databricks Runtime" - raise RuntimeError(msg) - - self._spark = SparkSession.builder.getOrCreate() - self._debug_truncate_bytes = debug_truncate_bytes if debug_truncate_bytes else 96 +class _SparkBackend(SqlBackend): + def __init__(self, spark, debug_truncate_bytes): + self._spark = spark + self._debug_truncate_bytes = debug_truncate_bytes if debug_truncate_bytes is not None else 96 def execute(self, sql: str) -> None: logger.debug(f"[spark][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") @@ -216,8 +209,31 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D df.write.saveAsTable(full_name, mode=mode) +class RuntimeBackend(_SparkBackend): + def __init__(self, debug_truncate_bytes: int | None = None): + if "DATABRICKS_RUNTIME_VERSION" not in os.environ: + msg = "Not in the Databricks Runtime" + raise RuntimeError(msg) + try: + # pylint: disable-next=import-error,import-outside-toplevel + from pyspark.sql.session import SparkSession # type: ignore[import-not-found] + super().__init__(SparkSession.builder.getOrCreate(), debug_truncate_bytes) + except ImportError as e: + raise RuntimeError("pyspark is not available") from e + + +class DatabricksConnectBackend(_SparkBackend): + def __init__(self, ws: WorkspaceClient): + try: + from databricks.connect import DatabricksSession + spark = DatabricksSession.builder().sdk_config(ws.config).getOrCreate() + super().__init__(spark, ws.config.debug_truncate_bytes) + except ImportError as e: + raise RuntimeError("Please run `pip install databricks-connect`") from e + + class MockBackend(SqlBackend): - def __init__(self, *, fails_on_first: dict | None = None, rows: dict | None = None, debug_truncate_bytes=96): + def __init__(self, *, fails_on_first: dict[str,str] | None = None, rows: dict | None = None, debug_truncate_bytes=96): self._fails_on_first = fails_on_first if not rows: rows = {} @@ -250,8 +266,14 @@ def fetch(self, sql) -> Iterator[Row]: logger.debug(f"Returning rows: {rows}") return iter(rows) - def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass, mode: str = "append"): + def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"): + if mode == "overwrite": + msg = "Overwrite mode is not yet supported" + raise NotImplementedError(msg) + rows = self._filter_none_rows(rows, klass) if klass.__class__ == type: + row_factory = self._row_factory(klass) + rows = [row_factory(*dataclasses.astuple(r)) for r in rows] self._save_table.append((full_name, rows, mode)) def rows_written_for(self, full_name: str, mode: str) -> list[DataclassInstance]: @@ -261,3 +283,7 @@ def rows_written_for(self, full_name: str, mode: str) -> list[DataclassInstance] continue rows += stub_rows return rows + + @staticmethod + def _row_factory(klass: Dataclass) -> type: + return Row.factory([f.name for f in dataclasses.fields(klass)]) \ No newline at end of file diff --git a/src/databricks/labs/lsql/deployment.py b/src/databricks/labs/lsql/deployment.py index 56846347..7dd42a67 100644 --- a/src/databricks/labs/lsql/deployment.py +++ b/src/databricks/labs/lsql/deployment.py @@ -2,7 +2,7 @@ import pkgutil from typing import Any -from databricks.labs.lsql.sql_backend import Dataclass, SqlBackend +from databricks.labs.lsql.backends import Dataclass, SqlBackend logger = logging.getLogger(__name__) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 00000000..368c7c1d --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,6 @@ +import logging + +from databricks.labs.blueprint.logger import install_logger + +install_logger() +logging.getLogger("databricks").setLevel("DEBUG") diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py new file mode 100644 index 00000000..7608f969 --- /dev/null +++ b/tests/unit/test_backends.py @@ -0,0 +1,330 @@ +import os +import sys +from dataclasses import dataclass +from unittest import mock +from unittest.mock import create_autospec, MagicMock + +import pytest +from databricks.labs.lsql import Row +from databricks.labs.lsql.backends import StatementExecutionBackend, RuntimeBackend, MockBackend +from databricks.sdk import WorkspaceClient +from databricks.sdk.service.sql import ExecuteStatementResponse, StatementStatus, StatementState, Disposition, Format, \ + ResultManifest, ResultSchema, ColumnInfo, ColumnInfoTypeName, ResultData +from databricks.sdk.errors import ( + BadRequest, + DataLoss, + NotFound, + PermissionDenied, + Unknown, +) + +# pylint: disable=protected-access + + +@dataclass +class Foo: + first: str + second: bool + + +@dataclass +class Baz: + first: str + second: str | None = None + + +@dataclass +class Bar: + first: str + second: bool + third: float + + +def test_statement_execution_backend_execute_happy(): + ws = create_autospec(WorkspaceClient) + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc") + + seb.execute("CREATE TABLE foo") + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="CREATE TABLE foo", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ) + + +def test_statement_execution_backend_fetch_happy(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) + ), + result=ResultData(data_array=[["1"], ["2"], ["3"]]), + statement_id="bcd", + ) + + seb = StatementExecutionBackend(ws, "abc") + + result = list(seb.fetch("SELECT id FROM range(3)")) + + assert [Row(id=1), Row(id=2), Row(id=3)] == result + + +def test_statement_execution_backend_save_table_overwrite(mocker): + seb = StatementExecutionBackend(mocker.Mock(), "abc") + with pytest.raises(NotImplementedError): + seb.save_table("a.b.c", [1, 2, 3], Bar, mode="overwrite") + + +def test_statement_execution_backend_save_table_empty_records(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc") + + seb.save_table("a.b.c", [], Bar) + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="CREATE TABLE IF NOT EXISTS a.b.c " + "(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ) + + +def test_statement_execution_backend_save_table_two_records(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc") + + seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) + + ws.statement_execution.execute_statement.assert_has_calls([ + mock.call( + warehouse_id="abc", + statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + ]) + + +def test_statement_execution_backend_save_table_in_batches_of_two(mocker): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc", max_records_per_batch=2) + + seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo) + + ws.statement_execution.execute_statement.assert_has_calls([ + mock.call( + warehouse_id="abc", + statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT INTO a.b.c (first, second) VALUES ('ccc', TRUE)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + ]) + + +def test_runtime_backend_execute(): + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + runtime_backend = RuntimeBackend() + + runtime_backend.execute("CREATE TABLE foo") + + spark.sql.assert_called_with("CREATE TABLE foo") + + +def test_runtime_backend_fetch(): + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + spark.sql().collect.return_value = [Row(id=1), Row(id=2), Row(id=3)] + + runtime_backend = RuntimeBackend() + + result = runtime_backend.fetch("SELECT id FROM range(3)") + + assert [Row(id=1), Row(id=2), Row(id=3)] == result + + spark.sql.assert_called_with("SELECT id FROM range(3)") + + +def test_runtime_backend_save_table(): + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + runtime_backend = RuntimeBackend() + + runtime_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) + + spark.createDataFrame.assert_called_with( + [Foo(first="aaa", second=True), Foo(first="bbb", second=False)], + "first STRING NOT NULL, second BOOLEAN NOT NULL", + ) + spark.createDataFrame().write.saveAsTable.assert_called_with("a.b.c", mode="append") + + +def test_runtime_backend_save_table_with_row_containing_none_with_nullable_class(mocker): + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + runtime_backend = RuntimeBackend() + + runtime_backend.save_table("a.b.c", [Baz("aaa", "ccc"), Baz("bbb", None)], Baz) + + spark.createDataFrame.assert_called_with( + [Baz(first="aaa", second="ccc"), Baz(first="bbb", second=None)], + "first STRING NOT NULL, second STRING", + ) + spark.createDataFrame().write.saveAsTable.assert_called_with("a.b.c", mode="append") + + +@dataclass +class DummyClass: + key: str + value: str | None = None + + +def test_save_table_with_not_null_constraint_violated(): + rows = [DummyClass("1", "test"), DummyClass("2", None), DummyClass(None, "value")] + + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + + runtime_backend = RuntimeBackend() + + with pytest.raises(Exception, + match="Not null constraint violated for column key, row = {'key': None, 'value': 'value'}"): + runtime_backend.save_table("a.b.c", rows, DummyClass) + + +@pytest.mark.parametrize('msg,err_t', [ + ("SCHEMA_NOT_FOUND foo schema does not exist", NotFound), + (".. TABLE_OR_VIEW_NOT_FOUND ..", NotFound), + (".. UNRESOLVED_COLUMN.WITH_SUGGESTION ..", BadRequest), + ("DELTA_TABLE_NOT_FOUND foo table does not exist", NotFound), + ("DELTA_MISSING_TRANSACTION_LOG foo table does not exist", DataLoss), + ("PARSE_SYNTAX_ERROR foo", BadRequest), + ("foo Operation not allowed", PermissionDenied), + ("foo error failure", Unknown), +]) +def test_runtime_backend_error_mapping_similar_to_statement_execution(msg, err_t): + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + spark.sql.side_effect = Exception(msg) + + runtime_backend = RuntimeBackend() + + with pytest.raises(err_t): + runtime_backend.execute("SELECT * from bar") + + with pytest.raises(err_t): + list(runtime_backend.fetch("SELECT * from bar")) + + +def test_mock_backend_fails_on_first(): + mock_backend = MockBackend(fails_on_first={"CREATE": '.. DELTA_TABLE_NOT_FOUND ..'}) + + with pytest.raises(NotFound): + mock_backend.execute("CREATE TABLE foo") + + +def test_mock_backend_rows(): + mock_backend = MockBackend(rows={ + r"SELECT id FROM range\(3\)": [Row(id=1), Row(id=2), Row(id=3)] + }) + + result = list(mock_backend.fetch("SELECT id FROM range(3)")) + + assert [Row(id=1), Row(id=2), Row(id=3)] == result + + +def test_mock_backend_save_table(): + mock_backend = MockBackend() + + mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) + + assert mock_backend.rows_written_for("a.b.c", "append") == [ + Row(first='aaa', second=True), + Row(first='bbb', second=False), + ] \ No newline at end of file diff --git a/tests/unit/test_deployment.py b/tests/unit/test_deployment.py new file mode 100644 index 00000000..4651ff1c --- /dev/null +++ b/tests/unit/test_deployment.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass + +from databricks.labs.lsql.backends import MockBackend +from databricks.labs.lsql.deployment import SchemaDeployer +from . import views + + +def test_deploys_view(): + mock_backend = MockBackend() + deployment = SchemaDeployer( + sql_backend=mock_backend, + inventory_schema="inventory", + mod=views, + ) + + deployment.deploy_view('some', 'some.sql') + + assert mock_backend.queries == [ + 'CREATE OR REPLACE VIEW hive_metastore.inventory.some AS SELECT id, name FROM hive_metastore.inventory.something' + ] + + +@dataclass +class Foo: + first: str + second: bool + + +def test_deploys_dataclass(): + mock_backend = MockBackend() + deployment = SchemaDeployer( + sql_backend=mock_backend, + inventory_schema="inventory", + mod=views, + ) + deployment.deploy_schema() + deployment.deploy_table('foo', Foo) + deployment.delete_schema() + + assert mock_backend.queries == [ + 'CREATE SCHEMA IF NOT EXISTS hive_metastore.inventory', + 'CREATE TABLE IF NOT EXISTS hive_metastore.inventory.foo (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA', + 'DROP SCHEMA IF EXISTS hive_metastore.inventory CASCADE' + ] + diff --git a/tests/unit/views/__init__.py b/tests/unit/views/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/views/some.sql b/tests/unit/views/some.sql new file mode 100644 index 00000000..166a1660 --- /dev/null +++ b/tests/unit/views/some.sql @@ -0,0 +1 @@ +SELECT id, name FROM $inventory.something \ No newline at end of file From 2749b8932d6cfa107188946ddb539e25c35268c2 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Mon, 11 Mar 2024 15:38:22 +0100 Subject: [PATCH 09/10] .. --- src/databricks/labs/lsql/core.py | 312 +++++++++++++++---------------- 1 file changed, 151 insertions(+), 161 deletions(-) diff --git a/src/databricks/labs/lsql/core.py b/src/databricks/labs/lsql/core.py index 9f5f9979..e1b7ed41 100644 --- a/src/databricks/labs/lsql/core.py +++ b/src/databricks/labs/lsql/core.py @@ -37,7 +37,9 @@ class Row(tuple): + """Row is a tuple with named fields that resembles PySpark's SQL Row API.""" def __new__(cls, *args, **kwargs): + """Create a new instance of Row.""" if args and kwargs: raise ValueError("cannot mix positional and keyword arguments") if kwargs: @@ -57,23 +59,26 @@ def __new__(cls, *args, **kwargs): @classmethod def factory(cls, col_names: list[str]) -> type: + """Create a new Row class with the given column names.""" return type("Row", (Row,), {"__columns__": col_names}) - # Python SDK convention def as_dict(self) -> dict[str, Any]: + """Convert the row to a dictionary with the same conventions as Databricks SDK.""" return dict(zip(self.__columns__, self, strict=True)) - # PySpark convention def __contains__(self, item): + """Check if the column is in the row.""" return item in self.__columns__ def __getitem__(self, col): + """Get the column by index or name.""" if isinstance(col, int | slice): return super().__getitem__(col) # if columns are named `2 + 2`, for example return self.__getattr__(col) def __getattr__(self, col): + """Get the column by name.""" if col.startswith("__"): raise AttributeError(col) try: @@ -85,48 +90,30 @@ def __getattr__(self, col): raise AttributeError(col) from None def __repr__(self): + """Get the string representation of the row.""" return f"Row({', '.join(f'{k}={v!r}' for (k, v) in zip(self.__columns__, self, strict=True))})" class StatementExecutionExt: - """ - Execute SQL statements in a stateless manner. + """Execute SQL statements in a stateless manner. - Primary use-case of :py:meth:`iterate_rows` and :py:meth:`execute` methods is oriented at executing SQL queries in + Primary use-case of :py:meth:`fetch_all` and :py:meth:`execute` methods is oriented at executing SQL queries in a stateless manner straight away from Databricks SDK for Python, without requiring any external dependencies. Results are fetched in JSON format through presigned external links. This is perfect for serverless applications like AWS Lambda, Azure Functions, or any other containerised short-lived applications, where container startup time is faster with the smaller dependency set. - .. code-block: + >>> for (pickup_zip, dropoff_zip) in see('SELECT pickup_zip, dropoff_zip FROM samples.nyctaxi.trips LIMIT 10'): + >>> print(f'pickup_zip={pickup_zip}, dropoff_zip={dropoff_zip}') - for (pickup_zip, dropoff_zip) in w.statement_execution.iterate_rows(warehouse_id, - 'SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10', catalog='samples'): - print(f'pickup_zip={pickup_zip}, dropoff_zip={dropoff_zip}') - - Method :py:meth:`iterate_rows` returns an iterator of objects, that resemble :class:`pyspark.sql.Row` APIs, but full - compatibility is not the goal of this implementation. - - .. code-block:: - - iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog='samples') - for row in iterate_rows('SELECT * FROM nyctaxi.trips LIMIT 10'): - pickup_time, dropoff_time = row[0], row[1] - pickup_zip = row.pickup_zip - dropoff_zip = row['dropoff_zip'] - all_fields = row.as_dict() - print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') + Method :py:meth:`fetch_all` returns an iterator of objects, that resemble :class:`pyspark.sql.Row` APIs, but full + compatibility is not the goal of this implementation. See :class:`Row` for more details. When you only need to execute the query and have no need to iterate over results, use the :py:meth:`execute`. - .. code-block:: - - w.statement_execution.execute(warehouse_id, 'CREATE TABLE foo AS SELECT * FROM range(10)') - Applications, that need to a more traditional SQL Python APIs with cursors, efficient data transfer of hundreds of megabytes or gigabytes of data serialized in Apache Arrow format, and low result fetching latency, should use - the stateful Databricks SQL Connector for Python. - """ + the stateful Databricks SQL Connector for Python.""" def __init__(self, ws: WorkspaceClient, disposition: Disposition | None = None, @@ -168,86 +155,6 @@ def __init__(self, ws: WorkspaceClient, ColumnInfoTypeName.TIMESTAMP: self._parse_timestamp, } - @staticmethod - def _parse_date(value: str) -> datetime.date: - """Parse date from string in ISO format.""" - year, month, day = value.split('-') - return datetime.date(int(year), int(month), int(day)) - - @staticmethod - def _parse_timestamp(value: str) -> datetime.datetime: - """Parse timestamp from string in ISO format.""" - # make it work with Python 3.7 to 3.10 as well - return datetime.datetime.fromisoformat(value.replace('Z', '+00:00')) - - @staticmethod - def _raise_if_needed(status: StatementStatus): - """Raise an exception if the statement status is failed, canceled, or closed.""" - if status.state not in [StatementState.FAILED, StatementState.CANCELED, StatementState.CLOSED]: - return - status_error = status.error - if status_error is None: - status_error = ServiceError(message="unknown", error_code=ServiceErrorCode.UNKNOWN) - error_message = status_error.message - if error_message is None: - error_message = "" - if "SCHEMA_NOT_FOUND" in error_message: - raise NotFound(error_message) - if "TABLE_OR_VIEW_NOT_FOUND" in error_message: - raise NotFound(error_message) - if "DELTA_TABLE_NOT_FOUND" in error_message: - raise NotFound(error_message) - if "DELTA_MISSING_TRANSACTION_LOG" in error_message: - raise DataLoss(error_message) - mapping = { - ServiceErrorCode.ABORTED: errors.Aborted, - ServiceErrorCode.ALREADY_EXISTS: errors.AlreadyExists, - ServiceErrorCode.BAD_REQUEST: errors.BadRequest, - ServiceErrorCode.CANCELLED: errors.Cancelled, - ServiceErrorCode.DEADLINE_EXCEEDED: errors.DeadlineExceeded, - ServiceErrorCode.INTERNAL_ERROR: errors.InternalError, - ServiceErrorCode.IO_ERROR: errors.InternalError, - ServiceErrorCode.NOT_FOUND: errors.NotFound, - ServiceErrorCode.RESOURCE_EXHAUSTED: errors.ResourceExhausted, - ServiceErrorCode.SERVICE_UNDER_MAINTENANCE: errors.TemporarilyUnavailable, - ServiceErrorCode.TEMPORARILY_UNAVAILABLE: errors.TemporarilyUnavailable, - ServiceErrorCode.UNAUTHENTICATED: errors.Unauthenticated, - ServiceErrorCode.UNKNOWN: errors.Unknown, - ServiceErrorCode.WORKSPACE_TEMPORARILY_UNAVAILABLE: errors.TemporarilyUnavailable, - } - error_code = status_error.error_code - if error_code is None: - error_code = ServiceErrorCode.UNKNOWN - error_class = mapping.get(error_code, errors.Unknown) - raise error_class(error_message) - - def _default_warehouse(self) -> str: - """Get the default warehouse id from the workspace client configuration - or DATABRICKS_WAREHOUSE_ID environment variable. If not set, it will use - the first available warehouse that is not in the DELETED or DELETING state.""" - with self._lock: - if self._warehouse_id: - return self._warehouse_id - # if we create_autospec(WorkspaceClient), the warehouse_id is a MagicMock - if isinstance(self._ws.config.warehouse_id, str) and self._ws.config.warehouse_id: - self._warehouse_id = self._ws.config.warehouse_id - return self._ws.config.warehouse_id - ids = [] - for v in self._ws.warehouses.list(): - if v.state in [State.DELETED, State.DELETING]: - continue - elif v.state == State.RUNNING: - self._ws.config.warehouse_id = v.id - return self._ws.config.warehouse_id - ids.append(v.id) - if len(ids) > 0: - # otherwise - first warehouse - self._ws.config.warehouse_id = ids[0] - return self._ws.config.warehouse_id - raise ValueError("no warehouse_id=... given, " - "neither it is set in the WorkspaceClient(..., warehouse_id=...), " - "nor in the DATABRICKS_WAREHOUSE_ID environment variable") - def execute( self, statement: str, @@ -258,36 +165,40 @@ def execute( schema: str | None = None, timeout: timedelta | None = None, ) -> ExecuteStatementResponse: - """(Experimental) Execute a SQL statement and block until results are ready, - including starting the warehouse if needed. + """Execute a SQL statement and block until results are ready, including starting + the warehouse if needed. This is a high-level implementation that works with fetching records in JSON format. It can be considered as a quick way to run SQL queries by just depending on Databricks SDK for Python without the need of any other compiled library dependencies. - This method is a higher-level wrapper over :py:meth:`execute_statement` and fetches results + This method is a higher-level wrapper over Databricks SDK and fetches results in JSON format through the external link disposition, with client-side polling until the statement succeeds in execution. Whenever the statement is failed, cancelled, or - closed, this method raises `DatabricksError` with the state message and the relevant - error code. + closed, this method raises `DatabricksError` subclass with the state message and + the relevant error code. - To seamlessly iterate over the rows from query results, please use :py:meth:`iterate_rows`. + To seamlessly iterate over the rows from query results, please use :py:meth:`fetch_all`. - :param warehouse_id: str - Warehouse upon which to execute a statement. :param statement: str SQL statement to execute + :param warehouse_id: str (optional) + Warehouse upon which to execute a statement. If not given, it will use the warehouse specified + in the constructor or the first available warehouse that is not in the DELETED or DELETING state. :param byte_limit: int (optional) - Applies the given byte limit to the statement's result size. Byte counts are based on internal - representations and may not match measurable sizes in the JSON format. + Applies the given byte limit to the statement's result size. Byte counts are based on internal + representations and may not match measurable sizes in the JSON format. :param catalog: str (optional) - Sets default catalog for statement execution, similar to `USE CATALOG` in SQL. + Sets default catalog for statement execution, similar to `USE CATALOG` in SQL. If not given, + it will use the default catalog or the catalog specified in the constructor. :param schema: str (optional) - Sets default schema for statement execution, similar to `USE SCHEMA` in SQL. + Sets default schema for statement execution, similar to `USE SCHEMA` in SQL. If not given, + it will use the default schema or the schema specified in the constructor. :param timeout: timedelta (optional) - Timeout after which the query is cancelled. If timeout is less than 50 seconds, - it is handled on the server side. If the timeout is greater than 50 seconds, - Databricks SDK for Python cancels the statement execution and throws `TimeoutError`. + Timeout after which the query is cancelled. If timeout is less than 50 seconds, + it is handled on the server side. If the timeout is greater than 50 seconds, + Databricks SDK for Python cancels the statement execution and throws `TimeoutError`. + If not given, it will use the timeout specified in the constructor. :return: ExecuteStatementResponse """ # The wait_timeout field must be 0 seconds (disables wait), @@ -346,16 +257,8 @@ def execute( msg = f"timed out after {timeout}: {status_message}" raise TimeoutError(msg) - def _statement_timeouts(self, timeout): - if timeout is None: - timeout = self._timeout - wait_timeout = None - if MIN_PLATFORM_TIMEOUT <= timeout.total_seconds() <= MAX_PLATFORM_TIMEOUT: - # set server-side timeout - wait_timeout = f"{timeout.total_seconds()}s" - return timeout, wait_timeout - def __call__(self, statement: str): + """Execute a SQL statement and block until results are ready.""" yield from self.fetch_all(statement) def fetch_all( @@ -375,31 +278,27 @@ def fetch_all( a returned iterator. Every row API resembles those of :class:`pyspark.sql.Row`, but full compatibility is not the goal of this implementation. - .. code-block:: + >>> ws = WorkspaceClient(...) + >>> see = StatementExecutionExt(ws, warehouse_id=env_or_skip("TEST_DEFAULT_WAREHOUSE_ID"), catalog="samples") + >>> for row in see("SELECT * FROM nyctaxi.trips LIMIT 10"): + >>> pickup_time, dropoff_time = row[0], row[1] + >>> pickup_zip = row.pickup_zip + >>> dropoff_zip = row["dropoff_zip"] + >>> all_fields = row.as_dict() + >>> logger.info(f"{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}") - iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog='samples') - for row in iterate_rows('SELECT * FROM nyctaxi.trips LIMIT 10'): - pickup_time, dropoff_time = row[0], row[1] - pickup_zip = row.pickup_zip - dropoff_zip = row['dropoff_zip'] - all_fields = row.as_dict() - print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') - - :param warehouse_id: str - Warehouse upon which to execute a statement. :param statement: str - SQL statement to execute + SQL statement to execute + :param warehouse_id: str (optional) + Warehouse upon which to execute a statement. See :py:meth:`execute` for more details. :param byte_limit: int (optional) - Applies the given byte limit to the statement's result size. Byte counts are based on internal - representations and may not match measurable sizes in the JSON format. + Result-size limit in bytes. See :py:meth:`execute` for more details. :param catalog: str (optional) - Sets default catalog for statement execution, similar to `USE CATALOG` in SQL. + Catalog for statement execution. See :py:meth:`execute` for more details. :param schema: str (optional) - Sets default schema for statement execution, similar to `USE SCHEMA` in SQL. + Schema for statement execution. See :py:meth:`execute` for more details. :param timeout: timedelta (optional) - Timeout after which the query is cancelled. If timeout is less than 50 seconds, - it is handled on the server side. If the timeout is greater than 50 seconds, - Databricks SDK for Python cancels the statement execution and throws `TimeoutError`. + Timeout after which the query is cancelled. See :py:meth:`execute` for more details. :return: Iterator[Row] """ execute_response = self.execute(statement, @@ -436,20 +335,20 @@ def fetch_one(self, statement: str, disable_magic: bool = False, **kwargs) -> Ro This method is a wrapper over :py:meth:`fetch_all` and fetches only the first row from the result set. If no records are available, it returns `None`. - .. code-block:: - - row = see.fetch_one('SELECT * FROM samples.nyctaxi.trips LIMIT 1') - if row: - pickup_time, dropoff_time = row[0], row[1] - pickup_zip = row.pickup_zip - dropoff_zip = row['dropoff_zip'] - all_fields = row.as_dict() - print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') + >>> row = see.fetch_one('SELECT * FROM samples.nyctaxi.trips LIMIT 1') + >>> if row: + >>> pickup_time, dropoff_time = row[0], row[1] + >>> pickup_zip = row.pickup_zip + >>> dropoff_zip = row['dropoff_zip'] + >>> all_fields = row.as_dict() + >>> print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') :param statement: str SQL statement to execute :param disable_magic: bool (optional) Disables the magic of adding `LIMIT 1` to the statement. By default, it is `False`. + :param kwargs: dict + Additional keyword arguments to pass to :py:meth:`fetch_all` :return: Row | None """ disable_magic = disable_magic or self._disable_magic @@ -465,9 +364,99 @@ def fetch_value(self, statement: str, **kwargs) -> Any | None: return v return None + def _statement_timeouts(self, timeout): + """Set server-side and client-side timeouts for statement execution.""" + if timeout is None: + timeout = self._timeout + wait_timeout = None + if MIN_PLATFORM_TIMEOUT <= timeout.total_seconds() <= MAX_PLATFORM_TIMEOUT: + # set server-side timeout + wait_timeout = f"{timeout.total_seconds()}s" + return timeout, wait_timeout + + @staticmethod + def _parse_date(value: str) -> datetime.date: + """Parse date from string in ISO format.""" + year, month, day = value.split('-') + return datetime.date(int(year), int(month), int(day)) + + @staticmethod + def _parse_timestamp(value: str) -> datetime.datetime: + """Parse timestamp from string in ISO format.""" + # make it work with Python 3.7 to 3.10 as well + return datetime.datetime.fromisoformat(value.replace('Z', '+00:00')) + + @staticmethod + def _raise_if_needed(status: StatementStatus): + """Raise an exception if the statement status is failed, canceled, or closed.""" + if status.state not in [StatementState.FAILED, StatementState.CANCELED, StatementState.CLOSED]: + return + status_error = status.error + if status_error is None: + status_error = ServiceError(message="unknown", error_code=ServiceErrorCode.UNKNOWN) + error_message = status_error.message + if error_message is None: + error_message = "" + if "SCHEMA_NOT_FOUND" in error_message: + raise NotFound(error_message) + if "TABLE_OR_VIEW_NOT_FOUND" in error_message: + raise NotFound(error_message) + if "DELTA_TABLE_NOT_FOUND" in error_message: + raise NotFound(error_message) + if "DELTA_MISSING_TRANSACTION_LOG" in error_message: + raise DataLoss(error_message) + mapping = { + ServiceErrorCode.ABORTED: errors.Aborted, + ServiceErrorCode.ALREADY_EXISTS: errors.AlreadyExists, + ServiceErrorCode.BAD_REQUEST: errors.BadRequest, + ServiceErrorCode.CANCELLED: errors.Cancelled, + ServiceErrorCode.DEADLINE_EXCEEDED: errors.DeadlineExceeded, + ServiceErrorCode.INTERNAL_ERROR: errors.InternalError, + ServiceErrorCode.IO_ERROR: errors.InternalError, + ServiceErrorCode.NOT_FOUND: errors.NotFound, + ServiceErrorCode.RESOURCE_EXHAUSTED: errors.ResourceExhausted, + ServiceErrorCode.SERVICE_UNDER_MAINTENANCE: errors.TemporarilyUnavailable, + ServiceErrorCode.TEMPORARILY_UNAVAILABLE: errors.TemporarilyUnavailable, + ServiceErrorCode.UNAUTHENTICATED: errors.Unauthenticated, + ServiceErrorCode.UNKNOWN: errors.Unknown, + ServiceErrorCode.WORKSPACE_TEMPORARILY_UNAVAILABLE: errors.TemporarilyUnavailable, + } + error_code = status_error.error_code + if error_code is None: + error_code = ServiceErrorCode.UNKNOWN + error_class = mapping.get(error_code, errors.Unknown) + raise error_class(error_message) + + def _default_warehouse(self) -> str: + """Get the default warehouse id from the workspace client configuration + or DATABRICKS_WAREHOUSE_ID environment variable. If not set, it will use + the first available warehouse that is not in the DELETED or DELETING state.""" + with self._lock: + if self._warehouse_id: + return self._warehouse_id + # if we create_autospec(WorkspaceClient), the warehouse_id is a MagicMock + if isinstance(self._ws.config.warehouse_id, str) and self._ws.config.warehouse_id: + self._warehouse_id = self._ws.config.warehouse_id + return self._ws.config.warehouse_id + ids = [] + for v in self._ws.warehouses.list(): + if v.state in [State.DELETED, State.DELETING]: + continue + elif v.state == State.RUNNING: + self._ws.config.warehouse_id = v.id + return self._ws.config.warehouse_id + ids.append(v.id) + if len(ids) > 0: + # otherwise - first warehouse + self._ws.config.warehouse_id = ids[0] + return self._ws.config.warehouse_id + raise ValueError("no warehouse_id=... given, " + "neither it is set in the WorkspaceClient(..., warehouse_id=...), " + "nor in the DATABRICKS_WAREHOUSE_ID environment variable") + @staticmethod def _add_limit(statement: str) -> str: - # parse with sqlglot if there's a limit statement + """Add a limit 1 to the statement if it does not have one already.""" statements = sqlglot.parse(statement, read="databricks") if not statements: raise ValueError(f"cannot parse statement: {statement}") @@ -481,6 +470,7 @@ def _add_limit(statement: str) -> str: return statement def _result_schema(self, execute_response: ExecuteStatementResponse): + """Get the result schema from the execute response.""" manifest = execute_response.manifest if not manifest: msg = f"missing manifest: {execute_response}" From 4f379076f8636d70a9e3bbd937a0213508a1b3b5 Mon Sep 17 00:00:00 2001 From: Serge Smertin Date: Mon, 11 Mar 2024 16:54:19 +0100 Subject: [PATCH 10/10] .. --- src/databricks/labs/lsql/__init__.py | 1 - src/databricks/labs/lsql/backends.py | 21 ++- src/databricks/labs/lsql/core.py | 79 ++++++----- tests/integration/conftest.py | 6 +- tests/integration/test_integration.py | 29 ++-- tests/unit/test_backends.py | 182 ++++++++++++++------------ tests/unit/test_core.py | 79 +++++------ tests/unit/test_deployment.py | 14 +- 8 files changed, 218 insertions(+), 193 deletions(-) diff --git a/src/databricks/labs/lsql/__init__.py b/src/databricks/labs/lsql/__init__.py index a8cb9877..ba43b3c2 100644 --- a/src/databricks/labs/lsql/__init__.py +++ b/src/databricks/labs/lsql/__init__.py @@ -1,4 +1,3 @@ from .core import Row __all__ = ["Row"] - diff --git a/src/databricks/labs/lsql/backends.py b/src/databricks/labs/lsql/backends.py index 9ae17b97..d9d0c64b 100644 --- a/src/databricks/labs/lsql/backends.py +++ b/src/databricks/labs/lsql/backends.py @@ -155,7 +155,7 @@ def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: D self.execute(sql) @staticmethod - def _row_to_sql(row: DataclassInstance, fields: list[dataclasses.Field]): + def _row_to_sql(row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]): data = [] for f in fields: value = getattr(row, f.name) @@ -215,8 +215,11 @@ def __init__(self, debug_truncate_bytes: int | None = None): msg = "Not in the Databricks Runtime" raise RuntimeError(msg) try: - # pylint: disable-next=import-error,import-outside-toplevel - from pyspark.sql.session import SparkSession # type: ignore[import-not-found] + # pylint: disable-next=import-error,import-outside-toplevel,useless-suppression + from pyspark.sql.session import ( # type: ignore[import-not-found] + SparkSession, + ) + super().__init__(SparkSession.builder.getOrCreate(), debug_truncate_bytes) except ImportError as e: raise RuntimeError("pyspark is not available") from e @@ -225,7 +228,11 @@ def __init__(self, debug_truncate_bytes: int | None = None): class DatabricksConnectBackend(_SparkBackend): def __init__(self, ws: WorkspaceClient): try: - from databricks.connect import DatabricksSession + # pylint: disable-next=import-outside-toplevel + from databricks.connect import ( # type: ignore[import-untyped] + DatabricksSession, + ) + spark = DatabricksSession.builder().sdk_config(ws.config).getOrCreate() super().__init__(spark, ws.config.debug_truncate_bytes) except ImportError as e: @@ -233,7 +240,9 @@ def __init__(self, ws: WorkspaceClient): class MockBackend(SqlBackend): - def __init__(self, *, fails_on_first: dict[str,str] | None = None, rows: dict | None = None, debug_truncate_bytes=96): + def __init__( + self, *, fails_on_first: dict[str, str] | None = None, rows: dict | None = None, debug_truncate_bytes=96 + ): self._fails_on_first = fails_on_first if not rows: rows = {} @@ -286,4 +295,4 @@ def rows_written_for(self, full_name: str, mode: str) -> list[DataclassInstance] @staticmethod def _row_factory(klass: Dataclass) -> type: - return Row.factory([f.name for f in dataclasses.fields(klass)]) \ No newline at end of file + return Row.factory([f.name for f in dataclasses.fields(klass)]) diff --git a/src/databricks/labs/lsql/core.py b/src/databricks/labs/lsql/core.py index e1b7ed41..fa048c09 100644 --- a/src/databricks/labs/lsql/core.py +++ b/src/databricks/labs/lsql/core.py @@ -1,15 +1,14 @@ import base64 import datetime -import functools import json import logging import random import threading import time import types -from collections.abc import Iterator +from collections.abc import Callable, Iterator from datetime import timedelta -from typing import Any, Callable +from typing import Any import requests import sqlglot @@ -20,11 +19,11 @@ Disposition, ExecuteStatementResponse, Format, - ResultData, ServiceError, ServiceErrorCode, + State, StatementState, - StatementStatus, State, + StatementStatus, ) MAX_SLEEP_PER_ATTEMPT = 10 @@ -38,6 +37,7 @@ class Row(tuple): """Row is a tuple with named fields that resembles PySpark's SQL Row API.""" + def __new__(cls, *args, **kwargs): """Create a new instance of Row.""" if args and kwargs: @@ -47,7 +47,7 @@ def __new__(cls, *args, **kwargs): row = tuple.__new__(cls, list(kwargs.values())) row.__columns__ = list(kwargs.keys()) return row - if len(args) == 1 and hasattr(cls, '__columns__') and isinstance(args[0], (types.GeneratorType, list, tuple)): + if len(args) == 1 and hasattr(cls, "__columns__") and isinstance(args[0], (types.GeneratorType, list, tuple)): # this type returned by Row.factory() and we already know the column names return cls(*args[0]) if len(args) == 2 and isinstance(args[0], (list, tuple)) and isinstance(args[1], (list, tuple)): @@ -115,19 +115,21 @@ class StatementExecutionExt: megabytes or gigabytes of data serialized in Apache Arrow format, and low result fetching latency, should use the stateful Databricks SQL Connector for Python.""" - def __init__(self, ws: WorkspaceClient, - disposition: Disposition | None = None, - warehouse_id: str | None = None, - byte_limit: int | None = None, - catalog: str | None = None, - schema: str | None = None, - timeout: timedelta = timedelta(minutes=20), - disable_magic: bool = False, - http_session_factory: Callable[[], requests.Session] | None = None): + def __init__( # pylint: disable=too-many-arguments + self, + ws: WorkspaceClient, + disposition: Disposition | None = None, + warehouse_id: str | None = None, + byte_limit: int | None = None, + catalog: str | None = None, + schema: str | None = None, + timeout: timedelta = timedelta(minutes=20), + disable_magic: bool = False, + http_session_factory: Callable[[], requests.Session] | None = None, + ): if not http_session_factory: http_session_factory = requests.Session self._ws = ws - self._api = ws.api_client self._http = http_session_factory() self._lock = threading.Lock() self._warehouse_id = warehouse_id @@ -301,15 +303,13 @@ def fetch_all( Timeout after which the query is cancelled. See :py:meth:`execute` for more details. :return: Iterator[Row] """ - execute_response = self.execute(statement, - warehouse_id=warehouse_id, - byte_limit=byte_limit, - catalog=catalog, - schema=schema, - timeout=timeout) + execute_response = self.execute( + statement, warehouse_id=warehouse_id, byte_limit=byte_limit, catalog=catalog, schema=schema, timeout=timeout + ) + assert execute_response.statement_id is not None result_data = execute_response.result if result_data is None: - return [] + return row_factory, col_conv = self._result_schema(execute_response) while True: if result_data.data_array: @@ -318,6 +318,7 @@ def fetch_all( next_chunk_index = result_data.next_chunk_index if result_data.external_links: for external_link in result_data.external_links: + assert external_link.external_link is not None next_chunk_index = external_link.next_chunk_index response = self._http.get(external_link.external_link) response.raise_for_status() @@ -326,8 +327,8 @@ def fetch_all( if not next_chunk_index: return result_data = self._ws.statement_execution.get_statement_result_chunk_n( - execute_response.statement_id, - next_chunk_index) + execute_response.statement_id, next_chunk_index + ) def fetch_one(self, statement: str, disable_magic: bool = False, **kwargs) -> Row | None: """Execute a query and fetch the first available record. @@ -360,11 +361,11 @@ def fetch_one(self, statement: str, disable_magic: bool = False, **kwargs) -> Ro def fetch_value(self, statement: str, **kwargs) -> Any | None: """Execute a query and fetch the first available value.""" - for v, in self.fetch_all(statement, **kwargs): + for (v,) in self.fetch_all(statement, **kwargs): return v return None - def _statement_timeouts(self, timeout): + def _statement_timeouts(self, timeout) -> tuple[timedelta, str | None]: """Set server-side and client-side timeouts for statement execution.""" if timeout is None: timeout = self._timeout @@ -372,19 +373,20 @@ def _statement_timeouts(self, timeout): if MIN_PLATFORM_TIMEOUT <= timeout.total_seconds() <= MAX_PLATFORM_TIMEOUT: # set server-side timeout wait_timeout = f"{timeout.total_seconds()}s" + assert timeout is not None return timeout, wait_timeout @staticmethod def _parse_date(value: str) -> datetime.date: """Parse date from string in ISO format.""" - year, month, day = value.split('-') + year, month, day = value.split("-") return datetime.date(int(year), int(month), int(day)) @staticmethod def _parse_timestamp(value: str) -> datetime.datetime: """Parse timestamp from string in ISO format.""" # make it work with Python 3.7 to 3.10 as well - return datetime.datetime.fromisoformat(value.replace('Z', '+00:00')) + return datetime.datetime.fromisoformat(value.replace("Z", "+00:00")) @staticmethod def _raise_if_needed(status: StatementStatus): @@ -440,9 +442,10 @@ def _default_warehouse(self) -> str: return self._ws.config.warehouse_id ids = [] for v in self._ws.warehouses.list(): + assert v.id is not None if v.state in [State.DELETED, State.DELETING]: continue - elif v.state == State.RUNNING: + if v.state == State.RUNNING: self._ws.config.warehouse_id = v.id return self._ws.config.warehouse_id ids.append(v.id) @@ -450,9 +453,11 @@ def _default_warehouse(self) -> str: # otherwise - first warehouse self._ws.config.warehouse_id = ids[0] return self._ws.config.warehouse_id - raise ValueError("no warehouse_id=... given, " - "neither it is set in the WorkspaceClient(..., warehouse_id=...), " - "nor in the DATABRICKS_WAREHOUSE_ID environment variable") + raise ValueError( + "no warehouse_id=... given, " + "neither it is set in the WorkspaceClient(..., warehouse_id=...), " + "nor in the DATABRICKS_WAREHOUSE_ID environment variable" + ) @staticmethod def _add_limit(statement: str) -> str: @@ -463,10 +468,10 @@ def _add_limit(statement: str) -> str: statement_ast = statements[0] if isinstance(statement_ast, sqlglot.expressions.Select): if statement_ast.limit is not None: - limit = statement_ast.args.get('limit', None) - if limit and limit.text('expression') != '1': + limit = statement_ast.args.get("limit", None) + if limit and limit.text("expression") != "1": raise ValueError(f"limit is not 1: {limit.text('expression')}") - return statement_ast.limit(expression=1).sql('databricks') + return statement_ast.limit(expression=1).sql("databricks") return statement def _result_schema(self, execute_response: ExecuteStatementResponse): @@ -485,6 +490,8 @@ def _result_schema(self, execute_response: ExecuteStatementResponse): if not columns: columns = [] for col in columns: + assert col.name is not None + assert col.type_name is not None col_names.append(col.name) conv = self._type_converters.get(col.type_name, None) if conv is None: diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index d46a5952..c4db6489 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -4,14 +4,14 @@ import pathlib import string import sys -from typing import MutableMapping, Callable +from typing import Callable, MutableMapping import pytest +from databricks.labs.blueprint.logger import install_logger from databricks.sdk import WorkspaceClient from pytest import fixture from databricks.labs.lsql.__about__ import __version__ -from databricks.labs.blueprint.logger import install_logger install_logger() logging.getLogger("databricks").setLevel("DEBUG") @@ -89,4 +89,4 @@ def inner(var: str) -> str: skip(f"Environment variable {var} is missing") return debug_env[var] - return inner \ No newline at end of file + return inner diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py index d37b3df3..5514058f 100644 --- a/tests/integration/test_integration.py +++ b/tests/integration/test_integration.py @@ -1,9 +1,9 @@ import logging import pytest +from databricks.sdk.service.sql import Disposition from databricks.labs.lsql.core import StatementExecutionExt -from databricks.sdk.service.sql import Disposition logger = logging.getLogger(__name__) @@ -21,21 +21,20 @@ def test_sql_execution(ws, env_or_skip): results = [] see = StatementExecutionExt(ws, warehouse_id=env_or_skip("TEST_DEFAULT_WAREHOUSE_ID")) for pickup_zip, dropoff_zip in see.fetch_all( - "SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10", - catalog="samples" + "SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10", catalog="samples" ): results.append((pickup_zip, dropoff_zip)) assert results == [ - (10282, 10171), - (10110, 10110), - (10103, 10023), - (10022, 10017), - (10110, 10282), - (10009, 10065), - (10153, 10199), - (10112, 10069), - (10023, 10153), - (10012, 10003) + (10282, 10171), + (10110, 10110), + (10103, 10023), + (10022, 10017), + (10110, 10282), + (10009, 10065), + (10153, 10199), + (10112, 10069), + (10023, 10153), + (10012, 10003), ] @@ -59,7 +58,7 @@ def test_sql_execution_partial(ws, env_or_skip): (10153, 10199), (10112, 10069), (10023, 10153), - (10012, 10003) + (10012, 10003), ] @@ -83,4 +82,4 @@ def test_fetch_one_works(ws): def test_fetch_value(ws): see = StatementExecutionExt(ws) count = see.fetch_value("SELECT COUNT(*) FROM samples.nyctaxi.trips") - assert count == 21932 \ No newline at end of file + assert count == 21932 diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py index 7608f969..f6614663 100644 --- a/tests/unit/test_backends.py +++ b/tests/unit/test_backends.py @@ -2,14 +2,10 @@ import sys from dataclasses import dataclass from unittest import mock -from unittest.mock import create_autospec, MagicMock +from unittest.mock import MagicMock, create_autospec import pytest -from databricks.labs.lsql import Row -from databricks.labs.lsql.backends import StatementExecutionBackend, RuntimeBackend, MockBackend from databricks.sdk import WorkspaceClient -from databricks.sdk.service.sql import ExecuteStatementResponse, StatementStatus, StatementState, Disposition, Format, \ - ResultManifest, ResultSchema, ColumnInfo, ColumnInfoTypeName, ResultData from databricks.sdk.errors import ( BadRequest, DataLoss, @@ -17,6 +13,24 @@ PermissionDenied, Unknown, ) +from databricks.sdk.service.sql import ( + ColumnInfo, + ColumnInfoTypeName, + ExecuteStatementResponse, + Format, + ResultData, + ResultManifest, + ResultSchema, + StatementState, + StatementStatus, +) + +from databricks.labs.lsql import Row +from databricks.labs.lsql.backends import ( + MockBackend, + RuntimeBackend, + StatementExecutionBackend, +) # pylint: disable=protected-access @@ -67,9 +81,7 @@ def test_statement_execution_backend_fetch_happy(): ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest( - schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) - ), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), result=ResultData(data_array=[["1"], ["2"], ["3"]]), statement_id="bcd", ) @@ -101,7 +113,7 @@ def test_statement_execution_backend_save_table_empty_records(): ws.statement_execution.execute_statement.assert_called_with( warehouse_id="abc", statement="CREATE TABLE IF NOT EXISTS a.b.c " - "(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA", + "(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA", catalog=None, schema=None, disposition=None, @@ -122,28 +134,30 @@ def test_statement_execution_backend_save_table_two_records(): seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) - ws.statement_execution.execute_statement.assert_has_calls([ - mock.call( - warehouse_id="abc", - statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", - catalog=None, - schema=None, - disposition=None, - format=Format.JSON_ARRAY, - byte_limit=None, - wait_timeout=None, - ), - mock.call( - warehouse_id="abc", - statement="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)", - catalog=None, - schema=None, - disposition=None, - format=Format.JSON_ARRAY, - byte_limit=None, - wait_timeout=None, - ), - ]) + ws.statement_execution.execute_statement.assert_has_calls( + [ + mock.call( + warehouse_id="abc", + statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + ] + ) def test_statement_execution_backend_save_table_in_batches_of_two(mocker): @@ -157,38 +171,40 @@ def test_statement_execution_backend_save_table_in_batches_of_two(mocker): seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo) - ws.statement_execution.execute_statement.assert_has_calls([ - mock.call( - warehouse_id="abc", - statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", - catalog=None, - schema=None, - disposition=None, - format=Format.JSON_ARRAY, - byte_limit=None, - wait_timeout=None, - ), - mock.call( - warehouse_id="abc", - statement="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)", - catalog=None, - schema=None, - disposition=None, - format=Format.JSON_ARRAY, - byte_limit=None, - wait_timeout=None, - ), - mock.call( - warehouse_id="abc", - statement="INSERT INTO a.b.c (first, second) VALUES ('ccc', TRUE)", - catalog=None, - schema=None, - disposition=None, - format=Format.JSON_ARRAY, - byte_limit=None, - wait_timeout=None, - ), - ]) + ws.statement_execution.execute_statement.assert_has_calls( + [ + mock.call( + warehouse_id="abc", + statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT INTO a.b.c (first, second) VALUES ('ccc', TRUE)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + ] + ) def test_runtime_backend_execute(): @@ -270,21 +286,25 @@ def test_save_table_with_not_null_constraint_violated(): runtime_backend = RuntimeBackend() - with pytest.raises(Exception, - match="Not null constraint violated for column key, row = {'key': None, 'value': 'value'}"): + with pytest.raises( + Exception, match="Not null constraint violated for column key, row = {'key': None, 'value': 'value'}" + ): runtime_backend.save_table("a.b.c", rows, DummyClass) -@pytest.mark.parametrize('msg,err_t', [ - ("SCHEMA_NOT_FOUND foo schema does not exist", NotFound), - (".. TABLE_OR_VIEW_NOT_FOUND ..", NotFound), - (".. UNRESOLVED_COLUMN.WITH_SUGGESTION ..", BadRequest), - ("DELTA_TABLE_NOT_FOUND foo table does not exist", NotFound), - ("DELTA_MISSING_TRANSACTION_LOG foo table does not exist", DataLoss), - ("PARSE_SYNTAX_ERROR foo", BadRequest), - ("foo Operation not allowed", PermissionDenied), - ("foo error failure", Unknown), -]) +@pytest.mark.parametrize( + "msg,err_t", + [ + ("SCHEMA_NOT_FOUND foo schema does not exist", NotFound), + (".. TABLE_OR_VIEW_NOT_FOUND ..", NotFound), + (".. UNRESOLVED_COLUMN.WITH_SUGGESTION ..", BadRequest), + ("DELTA_TABLE_NOT_FOUND foo table does not exist", NotFound), + ("DELTA_MISSING_TRANSACTION_LOG foo table does not exist", DataLoss), + ("PARSE_SYNTAX_ERROR foo", BadRequest), + ("foo Operation not allowed", PermissionDenied), + ("foo error failure", Unknown), + ], +) def test_runtime_backend_error_mapping_similar_to_statement_execution(msg, err_t): with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): pyspark_sql_session = MagicMock() @@ -303,16 +323,14 @@ def test_runtime_backend_error_mapping_similar_to_statement_execution(msg, err_t def test_mock_backend_fails_on_first(): - mock_backend = MockBackend(fails_on_first={"CREATE": '.. DELTA_TABLE_NOT_FOUND ..'}) + mock_backend = MockBackend(fails_on_first={"CREATE": ".. DELTA_TABLE_NOT_FOUND .."}) with pytest.raises(NotFound): mock_backend.execute("CREATE TABLE foo") def test_mock_backend_rows(): - mock_backend = MockBackend(rows={ - r"SELECT id FROM range\(3\)": [Row(id=1), Row(id=2), Row(id=3)] - }) + mock_backend = MockBackend(rows={r"SELECT id FROM range\(3\)": [Row(id=1), Row(id=2), Row(id=3)]}) result = list(mock_backend.fetch("SELECT id FROM range(3)")) @@ -325,6 +343,6 @@ def test_mock_backend_save_table(): mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) assert mock_backend.rows_written_for("a.b.c", "append") == [ - Row(first='aaa', second=True), - Row(first='bbb', second=False), - ] \ No newline at end of file + Row(first="aaa", second=True), + Row(first="bbb", second=False), + ] diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py index d90b6c33..f4115c3d 100644 --- a/tests/unit/test_core.py +++ b/tests/unit/test_core.py @@ -1,14 +1,13 @@ import datetime -import requests from unittest.mock import create_autospec import pytest -from databricks.sdk import WorkspaceClient -from databricks.sdk import errors +import requests +from databricks.sdk import WorkspaceClient, errors from databricks.sdk.service.sql import ( ColumnInfo, ColumnInfoTypeName, - Disposition, + EndpointInfo, ExecuteStatementResponse, ExternalLink, Format, @@ -18,18 +17,22 @@ ResultSchema, ServiceError, ServiceErrorCode, + State, StatementState, StatementStatus, - timedelta, EndpointInfo, State, + timedelta, ) from databricks.labs.lsql.core import Row, StatementExecutionExt -@pytest.mark.parametrize('row', [ - Row(foo="bar", enabled=True), - Row(['foo', 'enabled'], ['bar', True]), -]) +@pytest.mark.parametrize( + "row", + [ + Row(foo="bar", enabled=True), + Row(["foo", "enabled"], ["bar", True]), + ], +) def test_row_from_kwargs(row): assert row.foo == "bar" assert row["foo"] == "bar" @@ -55,7 +58,7 @@ def test_row_factory(): def test_row_factory_with_generator(): factory = Row.factory(["a", "b"]) - row = factory(_+1 for _ in range(2)) + row = factory(_ + 1 for _ in range(2)) a, b = row assert a == 1 assert b == 2 @@ -182,17 +185,20 @@ def test_execute_poll_succeeds(): ws.statement_execution.get_statement.assert_called_with("bcd") -@pytest.mark.parametrize("status_error,platform_error_type", [ - (None, errors.Unknown), - (ServiceError(), errors.Unknown), - (ServiceError(message="..."), errors.Unknown), - (ServiceError(error_code=ServiceErrorCode.RESOURCE_EXHAUSTED, message="..."), errors.ResourceExhausted), - (ServiceError(message="... SCHEMA_NOT_FOUND ..."), errors.NotFound), - (ServiceError(message="... TABLE_OR_VIEW_NOT_FOUND ..."), errors.NotFound), - (ServiceError(message="... DELTA_TABLE_NOT_FOUND ..."), errors.NotFound), - (ServiceError(message="... DELTA_TABLE_NOT_FOUND ..."), errors.NotFound), - (ServiceError(message="... DELTA_MISSING_TRANSACTION_LOG ..."), errors.DataLoss), -]) +@pytest.mark.parametrize( + "status_error,platform_error_type", + [ + (None, errors.Unknown), + (ServiceError(), errors.Unknown), + (ServiceError(message="..."), errors.Unknown), + (ServiceError(error_code=ServiceErrorCode.RESOURCE_EXHAUSTED, message="..."), errors.ResourceExhausted), + (ServiceError(message="... SCHEMA_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... TABLE_OR_VIEW_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... DELTA_TABLE_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... DELTA_TABLE_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... DELTA_MISSING_TRANSACTION_LOG ..."), errors.DataLoss), + ], +) def test_execute_fails(status_error, platform_error_type): ws = create_autospec(WorkspaceClient) @@ -206,6 +212,7 @@ def test_execute_fails(status_error, platform_error_type): with pytest.raises(platform_error_type): see.execute("SELECT 2+2") + def test_execute_poll_waits(): ws = create_autospec(WorkspaceClient) @@ -278,9 +285,7 @@ def test_fetch_all_no_chunks(): see = StatementExecutionExt(ws, warehouse_id="abc", http_session_factory=lambda: http_session) - rows = list( - see.fetch_all("SELECT id, CAST(NOW() AS DATE) AS since, NOW() AS now FROM range(2)") - ) + rows = list(see.fetch_all("SELECT id, CAST(NOW() AS DATE) AS since, NOW() AS now FROM range(2)")) assert len(rows) == 2 assert rows[0].id == 1 @@ -358,9 +363,7 @@ def test_fetch_one(): ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest( - schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) - ), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), result=ResultData(data_array=[["4"]]), statement_id="bcd", ) @@ -388,9 +391,7 @@ def test_fetch_one_none(): ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest( - schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) - ), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), statement_id="bcd", ) @@ -406,9 +407,7 @@ def test_fetch_one_disable_magic(): ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest( - schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) - ), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), result=ResultData(data_array=[["4"], ["5"], ["6"]]), statement_id="bcd", ) @@ -436,9 +435,7 @@ def test_fetch_value(): ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest( - schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) - ), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), result=ResultData(data_array=[["4"]]), statement_id="bcd", ) @@ -455,9 +452,7 @@ def test_fetch_value_none(): ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest( - schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) - ), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), statement_id="bcd", ) @@ -473,9 +468,7 @@ def test_callable_returns_iterator(): ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest( - schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)]) - ), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), result=ResultData(data_array=[["4"], ["5"], ["6"]]), statement_id="bcd", ) @@ -485,4 +478,4 @@ def test_callable_returns_iterator(): rows = list(see("SELECT 2+2 AS id")) assert len(rows) == 3 - assert rows == [Row(id=4), Row(id=5), Row(id=6)] \ No newline at end of file + assert rows == [Row(id=4), Row(id=5), Row(id=6)] diff --git a/tests/unit/test_deployment.py b/tests/unit/test_deployment.py index 4651ff1c..cd6fcdbb 100644 --- a/tests/unit/test_deployment.py +++ b/tests/unit/test_deployment.py @@ -2,6 +2,7 @@ from databricks.labs.lsql.backends import MockBackend from databricks.labs.lsql.deployment import SchemaDeployer + from . import views @@ -13,10 +14,10 @@ def test_deploys_view(): mod=views, ) - deployment.deploy_view('some', 'some.sql') + deployment.deploy_view("some", "some.sql") assert mock_backend.queries == [ - 'CREATE OR REPLACE VIEW hive_metastore.inventory.some AS SELECT id, name FROM hive_metastore.inventory.something' + "CREATE OR REPLACE VIEW hive_metastore.inventory.some AS SELECT id, name FROM hive_metastore.inventory.something" ] @@ -34,12 +35,11 @@ def test_deploys_dataclass(): mod=views, ) deployment.deploy_schema() - deployment.deploy_table('foo', Foo) + deployment.deploy_table("foo", Foo) deployment.delete_schema() assert mock_backend.queries == [ - 'CREATE SCHEMA IF NOT EXISTS hive_metastore.inventory', - 'CREATE TABLE IF NOT EXISTS hive_metastore.inventory.foo (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA', - 'DROP SCHEMA IF EXISTS hive_metastore.inventory CASCADE' + "CREATE SCHEMA IF NOT EXISTS hive_metastore.inventory", + "CREATE TABLE IF NOT EXISTS hive_metastore.inventory.foo (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", + "DROP SCHEMA IF EXISTS hive_metastore.inventory CASCADE", ] -