Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,7 @@ def _create_arrow_table(self, t_row_set, lz4_compressed, schema_bytes, descripti
num_rows,
) = convert_column_based_set_to_arrow_table(t_row_set.columns, description)
elif t_row_set.arrowBatches is not None:
(
arrow_table,
num_rows,
) = convert_arrow_based_set_to_arrow_table(
(arrow_table, num_rows,) = convert_arrow_based_set_to_arrow_table(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a formatting change unrelated to the bulk of the PR.

t_row_set.arrowBatches, lz4_compressed, schema_bytes
)
else:
Expand Down
1 change: 1 addition & 0 deletions src/databricks/sql/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ def infer_types(params: list[DbSqlParameter]):
int: DbSqlType.INTEGER,
float: DbSqlType.FLOAT,
datetime.datetime: DbSqlType.TIMESTAMP,
datetime.date: DbSqlType.DATE,
bool: DbSqlType.BOOLEAN,
}
newParams = copy.deepcopy(params)
Expand Down
4 changes: 3 additions & 1 deletion src/databricks/sqlalchemy/dialect/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ def __some_example_requirement(self):
import sqlalchemy.testing.exclusions

import logging

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a formatting change unrelated to the rest of this PR.

logger = logging.getLogger(__name__)

logger.warning("requirements.py is not currently employed by Databricks dialect")


class Requirements(sqlalchemy.testing.requirements.SuiteRequirements):
pass
pass
144 changes: 144 additions & 0 deletions tests/e2e/common/parameterized_query_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import datetime
from decimal import Decimal
from typing import Dict, List, Tuple, Union

import pytz

from databricks.sql.client import Connection
from databricks.sql.utils import DbSqlParameter, DbSqlType


class PySQLParameterizedQueryTestSuiteMixin:
"""Namespace for tests of server-side parameterized queries"""

QUERY = "SELECT :p AS col"

def _get_one_result(self, query: str, parameters: Union[Dict, List[Dict]]) -> Tuple:
with self.connection() as conn:
with conn.cursor() as cursor:
cursor.execute(query, parameters=parameters)
return cursor.fetchone()

def _quantize(self, input: Union[float, int], place_value=2) -> Decimal:

return Decimal(str(input)).quantize(Decimal("0." + "0" * place_value))

def test_primitive_inferred_bool(self):

params = {"p": True}
result = self._get_one_result(self.QUERY, params)
assert result.col == True

def test_primitive_inferred_integer(self):

params = {"p": 1}
result = self._get_one_result(self.QUERY, params)
assert result.col == 1

def test_primitive_inferred_double(self):

params = {"p": 3.14}
result = self._get_one_result(self.QUERY, params)
assert self._quantize(result.col) == self._quantize(3.14)

def test_primitive_inferred_date(self):

# DATE in Databricks is mapped into a datetime.date object in Python
date_value = datetime.date(2023, 9, 6)
params = {"p": date_value}
result = self._get_one_result(self.QUERY, params)
assert result.col == date_value

def test_primitive_inferred_timestamp(self):

# TIMESTAMP in Databricks is mapped into a datetime.datetime object in Python
date_value = datetime.datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC)
params = {"p": date_value}
result = self._get_one_result(self.QUERY, params)
assert result.col == date_value

def test_primitive_inferred_string(self):

params = {"p": "Hello"}
result = self._get_one_result(self.QUERY, params)
assert result.col == "Hello"

def test_dbsqlparam_inferred_bool(self):

params = [DbSqlParameter(name="p", value=True, type=None)]
result = self._get_one_result(self.QUERY, params)
assert result.col == True

def test_dbsqlparam_inferred_integer(self):

params = [DbSqlParameter(name="p", value=1, type=None)]
result = self._get_one_result(self.QUERY, params)
assert result.col == 1

def test_dbsqlparam_inferred_double(self):

params = [DbSqlParameter(name="p", value=3.14, type=None)]
result = self._get_one_result(self.QUERY, params)
assert self._quantize(result.col) == self._quantize(3.14)

def test_dbsqlparam_inferred_date(self):

# DATE in Databricks is mapped into a datetime.date object in Python
date_value = datetime.date(2023, 9, 6)
params = [DbSqlParameter(name="p", value=date_value, type=None)]
result = self._get_one_result(self.QUERY, params)
assert result.col == date_value

def test_dbsqlparam_inferred_timestamp(self):

# TIMESTAMP in Databricks is mapped into a datetime.datetime object in Python
date_value = datetime.datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC)
params = [DbSqlParameter(name="p", value=date_value, type=None)]
result = self._get_one_result(self.QUERY, params)
assert result.col == date_value

def test_dbsqlparam_inferred_string(self):

params = [DbSqlParameter(name="p", value="Hello", type=None)]
result = self._get_one_result(self.QUERY, params)
assert result.col == "Hello"

def test_dbsqlparam_explicit_bool(self):

params = [DbSqlParameter(name="p", value=True, type=DbSqlType.BOOLEAN)]
result = self._get_one_result(self.QUERY, params)
assert result.col == True

def test_dbsqlparam_explicit_integer(self):

params = [DbSqlParameter(name="p", value=1, type=DbSqlType.INTEGER)]
result = self._get_one_result(self.QUERY, params)
assert result.col == 1

def test_dbsqlparam_explicit_double(self):

params = [DbSqlParameter(name="p", value=3.14, type=DbSqlType.FLOAT)]
result = self._get_one_result(self.QUERY, params)
assert self._quantize(result.col) == self._quantize(3.14)

def test_dbsqlparam_explicit_date(self):

# DATE in Databricks is mapped into a datetime.date object in Python
date_value = datetime.date(2023, 9, 6)
params = [DbSqlParameter(name="p", value=date_value, type=DbSqlType.DATE)]
result = self._get_one_result(self.QUERY, params)
assert result.col == date_value

def test_dbsqlparam_explicit_timestamp(self):

# TIMESTAMP in Databricks is mapped into a datetime.datetime object in Python
date_value = datetime.datetime(2023, 9, 6, 3, 14, 27, 843, tzinfo=pytz.UTC)
params = [DbSqlParameter(name="p", value=date_value, type=DbSqlType.TIMESTAMP)]
result = self._get_one_result(self.QUERY, params)
assert result.col == date_value

def test_dbsqlparam_explicit_string(self):

params = [DbSqlParameter(name="p", value="Hello", type=DbSqlType.STRING)]
result = self._get_one_result(self.QUERY, params)
assert result.col == "Hello"
3 changes: 2 additions & 1 deletion tests/e2e/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tests.e2e.common.retry_test_mixins import Client429ResponseMixin, Client503ResponseMixin
from tests.e2e.common.staging_ingestion_tests import PySQLStagingIngestionTestSuiteMixin
from tests.e2e.common.retry_test_mixins import PySQLRetryTestsMixin
from tests.e2e.common.parameterized_query_tests import PySQLParameterizedQueryTestSuiteMixin

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,7 +143,7 @@ def test_cloud_fetch(self):
# Exclude Retry tests because they require specific setups, and LargeQueries too slow for core
# tests
class PySQLCoreTestSuite(SmokeTestMixin, CoreTestMixin, DecimalTestsMixin, TimestampTestsMixin,
PySQLTestCase, PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin):
PySQLTestCase, PySQLStagingIngestionTestSuiteMixin, PySQLRetryTestsMixin, PySQLParameterizedQueryTestSuiteMixin):
validate_row_value_type = True
validate_result = True

Expand Down