diff --git a/.codegen.json b/.codegen.json new file mode 100644 index 00000000..43712dca --- /dev/null +++ b/.codegen.json @@ -0,0 +1,17 @@ +{ + "version": { + "src/databricks/labs/lsql/__about__.py": "__version__ = \"$VERSION\"" + }, + "toolchain": { + "required": ["python3"], + "pre_setup": [ + "python3 -m pip install hatch==1.7.0", + "python3 -m hatch env create" + ], + "prepend_path": ".venv/bin", + "acceptance_path": "tests/integration", + "test": [ + "pytest -n 4 --cov src --cov-report=xml --timeout 30 tests/unit --durations 20" + ] + } +} \ No newline at end of file diff --git a/.github/workflows/acceptance.yml b/.github/workflows/acceptance.yml new file mode 100644 index 00000000..495f86d6 --- /dev/null +++ b/.github/workflows/acceptance.yml @@ -0,0 +1,45 @@ +name: acceptance + +on: + pull_request: + types: [ opened, synchronize, ready_for_review ] + +permissions: + id-token: write + contents: read + pull-requests: write + +concurrency: + group: single-acceptance-job-per-repo + +jobs: + integration: + if: github.event_name == 'pull_request' && github.event.pull_request.draft == false + environment: runtime + runs-on: larger + steps: + - name: Checkout Code + uses: actions/checkout@v2.5.0 + + - name: Unshallow + run: git fetch --prune --unshallow + + - name: Install Python + uses: actions/setup-python@v4 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: '3.10' + + - name: Install hatch + run: pip install hatch==1.7.0 + + - name: Run integration tests + uses: databrickslabs/sandbox/acceptance@acceptance/v0.1.4 + with: + vault_uri: ${{ secrets.VAULT_URI }} + timeout: 45m + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + ARM_CLIENT_ID: ${{ secrets.ARM_CLIENT_ID }} + ARM_TENANT_ID: ${{ secrets.ARM_TENANT_ID }} diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index d8b0ad2b..0b54d8f7 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -14,14 +14,12 @@ on: branches: - main -env: - HATCH_VERSION: 1.7.0 - jobs: ci: strategy: + fail-fast: false matrix: - pyVersion: [ '3.8', '3.9', '3.10', '3.11', '3.12' ] + pyVersion: [ '3.10', '3.11', '3.12' ] runs-on: ubuntu-latest steps: - name: Checkout @@ -34,11 +32,10 @@ jobs: cache-dependency-path: '**/pyproject.toml' python-version: ${{ matrix.pyVersion }} - - name: Install hatch - run: pip install hatch==$HATCH_VERSION - - name: Run unit tests - run: hatch run unit:test + run: | + pip install hatch==1.7.0 + make test - name: Publish test coverage uses: codecov/codecov-action@v1 @@ -49,15 +46,8 @@ jobs: - name: Checkout uses: actions/checkout@v3 - - name: Install Python - uses: actions/setup-python@v4 - with: - cache: 'pip' - cache-dependency-path: '**/pyproject.toml' - python-version: 3.10.x - - - name: Install hatch - run: pip install hatch==$HATCH_VERSION + - name: Format all files + run: make dev fmt - - name: Verify linting - run: hatch run lint:verify + - name: Fail on differences + run: git diff --exit-code diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..d08f42cb --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,48 @@ +name: Release + +on: + push: + tags: + - 'v*' + +jobs: + publish: + runs-on: ubuntu-latest + environment: release + permissions: + # Used to authenticate to PyPI via OIDC and sign the release's artifacts with sigstore-python. + id-token: write + # Used to attach signing artifacts to the published release. + contents: write + steps: + - uses: actions/checkout@v3 + + - uses: actions/setup-python@v4 + with: + cache: 'pip' + cache-dependency-path: '**/pyproject.toml' + python-version: '3.10' + + - name: Build wheels + run: | + pip install hatch==1.7.0 + hatch build + + - name: Draft release + uses: softprops/action-gh-release@v1 + with: + draft: true + files: | + dist/databricks_*.whl + dist/databricks_*.tar.gz + + - uses: pypa/gh-action-pypi-publish@release/v1 + name: Publish package distributions to PyPI + + - name: Sign artifacts with Sigstore + uses: sigstore/gh-action-sigstore-python@v2.1.1 + with: + inputs: | + dist/databricks_*.whl + dist/databricks_*.tar.gz + release-signing-artifacts: true \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..2aa8409e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,5 @@ +# Version changelog + +## 0.0.0 + +Initial commit diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 967fdb5d..90809b5c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,3 +1,117 @@ -To setup local dev environment, you have to install `hatch` tooling: `pip install hatch`. +# Contributing -After, you have to configure your IDE with it: `hatch run python -c "import sys; print(sys.executable)" | pbcopy` \ No newline at end of file +## First Principles + +Favoring standard libraries over external dependencies, especially in specific contexts like Databricks, is a best practice in software +development. + +There are several reasons why this approach is encouraged: +- Standard libraries are typically well-vetted, thoroughly tested, and maintained by the official maintainers of the programming language or platform. This ensures a higher level of stability and reliability. +- External dependencies, especially lesser-known or unmaintained ones, can introduce bugs, security vulnerabilities, or compatibility issues that can be challenging to resolve. Adding external dependencies increases the complexity of your codebase. +- Each dependency may have its own set of dependencies, potentially leading to a complex web of dependencies that can be difficult to manage. This complexity can lead to maintenance challenges, increased risk, and longer build times. +- External dependencies can pose security risks. If a library or package has known security vulnerabilities and is widely used, it becomes an attractive target for attackers. Minimizing external dependencies reduces the potential attack surface and makes it easier to keep your code secure. +- Relying on standard libraries enhances code portability. It ensures your code can run on different platforms and environments without being tightly coupled to specific external dependencies. This is particularly important in settings like Databricks, where you may need to run your code on different clusters or setups. +- External dependencies may have their versioning schemes and compatibility issues. When using standard libraries, you have more control over versioning and can avoid conflicts between different dependencies in your project. +- Fewer external dependencies mean faster build and deployment times. Downloading, installing, and managing external packages can slow down these processes, especially in large-scale projects or distributed computing environments like Databricks. +- External dependencies can be abandoned or go unmaintained over time. This can lead to situations where your project relies on outdated or unsupported code. When you depend on standard libraries, you have confidence that the core functionality you rely on will continue to be maintained and improved. + +While minimizing external dependencies is essential, exceptions can be made case-by-case. There are situations where external dependencies are +justified, such as when a well-established and actively maintained library provides significant benefits, like time savings, performance improvements, +or specialized functionality unavailable in standard libraries. + +## Common fixes for `mypy` errors + +See https://mypy.readthedocs.io/en/stable/cheat_sheet_py3.html for more details + +### ..., expression has type "None", variable has type "str" + +* Add `assert ... is not None` if it's a body of a method. Example: + +``` +# error: Argument 1 to "delete" of "DashboardWidgetsAPI" has incompatible type "str | None"; expected "str" +self._ws.dashboard_widgets.delete(widget.id) +``` + +after + +``` +assert widget.id is not None +self._ws.dashboard_widgets.delete(widget.id) +``` + +* Add `... | None` if it's in the dataclass. Example: `cloud: str = None` -> `cloud: str | None = None` + +### ..., has incompatible type "Path"; expected "str" + +Add `.as_posix()` to convert Path to str + +### Argument 2 to "get" of "dict" has incompatible type "None"; expected ... + +Add a valid default value for the dictionary return. + +Example: +```python +def viz_type(self) -> str: + return self.viz.get("type", None) +``` + +after: + +Example: +```python +def viz_type(self) -> str: + return self.viz.get("type", "UNKNOWN") +``` + +## Local Setup + +This section provides a step-by-step guide to set up and start working on the project. These steps will help you set up your project environment and dependencies for efficient development. + +To begin, run `make dev` create the default environment and install development dependencies, assuming you've already cloned the github repo. + +```shell +make dev +``` + +Verify installation with +```shell +make test +``` + +Before every commit, apply the consistent formatting of the code, as we want our codebase look consistent: +```shell +make fmt +``` + +Before every commit, run automated bug detector (`make lint`) and unit tests (`make test`) to ensure that automated +pull request checks do pass, before your code is reviewed by others: +```shell +make test +``` + +## First contribution + +Here are the example steps to submit your first contribution: + +1. Make a Fork from ucx repo (if you really want to contribute) +2. `git clone` +3. `git checkout main` (or `gcm` if you're using [ohmyzsh](https://ohmyz.sh/)). +4. `git pull` (or `gl` if you're using [ohmyzsh](https://ohmyz.sh/)). +5. `git checkout -b FEATURENAME` (or `gcb FEATURENAME` if you're using [ohmyzsh](https://ohmyz.sh/)). +6. .. do the work +7. `make fmt` +8. `make lint` +9. .. fix if any +10. `make test` +11. .. fix if any +12. `git commit -a`. Make sure to enter meaningful commit message title. +13. `git push origin FEATURENAME` +14. Go to GitHub UI and create PR. Alternatively, `gh pr create` (if you have [GitHub CLI](https://cli.github.com/) installed). + Use a meaningful pull request title because it'll appear in the release notes. Use `Resolves #NUMBER` in pull + request description to [automatically link it](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/using-keywords-in-issues-and-pull-requests#linking-a-pull-request-to-an-issue) + to an existing issue. +15. announce PR for the review + +## Troubleshooting + +If you encounter any package dependency errors after `git pull`, run `make clean` diff --git a/LICENSE b/LICENSE index e02a93e6..c8d0d24a 100644 --- a/LICENSE +++ b/LICENSE @@ -1,25 +1,69 @@ -DB license + Databricks License + Copyright (2024) Databricks, Inc. -Copyright (2023) Databricks, Inc. + Definitions. + + Agreement: The agreement between Databricks, Inc., and you governing + the use of the Databricks Services, as that term is defined in + the Master Cloud Services Agreement (MCSA) located at + www.databricks.com/legal/mcsa. + + Licensed Materials: The source code, object code, data, and/or other + works to which this license applies. -Definitions. + Scope of Use. You may not use the Licensed Materials except in + connection with your use of the Databricks Services pursuant to + the Agreement. Your use of the Licensed Materials must comply at all + times with any restrictions applicable to the Databricks Services, + generally, and must be used in accordance with any applicable + documentation. You may view, use, copy, modify, publish, and/or + distribute the Licensed Materials solely for the purposes of using + the Licensed Materials within or connecting to the Databricks Services. + If you do not agree to these terms, you may not view, use, copy, + modify, publish, and/or distribute the Licensed Materials. + + Redistribution. You may redistribute and sublicense the Licensed + Materials so long as all use is in compliance with these terms. + In addition: + + - You must give any other recipients a copy of this License; + - You must cause any modified files to carry prominent notices + stating that you changed the files; + - You must retain, in any derivative works that you distribute, + all copyright, patent, trademark, and attribution notices, + excluding those notices that do not pertain to any part of + the derivative works; and + - If a "NOTICE" text file is provided as part of its + distribution, then any derivative works that you distribute + must include a readable copy of the attribution notices + contained within such NOTICE file, excluding those notices + that do not pertain to any part of the derivative works. -Agreement: The agreement between Databricks, Inc., and you governing the use of the Databricks Services, which shall be, with respect to Databricks, the Databricks Terms of Service located at www.databricks.com/termsofservice, and with respect to Databricks Community Edition, the Community Edition Terms of Service located at www.databricks.com/ce-termsofuse, in each case unless you have entered into a separate written agreement with Databricks governing the use of the applicable Databricks Services. + You may add your own copyright statement to your modifications and may + provide additional license terms and conditions for use, reproduction, + or distribution of your modifications, or for any such derivative works + as a whole, provided your use, reproduction, and distribution of + the Licensed Materials otherwise complies with the conditions stated + in this License. -Software: The source code and object code to which this license applies. + Termination. This license terminates automatically upon your breach of + these terms or upon the termination of your Agreement. Additionally, + Databricks may terminate this license at any time on notice. Upon + termination, you must permanently delete the Licensed Materials and + all copies thereof. -Scope of Use. You may not use this Software except in connection with your use of the Databricks Services pursuant to the Agreement. Your use of the Software must comply at all times with any restrictions applicable to the Databricks Services, generally, and must be used in accordance with any applicable documentation. You may view, use, copy, modify, publish, and/or distribute the Software solely for the purposes of using the code within or connecting to the Databricks Services. If you do not agree to these terms, you may not view, use, copy, modify, publish, and/or distribute the Software. + DISCLAIMER; LIMITATION OF LIABILITY. -Redistribution. You may redistribute and sublicense the Software so long as all use is in compliance with these terms. In addition: - -You must give any other recipients a copy of this License; -You must cause any modified files to carry prominent notices stating that you changed the files; -You must retain, in the source code form of any derivative works that you distribute, all copyright, patent, trademark, and attribution notices from the source code form, excluding those notices that do not pertain to any part of the derivative works; and -If the source code form includes a "NOTICE" text file as part of its distribution, then any derivative works that you distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the derivative works. -You may add your own copyright statement to your modifications and may provide additional license terms and conditions for use, reproduction, or distribution of your modifications, or for any such derivative works as a whole, provided your use, reproduction, and distribution of the Software otherwise complies with the conditions stated in this License. - -Termination. This license terminates automatically upon your breach of these terms or upon the termination of your Agreement. Additionally, Databricks may terminate this license at any time on notice. Upon termination, you must permanently delete the Software and all copies thereof. - -DISCLAIMER; LIMITATION OF LIABILITY. - -THE SOFTWARE IS PROVIDED “AS-IS” AND WITH ALL FAULTS. DATABRICKS, ON BEHALF OF ITSELF AND ITS LICENSORS, SPECIFICALLY DISCLAIMS ALL WARRANTIES RELATING TO THE SOURCE CODE, EXPRESS AND IMPLIED, INCLUDING, WITHOUT LIMITATION, IMPLIED WARRANTIES, CONDITIONS AND OTHER TERMS OF MERCHANTABILITY, SATISFACTORY QUALITY OR FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. DATABRICKS AND ITS LICENSORS TOTAL AGGREGATE LIABILITY RELATING TO OR ARISING OUT OF YOUR USE OF OR DATABRICKS’ PROVISIONING OF THE SOURCE CODE SHALL BE LIMITED TO ONE THOUSAND ($1,000) DOLLARS. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. \ No newline at end of file + THE LICENSED MATERIALS ARE PROVIDED “AS-IS” AND WITH ALL FAULTS. + DATABRICKS, ON BEHALF OF ITSELF AND ITS LICENSORS, SPECIFICALLY + DISCLAIMS ALL WARRANTIES RELATING TO THE LICENSED MATERIALS, EXPRESS + AND IMPLIED, INCLUDING, WITHOUT LIMITATION, IMPLIED WARRANTIES, + CONDITIONS AND OTHER TERMS OF MERCHANTABILITY, SATISFACTORY QUALITY OR + FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. DATABRICKS AND + ITS LICENSORS TOTAL AGGREGATE LIABILITY RELATING TO OR ARISING OUT OF + YOUR USE OF OR DATABRICKS’ PROVISIONING OF THE LICENSED MATERIALS SHALL + BE LIMITED TO ONE THOUSAND ($1,000) DOLLARS. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR + OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, + ARISING FROM, OUT OF OR IN CONNECTION WITH THE LICENSED MATERIALS OR + THE USE OR OTHER DEALINGS IN THE LICENSED MATERIALS. diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..c2b0aa79 --- /dev/null +++ b/Makefile @@ -0,0 +1,28 @@ +all: clean lint fmt test coverage + +clean: + rm -fr .venv clean htmlcov .mypy_cache .pytest_cache .ruff_cache .coverage coverage.xml + rm -fr **/*.pyc + +.venv/bin/python: + pip install hatch==1.7.0 + hatch env create + +dev: .venv/bin/python + @hatch run which python + +lint: + hatch run verify + +fmt: + hatch run fmt + +test: + hatch run test + +integration: + hatch run integration + +coverage: + hatch run coverage && open htmlcov/index.html + diff --git a/NOTICE b/NOTICE index 04a26a7c..605bb201 100644 --- a/NOTICE +++ b/NOTICE @@ -1,4 +1,4 @@ -Copyright (2023) Databricks, Inc. +Copyright (2024) Databricks, Inc. This Software includes software developed at Databricks (https://www.databricks.com/) and its use is subject to the included LICENSE file. diff --git a/pyproject.toml b/pyproject.toml index 0de2dbe6..c39b6e4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,7 @@ -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - [project] name = "databricks-labs-lsql" dynamic = ["version"] -description = 'Lightweight stateless SQL execution for Databricks with minimal dependencies' +description = "Lightweight stateless SQL execution for Databricks with minimal dependencies" readme = "README.md" requires-python = ">=3.10" license-files = { paths = ["LICENSE", "NOTICE"] } @@ -15,17 +11,16 @@ authors = [ ] classifiers = [ "License :: Other/Proprietary License", - "Development Status :: 4 - Beta", + "Development Status :: 3 - Alpha", "Programming Language :: Python", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] dependencies = [ + "databricks-labs-blueprint~=0.2.5", "databricks-sdk~=0.21.0", "PyYAML>=6.0.0,<7.0.0", "sqlglot~=22.2.1" @@ -36,127 +31,648 @@ Documentation = "https://github.com/databrickslabs/databricks-labs-lsql#readme" Issues = "https://github.com/databrickslabs/databricks-labs-lsql/issues" Source = "https://github.com/databrickslabs/databricks-labs-lsql" +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build] +sources = ["src"] +include = ["src"] + [tool.hatch.version] path = "src/databricks/labs/lsql/__about__.py" [tool.hatch.envs.default] -python = "python3" dependencies = [ - "coverage[toml]>=6.5", - "pytest", + "coverage[toml]>=6.5", + "pytest", + "pylint", + "pytest-xdist", + "pytest-cov>=4.0.0,<5.0.0", + "pytest-mock>=3.0.0,<4.0.0", + "pytest-timeout", + "ruff>=0.0.243", + "isort>=2.5.0", + "mypy", + "types-PyYAML", + "types-requests", ] + +python="3.10" + +# store virtual env as the child of this folder. Helps VSCode (and PyCharm) to run better +path = ".venv" + [tool.hatch.envs.default.scripts] -test = "pytest {args:tests}" -test-cov = "coverage run -m pytest {args:tests}" -cov-report = [ - "- coverage combine", - "coverage report", -] -cov = [ - "test-cov", - "cov-report", -] +test = "pytest -n 2 --cov src --cov-report=xml --timeout 30 tests/unit --durations 20" +coverage = "pytest -n 2 --cov src tests/unit --timeout 30 --cov-report=html --durations 20" +integration = "pytest -n 10 --cov src tests/integration --durations 20" +fmt = ["isort .", + "ruff format", + "ruff . --fix", + "mypy .", + "pylint --output-format=colorized -j 0 src"] +verify = ["black --check .", + "isort . --check-only", + "ruff .", + "mypy .", + "pylint --output-format=colorized -j 0 src"] -[[tool.hatch.envs.all.matrix]] -python = ["3.7", "3.8", "3.9", "3.10", "3.11"] +[tool.isort] +profile = "black" -[tool.hatch.envs.lint] -detached = true -dependencies = [ - "black>=23.1.0", - "mypy>=1.0.0", - "ruff>=0.0.243", -] -[tool.hatch.envs.lint.scripts] -typing = "mypy --install-types --non-interactive {args:src/databricks/labs/lsql tests}" -style = [ - "ruff {args:.}", - "black --check --diff {args:.}", -] -fmt = [ - "black {args:.}", - "ruff --fix {args:.}", - "style", -] -all = [ - "style", - "typing", -] +[tool.pytest.ini_options] +addopts = "--no-header" +cache_dir = ".venv/pytest-cache" [tool.black] -target-version = ["py37"] +target-version = ["py310"] line-length = 120 skip-string-normalization = true [tool.ruff] -target-version = "py37" +cache-dir = ".venv/ruff-cache" +target-version = "py310" line-length = 120 -select = [ - "A", - "ARG", - "B", - "C", - "DTZ", - "E", - "EM", - "F", - "FBT", - "I", - "ICN", - "ISC", - "N", - "PLC", - "PLE", - "PLR", - "PLW", - "Q", - "RUF", - "S", - "T", - "TID", - "UP", - "W", - "YTT", -] -ignore = [ - # Allow non-abstract empty methods in abstract base classes - "B027", - # Allow boolean positional values in function calls, like `dict.get(... True)` - "FBT003", - # Ignore checks for possible passwords - "S105", "S106", "S107", - # Ignore complexity - "C901", "PLR0911", "PLR0912", "PLR0913", "PLR0915", -] -unfixable = [ - # Don't touch unused imports - "F401", -] [tool.ruff.isort] -known-first-party = ["databricks.labs"] - -[tool.ruff.flake8-tidy-imports] -ban-relative-imports = "all" - -[tool.ruff.per-file-ignores] -# Tests can use magic values, assertions, and relative imports -"tests/**/*" = ["PLR2004", "S101", "TID252"] +known-first-party = ["databricks.labs.blueprint"] [tool.coverage.run] branch = true parallel = true -omit = [ - "src/databricks/labs/lsql/__about__.py", -] - -[tool.coverage.paths] -databricks_labs_lsql = ["src/databricks/labs/lsql", "*/databricks/labs/lsql/src/databricks/labs/lsql"] -tests = ["tests", "*/databricks/labs/lsql/tests"] [tool.coverage.report] +omit = ["*/working-copy/*", 'src/databricks/labs/blueprint/__main__.py'] exclude_lines = [ - "no cov", - "if __name__ == .__main__.:", - "if TYPE_CHECKING:", + "no cov", + "if __name__ == .__main__.:", + "if TYPE_CHECKING:", ] + +[tool.pylint.main] +# PyLint configuration is adapted from Google Python Style Guide with modifications. +# Sources https://google.github.io/styleguide/pylintrc +# License: https://github.com/google/styleguide/blob/gh-pages/LICENSE + +# Analyse import fallback blocks. This can be used to support both Python 2 and 3 +# compatible code, which means that the block might have code that exists only in +# one or another interpreter, leading to false positives when analysed. +# analyse-fallback-blocks = + +# Clear in-memory caches upon conclusion of linting. Useful if running pylint in +# a server-like mode. +# clear-cache-post-run = + +# Always return a 0 (non-error) status code, even if lint errors are found. This +# is primarily useful in continuous integration scripts. +# exit-zero = + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. +# extension-pkg-allow-list = + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code. (This is an alternative name to extension-pkg-allow-list +# for backward compatibility.) +# extension-pkg-whitelist = + +# Return non-zero exit code if any of these messages/categories are detected, +# even if score is above --fail-under value. Syntax same as enable. Messages +# specified are enabled, while categories only check already-enabled messages. +# fail-on = + +# Specify a score threshold under which the program will exit with error. +fail-under = 10.0 + +# Interpret the stdin as a python script, whose filename needs to be passed as +# the module_or_package argument. +# from-stdin = + +# Add files or directories matching the regular expressions patterns to the +# ignore-list. The regex matches against paths and can be in Posix or Windows +# format. Because '\\' represents the directory delimiter on Windows systems, it +# can't be used as an escape character. +# ignore-paths = + +# Files or directories matching the regular expression patterns are skipped. The +# regex matches against base names, not paths. The default value ignores Emacs +# file locks +ignore-patterns = ["^\\.#"] + +# List of module names for which member attributes should not be checked (useful +# for modules/projects where namespaces are manipulated during runtime and thus +# existing member attributes cannot be deduced by static analysis). It supports +# qualified module names, as well as Unix pattern matching. +# ignored-modules = + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +# init-hook = + +# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the +# number of processors available to use, and will cap the count on Windows to +# avoid hangs. +# jobs = + +# Control the amount of potential inferred values when inferring a single object. +# This can help the performance when dealing with large functions or complex, +# nested conditions. +limit-inference-results = 100 + +# List of plugins (as comma separated values of python module names) to load, +# usually to register additional checkers. +load-plugins = ["pylint.extensions.check_elif", "pylint.extensions.bad_builtin", "pylint.extensions.docparams", "pylint.extensions.for_any_all", "pylint.extensions.set_membership", "pylint.extensions.code_style", "pylint.extensions.overlapping_exceptions", "pylint.extensions.typing", "pylint.extensions.redefined_variable_type", "pylint.extensions.comparison_placement", "pylint.extensions.broad_try_clause", "pylint.extensions.dict_init_mutate", "pylint.extensions.consider_refactoring_into_while_condition"] + +# Pickle collected data for later comparisons. +persistent = true + +# Minimum Python version to use for version dependent checks. Will default to the +# version used to run pylint. +py-version = "3.10" + +# Discover python modules and packages in the file system subtree. +# recursive = + +# Add paths to the list of the source roots. Supports globbing patterns. The +# source root is an absolute path or a path relative to the current working +# directory used to determine a package namespace for modules located under the +# source root. +# source-roots = + +# When enabled, pylint would attempt to guess common misconfiguration and emit +# user-friendly hints instead of false-positive error messages. +suggestion-mode = true + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +# unsafe-load-any-extension = + +[tool.pylint.basic] +# Naming style matching correct argument names. +argument-naming-style = "snake_case" + +# Regular expression matching correct argument names. Overrides argument-naming- +# style. If left empty, argument names will be checked with the set naming style. +argument-rgx = "[a-z_][a-z0-9_]{2,30}$" + +# Naming style matching correct attribute names. +attr-naming-style = "snake_case" + +# Regular expression matching correct attribute names. Overrides attr-naming- +# style. If left empty, attribute names will be checked with the set naming +# style. +attr-rgx = "[a-z_][a-z0-9_]{2,}$" + +# Bad variable names which should always be refused, separated by a comma. +bad-names = ["foo", "bar", "baz", "toto", "tutu", "tata"] + +# Bad variable names regexes, separated by a comma. If names match any regex, +# they will always be refused +# bad-names-rgxs = + +# Naming style matching correct class attribute names. +class-attribute-naming-style = "any" + +# Regular expression matching correct class attribute names. Overrides class- +# attribute-naming-style. If left empty, class attribute names will be checked +# with the set naming style. +class-attribute-rgx = "([A-Za-z_][A-Za-z0-9_]{2,30}|(__.*__))$" + +# Naming style matching correct class constant names. +class-const-naming-style = "UPPER_CASE" + +# Regular expression matching correct class constant names. Overrides class- +# const-naming-style. If left empty, class constant names will be checked with +# the set naming style. +# class-const-rgx = + +# Naming style matching correct class names. +class-naming-style = "PascalCase" + +# Regular expression matching correct class names. Overrides class-naming-style. +# If left empty, class names will be checked with the set naming style. +class-rgx = "[A-Z_][a-zA-Z0-9]+$" + +# Naming style matching correct constant names. +const-naming-style = "UPPER_CASE" + +# Regular expression matching correct constant names. Overrides const-naming- +# style. If left empty, constant names will be checked with the set naming style. +const-rgx = "(([A-Z_][A-Z0-9_]*)|(__.*__))$" + +# Minimum line length for functions/classes that require docstrings, shorter ones +# are exempt. +docstring-min-length = -1 + +# Naming style matching correct function names. +function-naming-style = "snake_case" + +# Regular expression matching correct function names. Overrides function-naming- +# style. If left empty, function names will be checked with the set naming style. +function-rgx = "[a-z_][a-z0-9_]{2,30}$" + +# Good variable names which should always be accepted, separated by a comma. +good-names = ["i", "j", "k", "ex", "Run", "_"] + +# Good variable names regexes, separated by a comma. If names match any regex, +# they will always be accepted +# good-names-rgxs = + +# Include a hint for the correct naming format with invalid-name. +# include-naming-hint = + +# Naming style matching correct inline iteration names. +inlinevar-naming-style = "any" + +# Regular expression matching correct inline iteration names. Overrides +# inlinevar-naming-style. If left empty, inline iteration names will be checked +# with the set naming style. +inlinevar-rgx = "[A-Za-z_][A-Za-z0-9_]*$" + +# Naming style matching correct method names. +method-naming-style = "snake_case" + +# Regular expression matching correct method names. Overrides method-naming- +# style. If left empty, method names will be checked with the set naming style. +method-rgx = "[a-z_][a-z0-9_]{2,}$" + +# Naming style matching correct module names. +module-naming-style = "snake_case" + +# Regular expression matching correct module names. Overrides module-naming- +# style. If left empty, module names will be checked with the set naming style. +module-rgx = "(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$" + +# Colon-delimited sets of names that determine each other's naming style when the +# name regexes allow several styles. +# name-group = + +# Regular expression which should only match function or class names that do not +# require a docstring. +no-docstring-rgx = "__.*__" + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. These +# decorators are taken in consideration only for invalid-name. +property-classes = ["abc.abstractproperty"] + +# Regular expression matching correct type alias names. If left empty, type alias +# names will be checked with the set naming style. +# typealias-rgx = + +# Regular expression matching correct type variable names. If left empty, type +# variable names will be checked with the set naming style. +# typevar-rgx = + +# Naming style matching correct variable names. +variable-naming-style = "snake_case" + +# Regular expression matching correct variable names. Overrides variable-naming- +# style. If left empty, variable names will be checked with the set naming style. +variable-rgx = "[a-z_][a-z0-9_]{2,30}$" + +[tool.pylint.broad_try_clause] +# Maximum number of statements allowed in a try clause +max-try-statements = 7 + +[tool.pylint.classes] +# Warn about protected attribute access inside special methods +# check-protected-access-in-special-methods = + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods = ["__init__", "__new__", "setUp", "__post_init__"] + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected = ["_asdict", "_fields", "_replace", "_source", "_make"] + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg = ["cls"] + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg = ["mcs"] + +[tool.pylint.deprecated_builtins] +# List of builtins function names that should not be used, separated by a comma +bad-functions = ["map", "input"] + +[tool.pylint.design] +# List of regular expressions of class ancestor names to ignore when counting +# public methods (see R0903) +# exclude-too-few-public-methods = + +# List of qualified class names to ignore when counting class parents (see R0901) +# ignored-parents = + +# Maximum number of arguments for function / method. +max-args = 9 + +# Maximum number of attributes for a class (see R0902). +max-attributes = 11 + +# Maximum number of boolean expressions in an if statement (see R0916). +max-bool-expr = 5 + +# Maximum number of branch for function / method body. +max-branches = 20 + +# Maximum number of locals for function / method body. +max-locals = 19 + +# Maximum number of parents for a class (see R0901). +max-parents = 7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods = 20 + +# Maximum number of return / yield for function / method body. +max-returns = 11 + +# Maximum number of statements in function / method body. +max-statements = 50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods = 2 + +[tool.pylint.exceptions] +# Exceptions that will emit a warning when caught. +overgeneral-exceptions = ["builtins.Exception"] + +[tool.pylint.format] +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +# expected-line-ending-format = + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines = "^\\s*(# )??$" + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren = 4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string = " " + +# Maximum number of characters on a single line. +max-line-length = 100 + +# Maximum number of lines in a module. +max-module-lines = 2000 + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +# single-line-class-stmt = + +# Allow the body of an if to be on the same line as the test if there is no else. +# single-line-if-stmt = + +[tool.pylint.imports] +# List of modules that can be imported at any level, not just the top level one. +# allow-any-import-level = + +# Allow explicit reexports by alias from a package __init__. +# allow-reexport-from-package = + +# Allow wildcard imports from modules that define __all__. +# allow-wildcard-with-all = + +# Deprecated modules which should not be used, separated by a comma. +deprecated-modules = ["regsub", "TERMIOS", "Bastion", "rexec"] + +# Output a graph (.gv or any supported image format) of external dependencies to +# the given file (report RP0402 must not be disabled). +# ext-import-graph = + +# Output a graph (.gv or any supported image format) of all (i.e. internal and +# external) dependencies to the given file (report RP0402 must not be disabled). +# import-graph = + +# Output a graph (.gv or any supported image format) of internal dependencies to +# the given file (report RP0402 must not be disabled). +# int-import-graph = + +# Force import order to recognize a module as part of the standard compatibility +# libraries. +# known-standard-library = + +# Force import order to recognize a module as part of a third party library. +known-third-party = ["enchant"] + +# Couples of modules and preferred modules, separated by a comma. +# preferred-modules = + +[tool.pylint.logging] +# The type of string formatting that logging methods do. `old` means using % +# formatting, `new` is for `{}` formatting. +logging-format-style = "old" + +# Logging modules to check that the string format arguments are in logging +# function parameter format. +logging-modules = ["logging"] + +[tool.pylint."messages control"] +# Only show warnings with the listed confidence levels. Leave empty to show all. +# Valid levels: HIGH, CONTROL_FLOW, INFERENCE, INFERENCE_FAILURE, UNDEFINED. +confidence = ["HIGH", "CONTROL_FLOW", "INFERENCE", "INFERENCE_FAILURE", "UNDEFINED"] + +# Disable the message, report, category or checker with the given id(s). You can +# either give multiple identifiers separated by comma (,) or put this option +# multiple times (only on the command line, not in the configuration file where +# it should appear only once). You can also use "--disable=all" to disable +# everything first and then re-enable specific checks. For example, if you want +# to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use "--disable=all --enable=classes +# --disable=W". +disable = ["raw-checker-failed", "bad-inline-option", "locally-disabled", "file-ignored", "suppressed-message", "deprecated-pragma", "use-implicit-booleaness-not-comparison-to-string", "use-implicit-booleaness-not-comparison-to-zero", "consider-using-augmented-assign", "prefer-typing-namedtuple", "attribute-defined-outside-init", "invalid-name", "missing-module-docstring", "missing-class-docstring", "missing-function-docstring", "protected-access", "too-few-public-methods", "line-too-long", "too-many-lines", "trailing-whitespace", "missing-final-newline", "trailing-newlines", "bad-indentation", "unnecessary-semicolon", "multiple-statements", "superfluous-parens", "mixed-line-endings", "unexpected-line-ending-format", "fixme", "consider-using-assignment-expr", "logging-fstring-interpolation", "consider-using-any-or-all"] + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where it +# should appear only once). See also the "--disable" option for examples. +enable = ["useless-suppression", "use-symbolic-message-instead"] + +[tool.pylint.method_args] +# List of qualified names (i.e., library.method) which require a timeout +# parameter e.g. 'requests.api.get,requests.api.post' +timeout-methods = ["requests.api.delete", "requests.api.get", "requests.api.head", "requests.api.options", "requests.api.patch", "requests.api.post", "requests.api.put", "requests.api.request"] + +[tool.pylint.miscellaneous] +# List of note tags to take in consideration, separated by a comma. +notes = ["FIXME", "XXX", "TODO"] + +# Regular expression of note tags to take in consideration. +# notes-rgx = + +[tool.pylint.parameter_documentation] +# Whether to accept totally missing parameter documentation in the docstring of a +# function that has parameters. +accept-no-param-doc = true + +# Whether to accept totally missing raises documentation in the docstring of a +# function that raises an exception. +accept-no-raise-doc = true + +# Whether to accept totally missing return documentation in the docstring of a +# function that returns a statement. +accept-no-return-doc = true + +# Whether to accept totally missing yields documentation in the docstring of a +# generator. +accept-no-yields-doc = true + +# If the docstring type cannot be guessed the specified docstring type will be +# used. +default-docstring-type = "default" + +[tool.pylint.refactoring] +# Maximum number of nested blocks for function / method body +max-nested-blocks = 5 + +# Complete name of functions that never returns. When checking for inconsistent- +# return-statements if a never returning function is called then it will be +# considered as an explicit return statement and no message will be printed. +never-returning-functions = ["sys.exit", "argparse.parse_error"] + +[tool.pylint.reports] +# Python expression which should return a score less than or equal to 10. You +# have access to the variables 'fatal', 'error', 'warning', 'refactor', +# 'convention', and 'info' which contain the number of messages in each category, +# as well as 'statement' which is the total number of statements analyzed. This +# score is used by the global evaluation report (RP0004). +evaluation = "max(0, 0 if fatal else 10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10))" + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details. +# msg-template = + +# Set the output format. Available formats are: text, parseable, colorized, json2 +# (improved json format), json (old json format) and msvs (visual studio). You +# can also give a reporter class, e.g. mypackage.mymodule.MyReporterClass. +# output-format = + +# Tells whether to display a full report or only the messages. +# reports = + +# Activate the evaluation score. +score = true + +[tool.pylint.similarities] +# Comments are removed from the similarity computation +ignore-comments = true + +# Docstrings are removed from the similarity computation +ignore-docstrings = true + +# Imports are removed from the similarity computation +ignore-imports = true + +# Signatures are removed from the similarity computation +ignore-signatures = true + +# Minimum lines number of a similarity. +min-similarity-lines = 6 + +[tool.pylint.spelling] +# Limits count of emitted suggestions for spelling mistakes. +max-spelling-suggestions = 2 + +# Spelling dictionary name. No available dictionaries : You need to install both +# the python package and the system dependency for enchant to work. +# spelling-dict = + +# List of comma separated words that should be considered directives if they +# appear at the beginning of a comment and should not be checked. +spelling-ignore-comment-directives = "fmt: on,fmt: off,noqa:,noqa,nosec,isort:skip,mypy:,pragma:,# noinspection" + +# List of comma separated words that should not be checked. +# spelling-ignore-words = + +# A path to a file that contains the private dictionary; one word per line. +spelling-private-dict-file = ".pyenchant_pylint_custom_dict.txt" + +# Tells whether to store unknown words to the private dictionary (see the +# --spelling-private-dict-file option) instead of raising a message. +# spelling-store-unknown-words = + +[tool.pylint.typecheck] +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators = ["contextlib.contextmanager"] + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members = "REQUEST,acl_users,aq_parent,argparse.Namespace" + +# Tells whether missing members accessed in mixin class should be ignored. A +# class is considered mixin if its name matches the mixin-class-rgx option. +# Tells whether to warn about missing members when the owner of the attribute is +# inferred to be None. +ignore-none = true + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference can +# return multiple potential results while evaluating a Python object, but some +# branches might not be evaluated, which results in partial inference. In that +# case, it might be useful to still emit no-member and other checks for the rest +# of the inferred objects. +ignore-on-opaque-inference = true + +# List of symbolic message names to ignore for Mixin members. +ignored-checks-for-mixins = ["no-member", "not-async-context-manager", "not-context-manager", "attribute-defined-outside-init"] + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes = ["SQLObject", "optparse.Values", "thread._local", "_thread._local"] + +# Show a hint with possible names when a member name was not found. The aspect of +# finding the hint is based on edit distance. +missing-member-hint = true + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance = 1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices = 1 + +# Regex pattern to define which classes are considered mixins. +mixin-class-rgx = ".*MixIn" + +# List of decorators that change the signature of a decorated function. +# signature-mutators = + +[tool.pylint.variables] +# List of additional names supposed to be defined in builtins. Remember that you +# should avoid defining new builtins when possible. +# additional-builtins = + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables = true + +# List of names allowed to shadow builtins +# allowed-redefined-builtins = + +# List of strings which can identify a callback function by name. A callback name +# must start or end with one of those strings. +callbacks = ["cb_", "_cb"] + +# A regular expression matching the name of dummy variables (i.e. expected to not +# be used). +dummy-variables-rgx = "_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_" + +# Argument names that match this expression will be ignored. +ignored-argument-names = "_.*|^ignored_|^unused_" + +# Tells whether we should check for unused import in __init__ files. +# init-import = + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules = ["six.moves", "past.builtins", "future.builtins", "builtins", "io"] diff --git a/src/databricks/labs/lsql/__about__.py b/src/databricks/labs/lsql/__about__.py index 97688a1d..6c8e6b97 100644 --- a/src/databricks/labs/lsql/__about__.py +++ b/src/databricks/labs/lsql/__about__.py @@ -1,4 +1 @@ -# SPDX-FileCopyrightText: 2023-present Serge Smertin -# -# SPDX-License-Identifier: MIT -__version__ = "0.0.1" +__version__ = "0.0.0" diff --git a/src/databricks/labs/lsql/__init__.py b/src/databricks/labs/lsql/__init__.py index 013f83dc..ba43b3c2 100644 --- a/src/databricks/labs/lsql/__init__.py +++ b/src/databricks/labs/lsql/__init__.py @@ -1 +1,3 @@ -from .lib import Row \ No newline at end of file +from .core import Row + +__all__ = ["Row"] diff --git a/src/databricks/labs/lsql/__main__.py b/src/databricks/labs/lsql/__main__.py index dbcf98df..e69de29b 100644 --- a/src/databricks/labs/lsql/__main__.py +++ b/src/databricks/labs/lsql/__main__.py @@ -1 +0,0 @@ -from databricks.sdk import Config diff --git a/src/databricks/labs/lsql/backends.py b/src/databricks/labs/lsql/backends.py new file mode 100644 index 00000000..d9d0c64b --- /dev/null +++ b/src/databricks/labs/lsql/backends.py @@ -0,0 +1,298 @@ +import dataclasses +import logging +import os +import re +from abc import ABC, abstractmethod +from collections.abc import Callable, Iterable, Iterator, Sequence +from types import UnionType +from typing import Any, ClassVar, Protocol, TypeVar + +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import ( + BadRequest, + DatabricksError, + DataLoss, + NotFound, + PermissionDenied, + Unknown, +) + +from databricks.labs.lsql.core import Row, StatementExecutionExt + +logger = logging.getLogger(__name__) + + +class DataclassInstance(Protocol): + __dataclass_fields__: ClassVar[dict] + + +Result = TypeVar("Result", bound=DataclassInstance) +Dataclass = type[DataclassInstance] +ResultFn = Callable[[], Iterable[Result]] + + +class SqlBackend(ABC): + @abstractmethod + def execute(self, sql: str) -> None: + raise NotImplementedError + + @abstractmethod + def fetch(self, sql: str) -> Iterator[Any]: + raise NotImplementedError + + @abstractmethod + def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"): + raise NotImplementedError + + def create_table(self, full_name: str, klass: Dataclass): + ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._schema_for(klass)}) USING DELTA" + self.execute(ddl) + + _builtin_type_mapping: ClassVar[dict[type, str]] = { + str: "STRING", + int: "LONG", + bool: "BOOLEAN", + float: "FLOAT", + } + + @classmethod + def _schema_for(cls, klass: Dataclass): + fields = [] + for f in dataclasses.fields(klass): + field_type = f.type + if isinstance(field_type, UnionType): + field_type = field_type.__args__[0] + if field_type not in cls._builtin_type_mapping: + msg = f"Cannot auto-convert {field_type}" + raise SyntaxError(msg) + not_null = " NOT NULL" + if f.default is None: + not_null = "" + spark_type = cls._builtin_type_mapping[field_type] + fields.append(f"{f.name} {spark_type}{not_null}") + return ", ".join(fields) + + @classmethod + def _filter_none_rows(cls, rows, klass): + if len(rows) == 0: + return rows + + results = [] + class_fields = dataclasses.fields(klass) + for row in rows: + if row is None: + continue + for field in class_fields: + if not hasattr(row, field.name): + logger.debug(f"Field {field.name} not present in row {dataclasses.asdict(row)}") + continue + if field.default is not None and getattr(row, field.name) is None: + msg = f"Not null constraint violated for column {field.name}, row = {dataclasses.asdict(row)}" + raise ValueError(msg) + results.append(row) + return results + + _whitespace = re.compile(r"\s{2,}") + + @classmethod + def _only_n_bytes(cls, j: str, num_bytes: int = 96) -> str: + j = cls._whitespace.sub(" ", j) + diff = len(j.encode("utf-8")) - num_bytes + if diff > 0: + return f"{j[:num_bytes]}... ({diff} more bytes)" + return j + + @staticmethod + def _api_error_from_message(error_message: str) -> DatabricksError: + if "SCHEMA_NOT_FOUND" in error_message: + return NotFound(error_message) + if "TABLE_OR_VIEW_NOT_FOUND" in error_message: + return NotFound(error_message) + if "DELTA_TABLE_NOT_FOUND" in error_message: + return NotFound(error_message) + if "DELTA_MISSING_TRANSACTION_LOG" in error_message: + return DataLoss(error_message) + if "UNRESOLVED_COLUMN.WITH_SUGGESTION" in error_message: + return BadRequest(error_message) + if "PARSE_SYNTAX_ERROR" in error_message: + return BadRequest(error_message) + if "Operation not allowed" in error_message: + return PermissionDenied(error_message) + return Unknown(error_message) + + +class StatementExecutionBackend(SqlBackend): + def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: int = 1000): + self._sql = StatementExecutionExt(ws, warehouse_id=warehouse_id) + self._warehouse_id = warehouse_id + self._max_records_per_batch = max_records_per_batch + debug_truncate_bytes = ws.config.debug_truncate_bytes + # while unit-testing, this value will contain a mock + self._debug_truncate_bytes = debug_truncate_bytes if isinstance(debug_truncate_bytes, int) else 96 + + def execute(self, sql: str) -> None: + logger.debug(f"[api][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") + self._sql.execute(sql) + + def fetch(self, sql: str) -> Iterator[Row]: + logger.debug(f"[api][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") + return self._sql.fetch_all(sql) + + def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode="append"): + if mode == "overwrite": + msg = "Overwrite mode is not yet supported" + raise NotImplementedError(msg) + rows = self._filter_none_rows(rows, klass) + self.create_table(full_name, klass) + if len(rows) == 0: + return + fields = dataclasses.fields(klass) + field_names = [f.name for f in fields] + for i in range(0, len(rows), self._max_records_per_batch): + batch = rows[i : i + self._max_records_per_batch] + vals = "), (".join(self._row_to_sql(r, fields) for r in batch) + sql = f'INSERT INTO {full_name} ({", ".join(field_names)}) VALUES ({vals})' + self.execute(sql) + + @staticmethod + def _row_to_sql(row: DataclassInstance, fields: tuple[dataclasses.Field[Any], ...]): + data = [] + for f in fields: + value = getattr(row, f.name) + field_type = f.type + if isinstance(field_type, UnionType): + field_type = field_type.__args__[0] + if value is None: + data.append("NULL") + elif field_type == bool: + data.append("TRUE" if value else "FALSE") + elif field_type == str: + value = str(value).replace("'", "''") + data.append(f"'{value}'") + elif field_type == int: + data.append(f"{value}") + else: + msg = f"unknown type: {field_type}" + raise ValueError(msg) + return ", ".join(data) + + +class _SparkBackend(SqlBackend): + def __init__(self, spark, debug_truncate_bytes): + self._spark = spark + self._debug_truncate_bytes = debug_truncate_bytes if debug_truncate_bytes is not None else 96 + + def execute(self, sql: str) -> None: + logger.debug(f"[spark][execute] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") + try: + self._spark.sql(sql) + except Exception as e: + error_message = str(e) + raise self._api_error_from_message(error_message) from None + + def fetch(self, sql: str) -> Iterator[Row]: + logger.debug(f"[spark][fetch] {self._only_n_bytes(sql, self._debug_truncate_bytes)}") + try: + return self._spark.sql(sql).collect() + except Exception as e: + error_message = str(e) + raise self._api_error_from_message(error_message) from None + + def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"): + rows = self._filter_none_rows(rows, klass) + + if len(rows) == 0: + self.create_table(full_name, klass) + return + # pyspark deals well with lists of dataclass instances, as long as schema is provided + df = self._spark.createDataFrame(rows, self._schema_for(klass)) + df.write.saveAsTable(full_name, mode=mode) + + +class RuntimeBackend(_SparkBackend): + def __init__(self, debug_truncate_bytes: int | None = None): + if "DATABRICKS_RUNTIME_VERSION" not in os.environ: + msg = "Not in the Databricks Runtime" + raise RuntimeError(msg) + try: + # pylint: disable-next=import-error,import-outside-toplevel,useless-suppression + from pyspark.sql.session import ( # type: ignore[import-not-found] + SparkSession, + ) + + super().__init__(SparkSession.builder.getOrCreate(), debug_truncate_bytes) + except ImportError as e: + raise RuntimeError("pyspark is not available") from e + + +class DatabricksConnectBackend(_SparkBackend): + def __init__(self, ws: WorkspaceClient): + try: + # pylint: disable-next=import-outside-toplevel + from databricks.connect import ( # type: ignore[import-untyped] + DatabricksSession, + ) + + spark = DatabricksSession.builder().sdk_config(ws.config).getOrCreate() + super().__init__(spark, ws.config.debug_truncate_bytes) + except ImportError as e: + raise RuntimeError("Please run `pip install databricks-connect`") from e + + +class MockBackend(SqlBackend): + def __init__( + self, *, fails_on_first: dict[str, str] | None = None, rows: dict | None = None, debug_truncate_bytes=96 + ): + self._fails_on_first = fails_on_first + if not rows: + rows = {} + self._rows = rows + self._save_table: list[tuple[str, Sequence[DataclassInstance], str]] = [] + self._debug_truncate_bytes = debug_truncate_bytes + self.queries: list[str] = [] + + def _sql(self, sql: str): + logger.debug(f"Mock backend.sql() received SQL: {self._only_n_bytes(sql, self._debug_truncate_bytes)}") + seen_before = sql in self.queries + self.queries.append(sql) + if not seen_before and self._fails_on_first is not None: + for match, failure in self._fails_on_first.items(): + if match in sql: + raise self._api_error_from_message(failure) from None + + def execute(self, sql): + self._sql(sql) + + def fetch(self, sql) -> Iterator[Row]: + self._sql(sql) + rows = [] + if self._rows: + for pattern in self._rows.keys(): + r = re.compile(pattern) + if r.search(sql): + logger.debug(f"Found match: {sql}") + rows.extend(self._rows[pattern]) + logger.debug(f"Returning rows: {rows}") + return iter(rows) + + def save_table(self, full_name: str, rows: Sequence[DataclassInstance], klass: Dataclass, mode: str = "append"): + if mode == "overwrite": + msg = "Overwrite mode is not yet supported" + raise NotImplementedError(msg) + rows = self._filter_none_rows(rows, klass) + if klass.__class__ == type: + row_factory = self._row_factory(klass) + rows = [row_factory(*dataclasses.astuple(r)) for r in rows] + self._save_table.append((full_name, rows, mode)) + + def rows_written_for(self, full_name: str, mode: str) -> list[DataclassInstance]: + rows: list[DataclassInstance] = [] + for stub_full_name, stub_rows, stub_mode in self._save_table: + if not (stub_full_name == full_name and stub_mode == mode): + continue + rows += stub_rows + return rows + + @staticmethod + def _row_factory(klass: Dataclass) -> type: + return Row.factory([f.name for f in dataclasses.fields(klass)]) diff --git a/src/databricks/labs/lsql/core.py b/src/databricks/labs/lsql/core.py new file mode 100644 index 00000000..fa048c09 --- /dev/null +++ b/src/databricks/labs/lsql/core.py @@ -0,0 +1,501 @@ +import base64 +import datetime +import json +import logging +import random +import threading +import time +import types +from collections.abc import Callable, Iterator +from datetime import timedelta +from typing import Any + +import requests +import sqlglot +from databricks.sdk import WorkspaceClient, errors +from databricks.sdk.errors import DataLoss, NotFound +from databricks.sdk.service.sql import ( + ColumnInfoTypeName, + Disposition, + ExecuteStatementResponse, + Format, + ServiceError, + ServiceErrorCode, + State, + StatementState, + StatementStatus, +) + +MAX_SLEEP_PER_ATTEMPT = 10 + +MAX_PLATFORM_TIMEOUT = 50 + +MIN_PLATFORM_TIMEOUT = 5 + +logger = logging.getLogger(__name__) + + +class Row(tuple): + """Row is a tuple with named fields that resembles PySpark's SQL Row API.""" + + def __new__(cls, *args, **kwargs): + """Create a new instance of Row.""" + if args and kwargs: + raise ValueError("cannot mix positional and keyword arguments") + if kwargs: + # PySpark's compatibility layer + row = tuple.__new__(cls, list(kwargs.values())) + row.__columns__ = list(kwargs.keys()) + return row + if len(args) == 1 and hasattr(cls, "__columns__") and isinstance(args[0], (types.GeneratorType, list, tuple)): + # this type returned by Row.factory() and we already know the column names + return cls(*args[0]) + if len(args) == 2 and isinstance(args[0], (list, tuple)) and isinstance(args[1], (list, tuple)): + # UCX's compatibility layer + row = tuple.__new__(cls, args[1]) + row.__columns__ = args[0] + return row + return tuple.__new__(cls, args) + + @classmethod + def factory(cls, col_names: list[str]) -> type: + """Create a new Row class with the given column names.""" + return type("Row", (Row,), {"__columns__": col_names}) + + def as_dict(self) -> dict[str, Any]: + """Convert the row to a dictionary with the same conventions as Databricks SDK.""" + return dict(zip(self.__columns__, self, strict=True)) + + def __contains__(self, item): + """Check if the column is in the row.""" + return item in self.__columns__ + + def __getitem__(self, col): + """Get the column by index or name.""" + if isinstance(col, int | slice): + return super().__getitem__(col) + # if columns are named `2 + 2`, for example + return self.__getattr__(col) + + def __getattr__(self, col): + """Get the column by name.""" + if col.startswith("__"): + raise AttributeError(col) + try: + idx = self.__columns__.index(col) + return self[idx] + except IndexError: + raise AttributeError(col) from None + except ValueError: + raise AttributeError(col) from None + + def __repr__(self): + """Get the string representation of the row.""" + return f"Row({', '.join(f'{k}={v!r}' for (k, v) in zip(self.__columns__, self, strict=True))})" + + +class StatementExecutionExt: + """Execute SQL statements in a stateless manner. + + Primary use-case of :py:meth:`fetch_all` and :py:meth:`execute` methods is oriented at executing SQL queries in + a stateless manner straight away from Databricks SDK for Python, without requiring any external dependencies. + Results are fetched in JSON format through presigned external links. This is perfect for serverless applications + like AWS Lambda, Azure Functions, or any other containerised short-lived applications, where container startup + time is faster with the smaller dependency set. + + >>> for (pickup_zip, dropoff_zip) in see('SELECT pickup_zip, dropoff_zip FROM samples.nyctaxi.trips LIMIT 10'): + >>> print(f'pickup_zip={pickup_zip}, dropoff_zip={dropoff_zip}') + + Method :py:meth:`fetch_all` returns an iterator of objects, that resemble :class:`pyspark.sql.Row` APIs, but full + compatibility is not the goal of this implementation. See :class:`Row` for more details. + + When you only need to execute the query and have no need to iterate over results, use the :py:meth:`execute`. + + Applications, that need to a more traditional SQL Python APIs with cursors, efficient data transfer of hundreds of + megabytes or gigabytes of data serialized in Apache Arrow format, and low result fetching latency, should use + the stateful Databricks SQL Connector for Python.""" + + def __init__( # pylint: disable=too-many-arguments + self, + ws: WorkspaceClient, + disposition: Disposition | None = None, + warehouse_id: str | None = None, + byte_limit: int | None = None, + catalog: str | None = None, + schema: str | None = None, + timeout: timedelta = timedelta(minutes=20), + disable_magic: bool = False, + http_session_factory: Callable[[], requests.Session] | None = None, + ): + if not http_session_factory: + http_session_factory = requests.Session + self._ws = ws + self._http = http_session_factory() + self._lock = threading.Lock() + self._warehouse_id = warehouse_id + self._schema = schema + self._timeout = timeout + self._catalog = catalog + self._disable_magic = disable_magic + self._byte_limit = byte_limit + self._disposition = disposition + self._type_converters = { + ColumnInfoTypeName.ARRAY: json.loads, + ColumnInfoTypeName.BINARY: base64.b64decode, + ColumnInfoTypeName.BOOLEAN: bool, + ColumnInfoTypeName.CHAR: str, + ColumnInfoTypeName.DATE: self._parse_date, + ColumnInfoTypeName.DOUBLE: float, + ColumnInfoTypeName.FLOAT: float, + ColumnInfoTypeName.INT: int, + ColumnInfoTypeName.LONG: int, + ColumnInfoTypeName.MAP: json.loads, + ColumnInfoTypeName.NULL: lambda _: None, + ColumnInfoTypeName.SHORT: int, + ColumnInfoTypeName.STRING: str, + ColumnInfoTypeName.STRUCT: json.loads, + ColumnInfoTypeName.TIMESTAMP: self._parse_timestamp, + } + + def execute( + self, + statement: str, + *, + warehouse_id: str | None = None, + byte_limit: int | None = None, + catalog: str | None = None, + schema: str | None = None, + timeout: timedelta | None = None, + ) -> ExecuteStatementResponse: + """Execute a SQL statement and block until results are ready, including starting + the warehouse if needed. + + This is a high-level implementation that works with fetching records in JSON format. + It can be considered as a quick way to run SQL queries by just depending on + Databricks SDK for Python without the need of any other compiled library dependencies. + + This method is a higher-level wrapper over Databricks SDK and fetches results + in JSON format through the external link disposition, with client-side polling until + the statement succeeds in execution. Whenever the statement is failed, cancelled, or + closed, this method raises `DatabricksError` subclass with the state message and + the relevant error code. + + To seamlessly iterate over the rows from query results, please use :py:meth:`fetch_all`. + + :param statement: str + SQL statement to execute + :param warehouse_id: str (optional) + Warehouse upon which to execute a statement. If not given, it will use the warehouse specified + in the constructor or the first available warehouse that is not in the DELETED or DELETING state. + :param byte_limit: int (optional) + Applies the given byte limit to the statement's result size. Byte counts are based on internal + representations and may not match measurable sizes in the JSON format. + :param catalog: str (optional) + Sets default catalog for statement execution, similar to `USE CATALOG` in SQL. If not given, + it will use the default catalog or the catalog specified in the constructor. + :param schema: str (optional) + Sets default schema for statement execution, similar to `USE SCHEMA` in SQL. If not given, + it will use the default schema or the schema specified in the constructor. + :param timeout: timedelta (optional) + Timeout after which the query is cancelled. If timeout is less than 50 seconds, + it is handled on the server side. If the timeout is greater than 50 seconds, + Databricks SDK for Python cancels the statement execution and throws `TimeoutError`. + If not given, it will use the timeout specified in the constructor. + :return: ExecuteStatementResponse + """ + # The wait_timeout field must be 0 seconds (disables wait), + # or between 5 seconds and 50 seconds. + timeout, wait_timeout = self._statement_timeouts(timeout) + logger.debug(f"Executing SQL statement: {statement}") + + # technically, we can do Disposition.EXTERNAL_LINKS, but let's push it further away. + # format is limited to Format.JSON_ARRAY, but other iterations may include ARROW_STREAM. + immediate_response = self._ws.statement_execution.execute_statement( + statement=statement, + warehouse_id=warehouse_id or self._default_warehouse(), + catalog=catalog or self._catalog, + schema=schema or self._schema, + disposition=self._disposition, + byte_limit=byte_limit or self._byte_limit, + wait_timeout=wait_timeout, + format=Format.JSON_ARRAY, + ) + + status = immediate_response.status + if status is None: + status = StatementStatus(state=StatementState.FAILED) + if status.state == StatementState.SUCCEEDED: + return immediate_response + + self._raise_if_needed(status) + + attempt = 1 + status_message = "polling..." + deadline = time.time() + timeout.total_seconds() + statement_id = immediate_response.statement_id + if not statement_id: + msg = f"No statement id: {immediate_response}" + raise ValueError(msg) + while time.time() < deadline: + res = self._ws.statement_execution.get_statement(statement_id) + result_status = res.status + if not result_status: + msg = f"Result status is none: {res}" + raise ValueError(msg) + state = result_status.state + if not state: + state = StatementState.FAILED + if state == StatementState.SUCCEEDED: + return ExecuteStatementResponse( + manifest=res.manifest, result=res.result, statement_id=statement_id, status=result_status + ) + status_message = f"current status: {state.value}" + self._raise_if_needed(result_status) + sleep = min(attempt, MAX_SLEEP_PER_ATTEMPT) + logger.debug(f"SQL statement {statement_id}: {status_message} (sleeping ~{sleep}s)") + time.sleep(sleep + random.random()) + attempt += 1 + self._ws.statement_execution.cancel_execution(statement_id) + msg = f"timed out after {timeout}: {status_message}" + raise TimeoutError(msg) + + def __call__(self, statement: str): + """Execute a SQL statement and block until results are ready.""" + yield from self.fetch_all(statement) + + def fetch_all( + self, + statement: str, + *, + warehouse_id: str | None = None, + byte_limit: int | None = None, + catalog: str | None = None, + schema: str | None = None, + timeout: timedelta | None = None, + ) -> Iterator[Row]: + """Execute a query and iterate over all available records. + + This method is a wrapper over :py:meth:`execute` with the handling of chunked result + processing and deserialization of those into separate rows, which are yielded from + a returned iterator. Every row API resembles those of :class:`pyspark.sql.Row`, + but full compatibility is not the goal of this implementation. + + >>> ws = WorkspaceClient(...) + >>> see = StatementExecutionExt(ws, warehouse_id=env_or_skip("TEST_DEFAULT_WAREHOUSE_ID"), catalog="samples") + >>> for row in see("SELECT * FROM nyctaxi.trips LIMIT 10"): + >>> pickup_time, dropoff_time = row[0], row[1] + >>> pickup_zip = row.pickup_zip + >>> dropoff_zip = row["dropoff_zip"] + >>> all_fields = row.as_dict() + >>> logger.info(f"{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}") + + :param statement: str + SQL statement to execute + :param warehouse_id: str (optional) + Warehouse upon which to execute a statement. See :py:meth:`execute` for more details. + :param byte_limit: int (optional) + Result-size limit in bytes. See :py:meth:`execute` for more details. + :param catalog: str (optional) + Catalog for statement execution. See :py:meth:`execute` for more details. + :param schema: str (optional) + Schema for statement execution. See :py:meth:`execute` for more details. + :param timeout: timedelta (optional) + Timeout after which the query is cancelled. See :py:meth:`execute` for more details. + :return: Iterator[Row] + """ + execute_response = self.execute( + statement, warehouse_id=warehouse_id, byte_limit=byte_limit, catalog=catalog, schema=schema, timeout=timeout + ) + assert execute_response.statement_id is not None + result_data = execute_response.result + if result_data is None: + return + row_factory, col_conv = self._result_schema(execute_response) + while True: + if result_data.data_array: + for data in result_data.data_array: + yield row_factory(col_conv[i](value) for i, value in enumerate(data)) + next_chunk_index = result_data.next_chunk_index + if result_data.external_links: + for external_link in result_data.external_links: + assert external_link.external_link is not None + next_chunk_index = external_link.next_chunk_index + response = self._http.get(external_link.external_link) + response.raise_for_status() + for data in response.json(): + yield row_factory(col_conv[i](value) for i, value in enumerate(data)) + if not next_chunk_index: + return + result_data = self._ws.statement_execution.get_statement_result_chunk_n( + execute_response.statement_id, next_chunk_index + ) + + def fetch_one(self, statement: str, disable_magic: bool = False, **kwargs) -> Row | None: + """Execute a query and fetch the first available record. + + This method is a wrapper over :py:meth:`fetch_all` and fetches only the first row + from the result set. If no records are available, it returns `None`. + + >>> row = see.fetch_one('SELECT * FROM samples.nyctaxi.trips LIMIT 1') + >>> if row: + >>> pickup_time, dropoff_time = row[0], row[1] + >>> pickup_zip = row.pickup_zip + >>> dropoff_zip = row['dropoff_zip'] + >>> all_fields = row.as_dict() + >>> print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') + + :param statement: str + SQL statement to execute + :param disable_magic: bool (optional) + Disables the magic of adding `LIMIT 1` to the statement. By default, it is `False`. + :param kwargs: dict + Additional keyword arguments to pass to :py:meth:`fetch_all` + :return: Row | None + """ + disable_magic = disable_magic or self._disable_magic + if not disable_magic: + statement = self._add_limit(statement) + for row in self.fetch_all(statement, **kwargs): + return row + return None + + def fetch_value(self, statement: str, **kwargs) -> Any | None: + """Execute a query and fetch the first available value.""" + for (v,) in self.fetch_all(statement, **kwargs): + return v + return None + + def _statement_timeouts(self, timeout) -> tuple[timedelta, str | None]: + """Set server-side and client-side timeouts for statement execution.""" + if timeout is None: + timeout = self._timeout + wait_timeout = None + if MIN_PLATFORM_TIMEOUT <= timeout.total_seconds() <= MAX_PLATFORM_TIMEOUT: + # set server-side timeout + wait_timeout = f"{timeout.total_seconds()}s" + assert timeout is not None + return timeout, wait_timeout + + @staticmethod + def _parse_date(value: str) -> datetime.date: + """Parse date from string in ISO format.""" + year, month, day = value.split("-") + return datetime.date(int(year), int(month), int(day)) + + @staticmethod + def _parse_timestamp(value: str) -> datetime.datetime: + """Parse timestamp from string in ISO format.""" + # make it work with Python 3.7 to 3.10 as well + return datetime.datetime.fromisoformat(value.replace("Z", "+00:00")) + + @staticmethod + def _raise_if_needed(status: StatementStatus): + """Raise an exception if the statement status is failed, canceled, or closed.""" + if status.state not in [StatementState.FAILED, StatementState.CANCELED, StatementState.CLOSED]: + return + status_error = status.error + if status_error is None: + status_error = ServiceError(message="unknown", error_code=ServiceErrorCode.UNKNOWN) + error_message = status_error.message + if error_message is None: + error_message = "" + if "SCHEMA_NOT_FOUND" in error_message: + raise NotFound(error_message) + if "TABLE_OR_VIEW_NOT_FOUND" in error_message: + raise NotFound(error_message) + if "DELTA_TABLE_NOT_FOUND" in error_message: + raise NotFound(error_message) + if "DELTA_MISSING_TRANSACTION_LOG" in error_message: + raise DataLoss(error_message) + mapping = { + ServiceErrorCode.ABORTED: errors.Aborted, + ServiceErrorCode.ALREADY_EXISTS: errors.AlreadyExists, + ServiceErrorCode.BAD_REQUEST: errors.BadRequest, + ServiceErrorCode.CANCELLED: errors.Cancelled, + ServiceErrorCode.DEADLINE_EXCEEDED: errors.DeadlineExceeded, + ServiceErrorCode.INTERNAL_ERROR: errors.InternalError, + ServiceErrorCode.IO_ERROR: errors.InternalError, + ServiceErrorCode.NOT_FOUND: errors.NotFound, + ServiceErrorCode.RESOURCE_EXHAUSTED: errors.ResourceExhausted, + ServiceErrorCode.SERVICE_UNDER_MAINTENANCE: errors.TemporarilyUnavailable, + ServiceErrorCode.TEMPORARILY_UNAVAILABLE: errors.TemporarilyUnavailable, + ServiceErrorCode.UNAUTHENTICATED: errors.Unauthenticated, + ServiceErrorCode.UNKNOWN: errors.Unknown, + ServiceErrorCode.WORKSPACE_TEMPORARILY_UNAVAILABLE: errors.TemporarilyUnavailable, + } + error_code = status_error.error_code + if error_code is None: + error_code = ServiceErrorCode.UNKNOWN + error_class = mapping.get(error_code, errors.Unknown) + raise error_class(error_message) + + def _default_warehouse(self) -> str: + """Get the default warehouse id from the workspace client configuration + or DATABRICKS_WAREHOUSE_ID environment variable. If not set, it will use + the first available warehouse that is not in the DELETED or DELETING state.""" + with self._lock: + if self._warehouse_id: + return self._warehouse_id + # if we create_autospec(WorkspaceClient), the warehouse_id is a MagicMock + if isinstance(self._ws.config.warehouse_id, str) and self._ws.config.warehouse_id: + self._warehouse_id = self._ws.config.warehouse_id + return self._ws.config.warehouse_id + ids = [] + for v in self._ws.warehouses.list(): + assert v.id is not None + if v.state in [State.DELETED, State.DELETING]: + continue + if v.state == State.RUNNING: + self._ws.config.warehouse_id = v.id + return self._ws.config.warehouse_id + ids.append(v.id) + if len(ids) > 0: + # otherwise - first warehouse + self._ws.config.warehouse_id = ids[0] + return self._ws.config.warehouse_id + raise ValueError( + "no warehouse_id=... given, " + "neither it is set in the WorkspaceClient(..., warehouse_id=...), " + "nor in the DATABRICKS_WAREHOUSE_ID environment variable" + ) + + @staticmethod + def _add_limit(statement: str) -> str: + """Add a limit 1 to the statement if it does not have one already.""" + statements = sqlglot.parse(statement, read="databricks") + if not statements: + raise ValueError(f"cannot parse statement: {statement}") + statement_ast = statements[0] + if isinstance(statement_ast, sqlglot.expressions.Select): + if statement_ast.limit is not None: + limit = statement_ast.args.get("limit", None) + if limit and limit.text("expression") != "1": + raise ValueError(f"limit is not 1: {limit.text('expression')}") + return statement_ast.limit(expression=1).sql("databricks") + return statement + + def _result_schema(self, execute_response: ExecuteStatementResponse): + """Get the result schema from the execute response.""" + manifest = execute_response.manifest + if not manifest: + msg = f"missing manifest: {execute_response}" + raise ValueError(msg) + manifest_schema = manifest.schema + if not manifest_schema: + msg = f"missing schema: {manifest}" + raise ValueError(msg) + col_names = [] + col_conv = [] + columns = manifest_schema.columns + if not columns: + columns = [] + for col in columns: + assert col.name is not None + assert col.type_name is not None + col_names.append(col.name) + conv = self._type_converters.get(col.type_name, None) + if conv is None: + msg = f"{col.name} has no {col.type_name.value} converter" + raise ValueError(msg) + col_conv.append(conv) + return Row.factory(col_names), col_conv diff --git a/src/databricks/labs/lsql/deployment.py b/src/databricks/labs/lsql/deployment.py new file mode 100644 index 00000000..7dd42a67 --- /dev/null +++ b/src/databricks/labs/lsql/deployment.py @@ -0,0 +1,40 @@ +import logging +import pkgutil +from typing import Any + +from databricks.labs.lsql.backends import Dataclass, SqlBackend + +logger = logging.getLogger(__name__) + + +class SchemaDeployer: + def __init__(self, sql_backend: SqlBackend, inventory_schema: str, mod: Any): + self._sql_backend = sql_backend + self._inventory_schema = inventory_schema + self._module = mod + + def deploy_schema(self): + logger.info(f"Ensuring {self._inventory_schema} database exists") + self._sql_backend.execute(f"CREATE SCHEMA IF NOT EXISTS hive_metastore.{self._inventory_schema}") + + def delete_schema(self): + logger.info(f"deleting {self._inventory_schema} database") + + self._sql_backend.execute(f"DROP SCHEMA IF EXISTS hive_metastore.{self._inventory_schema} CASCADE") + + def deploy_table(self, name: str, klass: Dataclass): + logger.info(f"Ensuring {self._inventory_schema}.{name} table exists") + self._sql_backend.create_table(f"hive_metastore.{self._inventory_schema}.{name}", klass) + + def deploy_view(self, name: str, relative_filename: str): + query = self._load(relative_filename) + logger.info(f"Ensuring {self._inventory_schema}.{name} view matches {relative_filename} contents") + ddl = f"CREATE OR REPLACE VIEW hive_metastore.{self._inventory_schema}.{name} AS {query}" + self._sql_backend.execute(ddl) + + def _load(self, relative_filename: str) -> str: + data = pkgutil.get_data(self._module.__name__, relative_filename) + assert data is not None + sql = data.decode("utf-8") + sql = sql.replace("$inventory", f"hive_metastore.{self._inventory_schema}") + return sql diff --git a/src/databricks/labs/lsql/lib.py b/src/databricks/labs/lsql/lib.py deleted file mode 100644 index a6eaee92..00000000 --- a/src/databricks/labs/lsql/lib.py +++ /dev/null @@ -1,326 +0,0 @@ -import base64 -import datetime -import functools -import json -import logging -import random -import time -from datetime import timedelta -from typing import Any, Dict, Iterator, List, Optional - -from databricks.sdk.core import DatabricksError -from databricks.sdk.service.sql import (ColumnInfoTypeName, Disposition, - ExecuteStatementResponse, Format, - StatementExecutionAPI, StatementState, - StatementStatus) - -_LOG = logging.getLogger(__name__) - - -class Row(tuple): - - def __new__(cls, columns: List[str], values: List[Any]) -> 'Row': - row = tuple.__new__(cls, values) - row.__columns__ = columns - return row - - # Python SDK convention - def as_dict(self) -> Dict[str, any]: - return dict(zip(self.__columns__, self)) - - # PySpark convention - asDict = as_dict - - # PySpark convention - def __contains__(self, item): - return item in self.__columns__ - - def __getitem__(self, col): - if isinstance(col, (int, slice)): - return super().__getitem__(col) - # if columns are named `2 + 2`, for example - return self.__getattr__(col) - - def __getattr__(self, col): - try: - idx = self.__columns__.index(col) - return self[idx] - except IndexError: - raise AttributeError(col) - except ValueError: - raise AttributeError(col) - - def __repr__(self): - return f"Row({', '.join(f'{k}={v}' for (k, v) in zip(self.__columns__, self))})" - - -class StatementExecutionExt(StatementExecutionAPI): - """ - Execute SQL statements in a stateless manner. - - Primary use-case of :py:meth:`iterate_rows` and :py:meth:`execute` methods is oriented at executing SQL queries in - a stateless manner straight away from Databricks SDK for Python, without requiring any external dependencies. - Results are fetched in JSON format through presigned external links. This is perfect for serverless applications - like AWS Lambda, Azure Functions, or any other containerised short-lived applications, where container startup - time is faster with the smaller dependency set. - - .. code-block: - - for (pickup_zip, dropoff_zip) in w.statement_execution.iterate_rows(warehouse_id, - 'SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10', catalog='samples'): - print(f'pickup_zip={pickup_zip}, dropoff_zip={dropoff_zip}') - - Method :py:meth:`iterate_rows` returns an iterator of objects, that resemble :class:`pyspark.sql.Row` APIs, but full - compatibility is not the goal of this implementation. - - .. code-block:: - - iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog='samples') - for row in iterate_rows('SELECT * FROM nyctaxi.trips LIMIT 10'): - pickup_time, dropoff_time = row[0], row[1] - pickup_zip = row.pickup_zip - dropoff_zip = row['dropoff_zip'] - all_fields = row.as_dict() - print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') - - When you only need to execute the query and have no need to iterate over results, use the :py:meth:`execute`. - - .. code-block:: - - w.statement_execution.execute(warehouse_id, 'CREATE TABLE foo AS SELECT * FROM range(10)') - - Applications, that need to a more traditional SQL Python APIs with cursors, efficient data transfer of hundreds of - megabytes or gigabytes of data serialized in Apache Arrow format, and low result fetching latency, should use - the stateful Databricks SQL Connector for Python. - """ - - def __init__(self, api_client): - super().__init__(api_client) - self._type_converters = { - ColumnInfoTypeName.ARRAY: json.loads, - ColumnInfoTypeName.BINARY: base64.b64decode, - ColumnInfoTypeName.BOOLEAN: bool, - ColumnInfoTypeName.CHAR: str, - ColumnInfoTypeName.DATE: self._parse_date, - ColumnInfoTypeName.DOUBLE: float, - ColumnInfoTypeName.FLOAT: float, - ColumnInfoTypeName.INT: int, - ColumnInfoTypeName.LONG: int, - ColumnInfoTypeName.MAP: json.loads, - ColumnInfoTypeName.NULL: lambda _: None, - ColumnInfoTypeName.SHORT: int, - ColumnInfoTypeName.STRING: str, - ColumnInfoTypeName.STRUCT: json.loads, - ColumnInfoTypeName.TIMESTAMP: self._parse_timestamp, - } - - @staticmethod - def _parse_date(value: str) -> datetime.date: - year, month, day = value.split('-') - return datetime.date(int(year), int(month), int(day)) - - @staticmethod - def _parse_timestamp(value: str) -> datetime.datetime: - # make it work with Python 3.7 to 3.10 as well - return datetime.datetime.fromisoformat(value.replace('Z', '+00:00')) - - @staticmethod - def _raise_if_needed(status: StatementStatus): - if status.state not in [StatementState.FAILED, StatementState.CANCELED, StatementState.CLOSED]: - return - err = status.error - if err is not None: - message = err.message.strip() - error_code = err.error_code.value - raise DatabricksError(message, error_code=error_code) - raise DatabricksError(status.state.value) - - def execute(self, - warehouse_id: str, - statement: str, - *, - byte_limit: Optional[int] = None, - catalog: Optional[str] = None, - schema: Optional[str] = None, - timeout: timedelta = timedelta(minutes=20), - ) -> ExecuteStatementResponse: - """(Experimental) Execute a SQL statement and block until results are ready, - including starting the warehouse if needed. - - This is a high-level implementation that works with fetching records in JSON format. - It can be considered as a quick way to run SQL queries by just depending on - Databricks SDK for Python without the need of any other compiled library dependencies. - - This method is a higher-level wrapper over :py:meth:`execute_statement` and fetches results - in JSON format through the external link disposition, with client-side polling until - the statement succeeds in execution. Whenever the statement is failed, cancelled, or - closed, this method raises `DatabricksError` with the state message and the relevant - error code. - - To seamlessly iterate over the rows from query results, please use :py:meth:`iterate_rows`. - - :param warehouse_id: str - Warehouse upon which to execute a statement. - :param statement: str - SQL statement to execute - :param byte_limit: int (optional) - Applies the given byte limit to the statement's result size. Byte counts are based on internal - representations and may not match measurable sizes in the JSON format. - :param catalog: str (optional) - Sets default catalog for statement execution, similar to `USE CATALOG` in SQL. - :param schema: str (optional) - Sets default schema for statement execution, similar to `USE SCHEMA` in SQL. - :param timeout: timedelta (optional) - Timeout after which the query is cancelled. If timeout is less than 50 seconds, - it is handled on the server side. If the timeout is greater than 50 seconds, - Databricks SDK for Python cancels the statement execution and throws `TimeoutError`. - :return: ExecuteStatementResponse - """ - # The wait_timeout field must be 0 seconds (disables wait), - # or between 5 seconds and 50 seconds. - wait_timeout = None - if 5 <= timeout.total_seconds() <= 50: - # set server-side timeout - wait_timeout = f"{timeout.total_seconds()}s" - - _LOG.debug(f"Executing SQL statement: {statement}") - - # format is limited to Format.JSON_ARRAY, but other iterations may include ARROW_STREAM. - immediate_response = self.execute_statement(warehouse_id=warehouse_id, - statement=statement, - catalog=catalog, - schema=schema, - disposition=Disposition.EXTERNAL_LINKS, - format=Format.JSON_ARRAY, - byte_limit=byte_limit, - wait_timeout=wait_timeout) - - if immediate_response.status.state == StatementState.SUCCEEDED: - return immediate_response - - self._raise_if_needed(immediate_response.status) - - attempt = 1 - status_message = "polling..." - deadline = time.time() + timeout.total_seconds() - while time.time() < deadline: - res = self.get_statement(immediate_response.statement_id) - if res.status.state == StatementState.SUCCEEDED: - return ExecuteStatementResponse(manifest=res.manifest, - result=res.result, - statement_id=res.statement_id, - status=res.status) - status_message = f"current status: {res.status.state.value}" - self._raise_if_needed(res.status) - sleep = attempt - if sleep > 10: - # sleep 10s max per attempt - sleep = 10 - _LOG.debug(f"SQL statement {res.statement_id}: {status_message} (sleeping ~{sleep}s)") - time.sleep(sleep + random.random()) - attempt += 1 - self.cancel_execution(immediate_response.statement_id) - msg = f"timed out after {timeout}: {status_message}" - raise TimeoutError(msg) - - def iterate_rows(self, - warehouse_id: str, - statement: str, - *, - byte_limit: Optional[int] = None, - catalog: Optional[str] = None, - schema: Optional[str] = None, - timeout: timedelta = timedelta(minutes=20), - ) -> Iterator[Row]: - """(Experimental) Execute a query and iterate over all available records. - - This method is a wrapper over :py:meth:`execute` with the handling of chunked result - processing and deserialization of those into separate rows, which are yielded from - a returned iterator. Every row API resembles those of :class:`pyspark.sql.Row`, - but full compatibility is not the goal of this implementation. - - .. code-block:: - - iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog='samples') - for row in iterate_rows('SELECT * FROM nyctaxi.trips LIMIT 10'): - pickup_time, dropoff_time = row[0], row[1] - pickup_zip = row.pickup_zip - dropoff_zip = row['dropoff_zip'] - all_fields = row.as_dict() - print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') - - :param warehouse_id: str - Warehouse upon which to execute a statement. - :param statement: str - SQL statement to execute - :param byte_limit: int (optional) - Applies the given byte limit to the statement's result size. Byte counts are based on internal - representations and may not match measurable sizes in the JSON format. - :param catalog: str (optional) - Sets default catalog for statement execution, similar to `USE CATALOG` in SQL. - :param schema: str (optional) - Sets default schema for statement execution, similar to `USE SCHEMA` in SQL. - :param timeout: timedelta (optional) - Timeout after which the query is cancelled. If timeout is less than 50 seconds, - it is handled on the server side. If the timeout is greater than 50 seconds, - Databricks SDK for Python cancels the statement execution and throws `TimeoutError`. - :return: Iterator[Row] - """ - execute_response = self.execute(warehouse_id, - statement, - byte_limit=byte_limit, - catalog=catalog, - schema=schema, - timeout=timeout) - if execute_response.result.external_links is None: - return [] - for row in self._iterate_external_disposition(execute_response): - yield row - - def _result_schema(self, execute_response: ExecuteStatementResponse): - col_names = [] - col_conv = [] - for col in execute_response.manifest.schema.columns: - col_names.append(col.name) - conv = self._type_converters.get(col.type_name, None) - if conv is None: - msg = f"{col.name} has no {col.type_name.value} converter" - raise ValueError(msg) - col_conv.append(conv) - row_factory = functools.partial(Row, col_names) - return row_factory, col_conv - - def _iterate_external_disposition(self, execute_response: ExecuteStatementResponse) -> Iterator[Row]: - # ensure that we close the HTTP session after fetching the external links - result_data = execute_response.result - row_factory, col_conv = self._result_schema(execute_response) - with self._api._new_session() as http: - while True: - for external_link in result_data.external_links: - response = http.get(external_link.external_link) - response.raise_for_status() - for data in response.json(): - yield row_factory(col_conv[i](value) for i, value in enumerate(data)) - - if external_link.next_chunk_index is None: - return - - result_data = self.get_statement_result_chunk_n(execute_response.statement_id, - external_link.next_chunk_index) - - def _iterate_inline_disposition(self, execute_response: ExecuteStatementResponse) -> Iterator[Row]: - result_data = execute_response.result - row_factory, col_conv = self._result_schema(execute_response) - while True: - # case for Disposition.INLINE, where we get rows embedded into a response - for data in result_data.data_array: - # enumerate() + iterator + tuple constructor makes it more performant - # on larger humber of records for Python, even though it's less - # readable code. - yield row_factory(col_conv[i](value) for i, value in enumerate(data)) - - if result_data.next_chunk_index is None: - return - - result_data = self.get_statement_result_chunk_n(execute_response.statement_id, - result_data.next_chunk_index) \ No newline at end of file diff --git a/src/databricks/labs/lsql/py.typed b/src/databricks/labs/lsql/py.typed new file mode 100644 index 00000000..731adc94 --- /dev/null +++ b/src/databricks/labs/lsql/py.typed @@ -0,0 +1 @@ +# Marker file for PEP 561. The databricks-labs-lsql package uses inline types. diff --git a/src/databricks/labs/lsql/sql_backend.py b/src/databricks/labs/lsql/sql_backend.py deleted file mode 100644 index 1c3a9c27..00000000 --- a/src/databricks/labs/lsql/sql_backend.py +++ /dev/null @@ -1,194 +0,0 @@ -import dataclasses -import logging -import re -from abc import ABC, abstractmethod -from typing import Iterator, ClassVar - -from databricks.sdk import WorkspaceClient - -from databricks.labs.lsql.lib import StatementExecutionExt - -logger = logging.getLogger(__name__) - -class SqlBackend(ABC): - @abstractmethod - def execute(self, sql): - raise NotImplementedError - - @abstractmethod - def fetch(self, sql) -> Iterator[any]: - raise NotImplementedError - - @abstractmethod - def save_table(self, full_name: str, rows: list[any], klass: type, mode: str = "append"): - raise NotImplementedError - - def create_table(self, full_name: str, klass: type): - ddl = f"CREATE TABLE IF NOT EXISTS {full_name} ({self._schema_for(klass)}) USING DELTA" - self.execute(ddl) - - _builtin_type_mapping: ClassVar[dict[type, str]] = {str: "STRING", int: "INT", bool: "BOOLEAN", float: "FLOAT"} - - @classmethod - def _schema_for(cls, klass): - fields = [] - for f in dataclasses.fields(klass): - if f.type not in cls._builtin_type_mapping: - msg = f"Cannot auto-convert {f.type}" - raise SyntaxError(msg) - not_null = " NOT NULL" - if f.default is None: - not_null = "" - spark_type = cls._builtin_type_mapping[f.type] - fields.append(f"{f.name} {spark_type}{not_null}") - return ", ".join(fields) - - @classmethod - def _filter_none_rows(cls, rows, full_name): - if len(rows) == 0: - return rows - - results = [] - nullable_fields = set() - - for field in dataclasses.fields(rows[0]): - if field.default is None: - nullable_fields.add(field.name) - - for row in rows: - if row is None: - continue - row_contains_none = False - for column, value in dataclasses.asdict(row).items(): - if value is None and column not in nullable_fields: - logger.warning(f"[{full_name}] Field {column} is None, filtering row") - row_contains_none = True - break - - if not row_contains_none: - results.append(row) - return results - - -class StatementExecutionBackend(SqlBackend): - def __init__(self, ws: WorkspaceClient, warehouse_id, *, max_records_per_batch: int = 1000): - self._sql = StatementExecutionExt(ws.api_client) - self._warehouse_id = warehouse_id - self._max_records_per_batch = max_records_per_batch - - def execute(self, sql): - logger.debug(f"[api][execute] {sql}") - self._sql.execute(self._warehouse_id, sql) - - def fetch(self, sql) -> Iterator[any]: - logger.debug(f"[api][fetch] {sql}") - return self._sql.execute_fetch_all(self._warehouse_id, sql) - - def save_table(self, full_name: str, rows: list[any], klass: dataclasses.dataclass, mode="append"): - if mode == "overwrite": - msg = "Overwrite mode is not yet supported" - raise NotImplementedError(msg) - rows = self._filter_none_rows(rows, full_name) - self.create_table(full_name, klass) - if len(rows) == 0: - return - fields = dataclasses.fields(klass) - field_names = [f.name for f in fields] - for i in range(0, len(rows), self._max_records_per_batch): - batch = rows[i : i + self._max_records_per_batch] - vals = "), (".join(self._row_to_sql(r, fields) for r in batch) - sql = f'INSERT INTO {full_name} ({", ".join(field_names)}) VALUES ({vals})' - self.execute(sql) - - @staticmethod - def _row_to_sql(row, fields): - data = [] - for f in fields: - value = getattr(row, f.name) - if value is None: - data.append("NULL") - elif f.type == bool: - data.append("TRUE" if value else "FALSE") - elif f.type == str: - value = str(value).replace("'", "''") - data.append(f"'{value}'") - elif f.type == int: - data.append(f"{value}") - else: - msg = f"unknown type: {f.type}" - raise ValueError(msg) - return ", ".join(data) - - -class RuntimeBackend(SqlBackend): - def __init__(self): - from pyspark.sql.session import SparkSession - - if "DATABRICKS_RUNTIME_VERSION" not in os.environ: - msg = "Not in the Databricks Runtime" - raise RuntimeError(msg) - - self._spark = SparkSession.builder.getOrCreate() - - def execute(self, sql): - logger.debug(f"[spark][execute] {sql}") - self._spark.sql(sql) - - def fetch(self, sql) -> Iterator[any]: - logger.debug(f"[spark][fetch] {sql}") - return self._spark.sql(sql).collect() - - def save_table(self, full_name: str, rows: list[any], klass: dataclasses.dataclass, mode: str = "append"): - rows = self._filter_none_rows(rows, full_name) - - if len(rows) == 0: - self.create_table(full_name, klass) - return - # pyspark deals well with lists of dataclass instances, as long as schema is provided - df = self._spark.createDataFrame(rows, self._schema_for(rows[0])) - df.write.saveAsTable(full_name, mode=mode) - -class MockBackend(SqlBackend): - def __init__(self, *, fails_on_first: dict | None = None, rows: dict | None = None): - self._fails_on_first = fails_on_first - if not rows: - rows = {} - self._rows = rows - self._save_table = [] - self.queries = [] - - def _sql(self, sql): - logger.debug(f"Mock backend.sql() received SQL: {sql}") - seen_before = sql in self.queries - self.queries.append(sql) - if not seen_before and self._fails_on_first is not None: - for match, failure in self._fails_on_first.items(): - if match in sql: - raise RuntimeError(failure) - - def execute(self, sql): - self._sql(sql) - - def fetch(self, sql) -> Iterator[any]: - self._sql(sql) - rows = [] - if self._rows: - for pattern in self._rows.keys(): - r = re.compile(pattern) - if r.match(sql): - logger.debug(f"Found match: {sql}") - rows.extend(self._rows[pattern]) - logger.debug(f"Returning rows: {rows}") - return iter(rows) - - def save_table(self, full_name: str, rows: list[any], klass, mode: str = "append"): - if klass.__class__ == type: - self._save_table.append((full_name, rows, mode)) - - def rows_written_for(self, full_name: str, mode: str) -> list[any]: - rows = [] - for stub_full_name, stub_rows, stub_mode in self._save_table: - if not (stub_full_name == full_name and stub_mode == mode): - continue - rows += stub_rows - return rows diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index 255c9c76..00000000 --- a/tests/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# SPDX-FileCopyrightText: 2023-present Serge Smertin -# -# SPDX-License-Identifier: MIT diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 00000000..c4db6489 --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,92 @@ +import json +import logging +import os +import pathlib +import string +import sys +from typing import Callable, MutableMapping + +import pytest +from databricks.labs.blueprint.logger import install_logger +from databricks.sdk import WorkspaceClient +from pytest import fixture + +from databricks.labs.lsql.__about__ import __version__ + +install_logger() +logging.getLogger("databricks").setLevel("DEBUG") + + +def _is_in_debug() -> bool: + return os.path.basename(sys.argv[0]) in {"_jb_pytest_runner.py", "testlauncher.py"} + + +@fixture # type: ignore[no-redef] +def debug_env_name(): + return "ucws" + + +@fixture +def debug_env(monkeypatch, debug_env_name) -> MutableMapping[str, str]: + if not _is_in_debug(): + return os.environ + conf_file = pathlib.Path.home() / ".databricks/debug-env.json" + if not conf_file.exists(): + return os.environ + with conf_file.open("r") as f: + conf = json.load(f) + if debug_env_name not in conf: + sys.stderr.write( + f"""{debug_env_name} not found in ~/.databricks/debug-env.json + + this usually means that you have to add the following fixture to + conftest.py file in the relevant directory: + + @fixture + def debug_env_name(): + return 'ENV_NAME' # where ENV_NAME is one of: {", ".join(conf.keys())} + """ + ) + msg = f"{debug_env_name} not found in ~/.databricks/debug-env.json" + raise KeyError(msg) + for k, v in conf[debug_env_name].items(): + monkeypatch.setenv(k, v) + return os.environ + + +@fixture +def make_random(): + import random + + def inner(k=16) -> str: + charset = string.ascii_uppercase + string.ascii_lowercase + string.digits + return "".join(random.choices(charset, k=int(k))) + + return inner + + +@fixture +def product_info(): + return "lsql", __version__ + + +@fixture +def ws(product_info, debug_env) -> WorkspaceClient: + # Use variables from Unified Auth + # See https://databricks-sdk-py.readthedocs.io/en/latest/authentication.html + product_name, product_version = product_info + return WorkspaceClient(host=debug_env["DATABRICKS_HOST"], product=product_name, product_version=product_version) + + +@pytest.fixture +def env_or_skip(debug_env) -> Callable[[str], str]: + skip = pytest.skip + if _is_in_debug(): + skip = pytest.fail # type: ignore[assignment] + + def inner(var: str) -> str: + if var not in debug_env: + skip(f"Environment variable {var} is missing") + return debug_env[var] + + return inner diff --git a/tests/integration/test_integration.py b/tests/integration/test_integration.py new file mode 100644 index 00000000..5514058f --- /dev/null +++ b/tests/integration/test_integration.py @@ -0,0 +1,85 @@ +import logging + +import pytest +from databricks.sdk.service.sql import Disposition + +from databricks.labs.lsql.core import StatementExecutionExt + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize("disposition", [None, Disposition.INLINE, Disposition.EXTERNAL_LINKS]) +def test_sql_execution_chunked(ws, disposition): + see = StatementExecutionExt(ws, disposition=disposition) + total = 0 + for (x,) in see("SELECT id FROM range(2000000)"): + total += x + assert total == 1999999000000 + + +def test_sql_execution(ws, env_or_skip): + results = [] + see = StatementExecutionExt(ws, warehouse_id=env_or_skip("TEST_DEFAULT_WAREHOUSE_ID")) + for pickup_zip, dropoff_zip in see.fetch_all( + "SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10", catalog="samples" + ): + results.append((pickup_zip, dropoff_zip)) + assert results == [ + (10282, 10171), + (10110, 10110), + (10103, 10023), + (10022, 10017), + (10110, 10282), + (10009, 10065), + (10153, 10199), + (10112, 10069), + (10023, 10153), + (10012, 10003), + ] + + +def test_sql_execution_partial(ws, env_or_skip): + results = [] + see = StatementExecutionExt(ws, warehouse_id=env_or_skip("TEST_DEFAULT_WAREHOUSE_ID"), catalog="samples") + for row in see("SELECT * FROM nyctaxi.trips LIMIT 10"): + pickup_time, dropoff_time = row[0], row[1] + pickup_zip = row.pickup_zip + dropoff_zip = row["dropoff_zip"] + all_fields = row.as_dict() + logger.info(f"{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}") + results.append((pickup_zip, dropoff_zip)) + assert results == [ + (10282, 10171), + (10110, 10110), + (10103, 10023), + (10022, 10017), + (10110, 10282), + (10009, 10065), + (10153, 10199), + (10112, 10069), + (10023, 10153), + (10012, 10003), + ] + + +def test_fetch_one(ws): + see = StatementExecutionExt(ws) + assert see.fetch_one("SELECT 1") == (1,) + + +def test_fetch_one_fails_if_limit_is_bigger(ws): + see = StatementExecutionExt(ws) + with pytest.raises(ValueError): + see.fetch_one("SELECT * FROM samples.nyctaxi.trips LIMIT 100") + + +def test_fetch_one_works(ws): + see = StatementExecutionExt(ws) + row = see.fetch_one("SELECT * FROM samples.nyctaxi.trips LIMIT 1") + assert row.pickup_zip == 10282 + + +def test_fetch_value(ws): + see = StatementExecutionExt(ws) + count = see.fetch_value("SELECT COUNT(*) FROM samples.nyctaxi.trips") + assert count == 21932 diff --git a/tests/test_integration.py b/tests/test_integration.py deleted file mode 100644 index e9fff9e0..00000000 --- a/tests/test_integration.py +++ /dev/null @@ -1,31 +0,0 @@ -import functools - - -def test_sql_execution_chunked(w): - all_warehouses = w.warehouses.list() - assert len(w.warehouses.list()) > 0, 'at least one SQL warehouse required' - warehouse_id = all_warehouses[0].id - - total = 0 - fetch = functools.partial(w.statement_execution.iterate_rows, warehouse_id) - for (x, ) in fetch('SELECT id FROM range(2000000)'): - total += x - print(total) - - -def test_sql_execution(w, env_or_skip): - warehouse_id = env_or_skip('TEST_DEFAULT_WAREHOUSE_ID') - for (pickup_zip, dropoff_zip) in w.statement_execution.iterate_rows( - warehouse_id, 'SELECT pickup_zip, dropoff_zip FROM nyctaxi.trips LIMIT 10', catalog='samples'): - print(f'pickup_zip={pickup_zip}, dropoff_zip={dropoff_zip}') - - -def test_sql_execution_partial(w, env_or_skip): - warehouse_id = env_or_skip('TEST_DEFAULT_WAREHOUSE_ID') - iterate_rows = functools.partial(w.statement_execution.iterate_rows, warehouse_id, catalog='samples') - for row in iterate_rows('SELECT * FROM nyctaxi.trips LIMIT 10'): - pickup_time, dropoff_time = row[0], row[1] - pickup_zip = row.pickup_zip - dropoff_zip = row['dropoff_zip'] - all_fields = row.as_dict() - print(f'{pickup_zip}@{pickup_time} -> {dropoff_zip}@{dropoff_time}: {all_fields}') \ No newline at end of file diff --git a/tests/test_lib.py b/tests/test_lib.py deleted file mode 100644 index 4ea8ff8a..00000000 --- a/tests/test_lib.py +++ /dev/null @@ -1,219 +0,0 @@ -import datetime - -import pytest - -from databricks.sdk import WorkspaceClient -from databricks.sdk.core import DatabricksError -from databricks.labs.lsql import Row -from databricks.sdk.service.sql import (ColumnInfo, ColumnInfoTypeName, - Disposition, ExecuteStatementResponse, - ExternalLink, Format, - GetStatementResponse, ResultData, - ResultManifest, ResultSchema, - ServiceError, ServiceErrorCode, - StatementState, StatementStatus, - timedelta) - - -def test_execute_poll_succeeds(config, mocker): - w = WorkspaceClient(config=config) - - execute_statement = mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.PENDING), - statement_id='bcd', - )) - - get_statement = mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.get_statement', - return_value=GetStatementResponse( - manifest=ResultManifest(), - result=ResultData(byte_count=100500), - statement_id='bcd', - status=StatementStatus(state=StatementState.SUCCEEDED))) - - response = w.statement_execution.execute('abc', 'SELECT 2+2') - - assert response.status.state == StatementState.SUCCEEDED - assert response.result.byte_count == 100500 - execute_statement.assert_called_with(warehouse_id='abc', - statement='SELECT 2+2', - disposition=Disposition.EXTERNAL_LINKS, - format=Format.JSON_ARRAY, - byte_limit=None, - catalog=None, - schema=None, - wait_timeout=None) - get_statement.assert_called_with('bcd') - - -def test_execute_fails(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse( - statement_id='bcd', - status=StatementStatus(state=StatementState.FAILED, - error=ServiceError(error_code=ServiceErrorCode.RESOURCE_EXHAUSTED, - message='oops...')))) - - with pytest.raises(DatabricksError) as err: - w.statement_execution.execute('abc', 'SELECT 2+2') - assert err.value.error_code == 'RESOURCE_EXHAUSTED' - assert str(err.value) == 'oops...' - - -def test_execute_poll_waits(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse(status=StatementStatus(state=StatementState.PENDING), - statement_id='bcd', - )) - - runs = [] - - def _get_statement(statement_id): - assert statement_id == 'bcd' - if len(runs) == 0: - runs.append(1) - return GetStatementResponse(status=StatementStatus(state=StatementState.RUNNING), - statement_id='bcd') - - return GetStatementResponse(manifest=ResultManifest(), - result=ResultData(byte_count=100500), - statement_id='bcd', - status=StatementStatus(state=StatementState.SUCCEEDED)) - - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.get_statement', wraps=_get_statement) - - response = w.statement_execution.execute('abc', 'SELECT 2+2') - - assert response.status.state == StatementState.SUCCEEDED - assert response.result.byte_count == 100500 - - -def test_execute_poll_timeouts_on_client(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse(status=StatementStatus(state=StatementState.PENDING), - statement_id='bcd', - )) - - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.get_statement', - return_value=GetStatementResponse(status=StatementStatus(state=StatementState.RUNNING), - statement_id='bcd')) - - cancel_execution = mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.cancel_execution') - - with pytest.raises(TimeoutError): - w.statement_execution.execute('abc', 'SELECT 2+2', timeout=timedelta(seconds=1)) - - cancel_execution.assert_called_with('bcd') - - -def test_fetch_all_no_chunks(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest(schema=ResultSchema(columns=[ - ColumnInfo(name='id', type_name=ColumnInfoTypeName.INT), - ColumnInfo(name='since', type_name=ColumnInfoTypeName.DATE), - ColumnInfo(name='now', type_name=ColumnInfoTypeName.TIMESTAMP), - ])), - result=ResultData(external_links=[ExternalLink(external_link='https://singed-url')]), - statement_id='bcd', - )) - - raw_response = mocker.Mock() - raw_response.json = lambda: [["1", "2023-09-01", "2023-09-01T13:21:53Z"], - ["2", "2023-09-01", "2023-09-01T13:21:53Z"]] - - http_get = mocker.patch('requests.sessions.Session.get', return_value=raw_response) - - rows = list( - w.statement_execution.iterate_rows( - 'abc', 'SELECT id, CAST(NOW() AS DATE) AS since, NOW() AS now FROM range(2)')) - - assert len(rows) == 2 - assert rows[0].id == 1 - assert isinstance(rows[0].since, datetime.date) - assert isinstance(rows[0].now, datetime.datetime) - - http_get.assert_called_with('https://singed-url') - - -def test_fetch_all_no_chunks_no_converter(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest(schema=ResultSchema( - columns=[ColumnInfo(name='id', type_name=ColumnInfoTypeName.INTERVAL), ])), - result=ResultData(external_links=[ExternalLink(external_link='https://singed-url')]), - statement_id='bcd', - )) - - raw_response = mocker.Mock() - raw_response.json = lambda: [["1"], ["2"]] - - mocker.patch('requests.sessions.Session.get', return_value=raw_response) - - with pytest.raises(ValueError): - list(w.statement_execution.iterate_rows('abc', 'SELECT id FROM range(2)')) - - -def test_fetch_all_two_chunks(config, mocker): - w = WorkspaceClient(config=config) - - mocker.patch('databricks.sdk.service.sql.StatementExecutionAPI.execute_statement', - return_value=ExecuteStatementResponse( - status=StatementStatus(state=StatementState.SUCCEEDED), - manifest=ResultManifest(schema=ResultSchema(columns=[ - ColumnInfo(name='id', type_name=ColumnInfoTypeName.INT), - ColumnInfo(name='now', type_name=ColumnInfoTypeName.TIMESTAMP), - ])), - result=ResultData( - external_links=[ExternalLink(external_link='https://first', next_chunk_index=1)]), - statement_id='bcd', - )) - - next_chunk = mocker.patch( - 'databricks.sdk.service.sql.StatementExecutionAPI.get_statement_result_chunk_n', - return_value=ResultData(external_links=[ExternalLink(external_link='https://second')])) - - raw_response = mocker.Mock() - raw_response.json = lambda: [["1", "2023-09-01T13:21:53Z"], ["2", "2023-09-01T13:21:53Z"]] - http_get = mocker.patch('requests.sessions.Session.get', return_value=raw_response) - - rows = list(w.statement_execution.iterate_rows('abc', 'SELECT id, NOW() AS now FROM range(2)')) - - assert len(rows) == 4 - - assert http_get.call_args_list == [mocker.call('https://first'), mocker.call('https://second')] - next_chunk.assert_called_with('bcd', 1) - - -def test_row(): - row = Row(['a', 'b', 'c'], [1, 2, 3]) - - a, b, c = row - assert a == 1 - assert b == 2 - assert c == 3 - - assert 'a' in row - assert row.a == 1 - assert row['a'] == 1 - - as_dict = row.as_dict() - assert as_dict == row.asDict() - assert as_dict['b'] == 2 - - assert 'Row(a=1, b=2, c=3)' == str(row) - - with pytest.raises(AttributeError): - print(row.x) \ No newline at end of file diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py new file mode 100644 index 00000000..368c7c1d --- /dev/null +++ b/tests/unit/conftest.py @@ -0,0 +1,6 @@ +import logging + +from databricks.labs.blueprint.logger import install_logger + +install_logger() +logging.getLogger("databricks").setLevel("DEBUG") diff --git a/tests/unit/test_backends.py b/tests/unit/test_backends.py new file mode 100644 index 00000000..f6614663 --- /dev/null +++ b/tests/unit/test_backends.py @@ -0,0 +1,348 @@ +import os +import sys +from dataclasses import dataclass +from unittest import mock +from unittest.mock import MagicMock, create_autospec + +import pytest +from databricks.sdk import WorkspaceClient +from databricks.sdk.errors import ( + BadRequest, + DataLoss, + NotFound, + PermissionDenied, + Unknown, +) +from databricks.sdk.service.sql import ( + ColumnInfo, + ColumnInfoTypeName, + ExecuteStatementResponse, + Format, + ResultData, + ResultManifest, + ResultSchema, + StatementState, + StatementStatus, +) + +from databricks.labs.lsql import Row +from databricks.labs.lsql.backends import ( + MockBackend, + RuntimeBackend, + StatementExecutionBackend, +) + +# pylint: disable=protected-access + + +@dataclass +class Foo: + first: str + second: bool + + +@dataclass +class Baz: + first: str + second: str | None = None + + +@dataclass +class Bar: + first: str + second: bool + third: float + + +def test_statement_execution_backend_execute_happy(): + ws = create_autospec(WorkspaceClient) + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc") + + seb.execute("CREATE TABLE foo") + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="CREATE TABLE foo", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ) + + +def test_statement_execution_backend_fetch_happy(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), + result=ResultData(data_array=[["1"], ["2"], ["3"]]), + statement_id="bcd", + ) + + seb = StatementExecutionBackend(ws, "abc") + + result = list(seb.fetch("SELECT id FROM range(3)")) + + assert [Row(id=1), Row(id=2), Row(id=3)] == result + + +def test_statement_execution_backend_save_table_overwrite(mocker): + seb = StatementExecutionBackend(mocker.Mock(), "abc") + with pytest.raises(NotImplementedError): + seb.save_table("a.b.c", [1, 2, 3], Bar, mode="overwrite") + + +def test_statement_execution_backend_save_table_empty_records(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc") + + seb.save_table("a.b.c", [], Bar) + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="CREATE TABLE IF NOT EXISTS a.b.c " + "(first STRING NOT NULL, second BOOLEAN NOT NULL, third FLOAT NOT NULL) USING DELTA", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ) + + +def test_statement_execution_backend_save_table_two_records(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc") + + seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) + + ws.statement_execution.execute_statement.assert_has_calls( + [ + mock.call( + warehouse_id="abc", + statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + ] + ) + + +def test_statement_execution_backend_save_table_in_batches_of_two(mocker): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED) + ) + + seb = StatementExecutionBackend(ws, "abc", max_records_per_batch=2) + + seb.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False), Foo("ccc", True)], Foo) + + ws.statement_execution.execute_statement.assert_has_calls( + [ + mock.call( + warehouse_id="abc", + statement="CREATE TABLE IF NOT EXISTS a.b.c (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT INTO a.b.c (first, second) VALUES ('aaa', TRUE), ('bbb', FALSE)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + mock.call( + warehouse_id="abc", + statement="INSERT INTO a.b.c (first, second) VALUES ('ccc', TRUE)", + catalog=None, + schema=None, + disposition=None, + format=Format.JSON_ARRAY, + byte_limit=None, + wait_timeout=None, + ), + ] + ) + + +def test_runtime_backend_execute(): + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + runtime_backend = RuntimeBackend() + + runtime_backend.execute("CREATE TABLE foo") + + spark.sql.assert_called_with("CREATE TABLE foo") + + +def test_runtime_backend_fetch(): + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + spark.sql().collect.return_value = [Row(id=1), Row(id=2), Row(id=3)] + + runtime_backend = RuntimeBackend() + + result = runtime_backend.fetch("SELECT id FROM range(3)") + + assert [Row(id=1), Row(id=2), Row(id=3)] == result + + spark.sql.assert_called_with("SELECT id FROM range(3)") + + +def test_runtime_backend_save_table(): + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + runtime_backend = RuntimeBackend() + + runtime_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) + + spark.createDataFrame.assert_called_with( + [Foo(first="aaa", second=True), Foo(first="bbb", second=False)], + "first STRING NOT NULL, second BOOLEAN NOT NULL", + ) + spark.createDataFrame().write.saveAsTable.assert_called_with("a.b.c", mode="append") + + +def test_runtime_backend_save_table_with_row_containing_none_with_nullable_class(mocker): + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + runtime_backend = RuntimeBackend() + + runtime_backend.save_table("a.b.c", [Baz("aaa", "ccc"), Baz("bbb", None)], Baz) + + spark.createDataFrame.assert_called_with( + [Baz(first="aaa", second="ccc"), Baz(first="bbb", second=None)], + "first STRING NOT NULL, second STRING", + ) + spark.createDataFrame().write.saveAsTable.assert_called_with("a.b.c", mode="append") + + +@dataclass +class DummyClass: + key: str + value: str | None = None + + +def test_save_table_with_not_null_constraint_violated(): + rows = [DummyClass("1", "test"), DummyClass("2", None), DummyClass(None, "value")] + + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + + runtime_backend = RuntimeBackend() + + with pytest.raises( + Exception, match="Not null constraint violated for column key, row = {'key': None, 'value': 'value'}" + ): + runtime_backend.save_table("a.b.c", rows, DummyClass) + + +@pytest.mark.parametrize( + "msg,err_t", + [ + ("SCHEMA_NOT_FOUND foo schema does not exist", NotFound), + (".. TABLE_OR_VIEW_NOT_FOUND ..", NotFound), + (".. UNRESOLVED_COLUMN.WITH_SUGGESTION ..", BadRequest), + ("DELTA_TABLE_NOT_FOUND foo table does not exist", NotFound), + ("DELTA_MISSING_TRANSACTION_LOG foo table does not exist", DataLoss), + ("PARSE_SYNTAX_ERROR foo", BadRequest), + ("foo Operation not allowed", PermissionDenied), + ("foo error failure", Unknown), + ], +) +def test_runtime_backend_error_mapping_similar_to_statement_execution(msg, err_t): + with mock.patch.dict(os.environ, {"DATABRICKS_RUNTIME_VERSION": "14.0"}): + pyspark_sql_session = MagicMock() + sys.modules["pyspark.sql.session"] = pyspark_sql_session + spark = pyspark_sql_session.SparkSession.builder.getOrCreate() + + spark.sql.side_effect = Exception(msg) + + runtime_backend = RuntimeBackend() + + with pytest.raises(err_t): + runtime_backend.execute("SELECT * from bar") + + with pytest.raises(err_t): + list(runtime_backend.fetch("SELECT * from bar")) + + +def test_mock_backend_fails_on_first(): + mock_backend = MockBackend(fails_on_first={"CREATE": ".. DELTA_TABLE_NOT_FOUND .."}) + + with pytest.raises(NotFound): + mock_backend.execute("CREATE TABLE foo") + + +def test_mock_backend_rows(): + mock_backend = MockBackend(rows={r"SELECT id FROM range\(3\)": [Row(id=1), Row(id=2), Row(id=3)]}) + + result = list(mock_backend.fetch("SELECT id FROM range(3)")) + + assert [Row(id=1), Row(id=2), Row(id=3)] == result + + +def test_mock_backend_save_table(): + mock_backend = MockBackend() + + mock_backend.save_table("a.b.c", [Foo("aaa", True), Foo("bbb", False)], Foo) + + assert mock_backend.rows_written_for("a.b.c", "append") == [ + Row(first="aaa", second=True), + Row(first="bbb", second=False), + ] diff --git a/tests/unit/test_core.py b/tests/unit/test_core.py new file mode 100644 index 00000000..f4115c3d --- /dev/null +++ b/tests/unit/test_core.py @@ -0,0 +1,481 @@ +import datetime +from unittest.mock import create_autospec + +import pytest +import requests +from databricks.sdk import WorkspaceClient, errors +from databricks.sdk.service.sql import ( + ColumnInfo, + ColumnInfoTypeName, + EndpointInfo, + ExecuteStatementResponse, + ExternalLink, + Format, + GetStatementResponse, + ResultData, + ResultManifest, + ResultSchema, + ServiceError, + ServiceErrorCode, + State, + StatementState, + StatementStatus, + timedelta, +) + +from databricks.labs.lsql.core import Row, StatementExecutionExt + + +@pytest.mark.parametrize( + "row", + [ + Row(foo="bar", enabled=True), + Row(["foo", "enabled"], ["bar", True]), + ], +) +def test_row_from_kwargs(row): + assert row.foo == "bar" + assert row["foo"] == "bar" + assert "foo" in row + assert len(row) == 2 + assert list(row) == ["bar", True] + assert row.as_dict() == {"foo": "bar", "enabled": True} + foo, enabled = row + assert foo == "bar" + assert enabled is True + assert str(row) == "Row(foo='bar', enabled=True)" + with pytest.raises(AttributeError): + print(row.x) + + +def test_row_factory(): + factory = Row.factory(["a", "b"]) + row = factory(1, 2) + a, b = row + assert a == 1 + assert b == 2 + + +def test_row_factory_with_generator(): + factory = Row.factory(["a", "b"]) + row = factory(_ + 1 for _ in range(2)) + a, b = row + assert a == 1 + assert b == 2 + + +def test_selects_warehouse_from_config(): + ws = create_autospec(WorkspaceClient) + ws.config.warehouse_id = "abc" + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws) + see.execute("SELECT 2+2") + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="SELECT 2+2", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + + +def test_selects_warehouse_from_existing_first_running(): + ws = create_autospec(WorkspaceClient) + + ws.warehouses.list.return_value = [ + EndpointInfo(id="abc", state=State.DELETING), + EndpointInfo(id="bcd", state=State.RUNNING), + EndpointInfo(id="cde", state=State.RUNNING), + ] + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + ) + + see = StatementExecutionExt(ws) + see.execute("SELECT 2+2") + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="bcd", + statement="SELECT 2+2", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + + +def test_selects_warehouse_from_existing_not_running(): + ws = create_autospec(WorkspaceClient) + + ws.warehouses.list.return_value = [ + EndpointInfo(id="efg", state=State.DELETING), + EndpointInfo(id="fgh", state=State.DELETED), + EndpointInfo(id="bcd", state=State.STOPPED), + EndpointInfo(id="cde", state=State.STARTING), + ] + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + ) + + see = StatementExecutionExt(ws) + see.execute("SELECT 2+2") + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="bcd", + statement="SELECT 2+2", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + + +def test_no_warehouse_given(): + ws = create_autospec(WorkspaceClient) + + see = StatementExecutionExt(ws) + + with pytest.raises(ValueError): + see.execute("SELECT 2+2") + + +def test_execute_poll_succeeds(): + ws = create_autospec(WorkspaceClient) + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.PENDING), + statement_id="bcd", + ) + ws.statement_execution.get_statement.return_value = GetStatementResponse( + manifest=ResultManifest(), + result=ResultData(byte_count=100500), + statement_id="bcd", + status=StatementStatus(state=StatementState.SUCCEEDED), + ) + + see = StatementExecutionExt(ws) + + response = see.execute("SELECT 2+2", warehouse_id="abc") + + assert response.status.state == StatementState.SUCCEEDED + assert response.result.byte_count == 100500 + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="SELECT 2+2", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + ws.statement_execution.get_statement.assert_called_with("bcd") + + +@pytest.mark.parametrize( + "status_error,platform_error_type", + [ + (None, errors.Unknown), + (ServiceError(), errors.Unknown), + (ServiceError(message="..."), errors.Unknown), + (ServiceError(error_code=ServiceErrorCode.RESOURCE_EXHAUSTED, message="..."), errors.ResourceExhausted), + (ServiceError(message="... SCHEMA_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... TABLE_OR_VIEW_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... DELTA_TABLE_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... DELTA_TABLE_NOT_FOUND ..."), errors.NotFound), + (ServiceError(message="... DELTA_MISSING_TRANSACTION_LOG ..."), errors.DataLoss), + ], +) +def test_execute_fails(status_error, platform_error_type): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.FAILED, error=status_error), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + with pytest.raises(platform_error_type): + see.execute("SELECT 2+2") + + +def test_execute_poll_waits(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.PENDING), + statement_id="bcd", + ) + + ws.statement_execution.get_statement.side_effect = [ + GetStatementResponse(status=StatementStatus(state=StatementState.RUNNING), statement_id="bcd"), + GetStatementResponse( + manifest=ResultManifest(), + result=ResultData(byte_count=100500), + statement_id="bcd", + status=StatementStatus(state=StatementState.SUCCEEDED), + ), + ] + + see = StatementExecutionExt(ws, warehouse_id="abc") + + response = see.execute("SELECT 2+2") + + assert response.status.state == StatementState.SUCCEEDED + assert response.result.byte_count == 100500 + + +def test_execute_poll_timeouts_on_client(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.PENDING), + statement_id="bcd", + ) + + ws.statement_execution.get_statement.return_value = GetStatementResponse( + status=StatementStatus(state=StatementState.RUNNING), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + with pytest.raises(TimeoutError): + see.execute("SELECT 2+2", timeout=timedelta(seconds=1)) + + ws.statement_execution.cancel_execution.assert_called_with("bcd") + + +def test_fetch_all_no_chunks(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema( + columns=[ + ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT), + ColumnInfo(name="since", type_name=ColumnInfoTypeName.DATE), + ColumnInfo(name="now", type_name=ColumnInfoTypeName.TIMESTAMP), + ] + ) + ), + result=ResultData(external_links=[ExternalLink(external_link="https://singed-url")]), + statement_id="bcd", + ) + + http_session = create_autospec(requests.Session) + http_session.get("https://singed-url").json.return_value = [ + ["1", "2023-09-01", "2023-09-01T13:21:53Z"], + ["2", "2023-09-01", "2023-09-01T13:21:53Z"], + ] + + see = StatementExecutionExt(ws, warehouse_id="abc", http_session_factory=lambda: http_session) + + rows = list(see.fetch_all("SELECT id, CAST(NOW() AS DATE) AS since, NOW() AS now FROM range(2)")) + + assert len(rows) == 2 + assert rows[0].id == 1 + assert isinstance(rows[0].since, datetime.date) + assert isinstance(rows[0].now, datetime.datetime) + + http_session.get.assert_called_with("https://singed-url") + + +def test_fetch_all_no_chunks_no_converter(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema( + columns=[ + ColumnInfo(name="id", type_name=ColumnInfoTypeName.INTERVAL), + ] + ) + ), + result=ResultData(external_links=[ExternalLink(external_link="https://singed-url")]), + statement_id="bcd", + ) + + http_session = create_autospec(requests.Session) + http_session.get("https://singed-url").json.return_value = [["1"], ["2"]] + + see = StatementExecutionExt(ws, warehouse_id="abc", http_session_factory=lambda: http_session) + + with pytest.raises(ValueError, match="id has no INTERVAL converter"): + list(see.fetch_all("SELECT id FROM range(2)")) + + +def test_fetch_all_two_chunks(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest( + schema=ResultSchema( + columns=[ + ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT), + ColumnInfo(name="now", type_name=ColumnInfoTypeName.TIMESTAMP), + ] + ) + ), + result=ResultData(external_links=[ExternalLink(external_link="https://first", next_chunk_index=1)]), + statement_id="bcd", + ) + + ws.statement_execution.get_statement_result_chunk_n.return_value = ResultData( + external_links=[ExternalLink(external_link="https://second")] + ) + + http_session = create_autospec(requests.Session) + http_session.get(...).json.side_effect = [ + # https://first + [["1", "2023-09-01T13:21:53Z"], ["2", "2023-09-01T13:21:53Z"]], + # https://second + [["3", "2023-09-01T13:21:53Z"], ["4", "2023-09-01T13:21:53Z"]], + ] + + see = StatementExecutionExt(ws, warehouse_id="abc", http_session_factory=lambda: http_session) + + rows = list(see.fetch_all("SELECT id, NOW() AS now FROM range(4)")) + assert len(rows) == 4 + assert [_.id for _ in rows] == [1, 2, 3, 4] + + ws.statement_execution.get_statement_result_chunk_n.assert_called_with("bcd", 1) + + +def test_fetch_one(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), + result=ResultData(data_array=[["4"]]), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + row = see.fetch_one("SELECT 2+2 AS id") + + assert row.id == 4 + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="SELECT 2 + 2 AS id LIMIT 1", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + + +def test_fetch_one_none(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + row = see.fetch_one("SELECT 2+2 AS id") + + assert row is None + + +def test_fetch_one_disable_magic(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), + result=ResultData(data_array=[["4"], ["5"], ["6"]]), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + row = see.fetch_one("SELECT 2+2 AS id", disable_magic=True) + + assert row.id == 4 + + ws.statement_execution.execute_statement.assert_called_with( + warehouse_id="abc", + statement="SELECT 2+2 AS id", + format=Format.JSON_ARRAY, + disposition=None, + byte_limit=None, + catalog=None, + schema=None, + wait_timeout=None, + ) + + +def test_fetch_value(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), + result=ResultData(data_array=[["4"]]), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + value = see.fetch_value("SELECT 2+2 AS id") + + assert value == 4 + + +def test_fetch_value_none(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + value = see.fetch_value("SELECT 2+2 AS id") + + assert value is None + + +def test_callable_returns_iterator(): + ws = create_autospec(WorkspaceClient) + + ws.statement_execution.execute_statement.return_value = ExecuteStatementResponse( + status=StatementStatus(state=StatementState.SUCCEEDED), + manifest=ResultManifest(schema=ResultSchema(columns=[ColumnInfo(name="id", type_name=ColumnInfoTypeName.INT)])), + result=ResultData(data_array=[["4"], ["5"], ["6"]]), + statement_id="bcd", + ) + + see = StatementExecutionExt(ws, warehouse_id="abc") + + rows = list(see("SELECT 2+2 AS id")) + + assert len(rows) == 3 + assert rows == [Row(id=4), Row(id=5), Row(id=6)] diff --git a/tests/unit/test_deployment.py b/tests/unit/test_deployment.py new file mode 100644 index 00000000..cd6fcdbb --- /dev/null +++ b/tests/unit/test_deployment.py @@ -0,0 +1,45 @@ +from dataclasses import dataclass + +from databricks.labs.lsql.backends import MockBackend +from databricks.labs.lsql.deployment import SchemaDeployer + +from . import views + + +def test_deploys_view(): + mock_backend = MockBackend() + deployment = SchemaDeployer( + sql_backend=mock_backend, + inventory_schema="inventory", + mod=views, + ) + + deployment.deploy_view("some", "some.sql") + + assert mock_backend.queries == [ + "CREATE OR REPLACE VIEW hive_metastore.inventory.some AS SELECT id, name FROM hive_metastore.inventory.something" + ] + + +@dataclass +class Foo: + first: str + second: bool + + +def test_deploys_dataclass(): + mock_backend = MockBackend() + deployment = SchemaDeployer( + sql_backend=mock_backend, + inventory_schema="inventory", + mod=views, + ) + deployment.deploy_schema() + deployment.deploy_table("foo", Foo) + deployment.delete_schema() + + assert mock_backend.queries == [ + "CREATE SCHEMA IF NOT EXISTS hive_metastore.inventory", + "CREATE TABLE IF NOT EXISTS hive_metastore.inventory.foo (first STRING NOT NULL, second BOOLEAN NOT NULL) USING DELTA", + "DROP SCHEMA IF EXISTS hive_metastore.inventory CASCADE", + ] diff --git a/tests/unit/views/__init__.py b/tests/unit/views/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/views/some.sql b/tests/unit/views/some.sql new file mode 100644 index 00000000..166a1660 --- /dev/null +++ b/tests/unit/views/some.sql @@ -0,0 +1 @@ +SELECT id, name FROM $inventory.something \ No newline at end of file