145 changes: 99 additions & 46 deletions ibis/bigquery/client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""BigQuery ibis client implementation."""

import datetime

from collections import OrderedDict
from pkg_resources import parse_version
from typing import Optional, Tuple

import regex as re

import six

import pandas as pd

from google.api_core.exceptions import NotFound
Expand All @@ -21,7 +23,6 @@
import ibis.expr.datatypes as dt
import ibis.expr.lineage as lin

from ibis.compat import parse_version
from ibis.client import Database, Query, SQLClient
from ibis.bigquery import compiler as comp
from ibis.bigquery.datatypes import ibis_type_to_bigquery_type
Expand Down Expand Up @@ -53,6 +54,7 @@

@dt.dtype.register(bq.schema.SchemaField)
def bigquery_field_to_ibis_dtype(field):
"""Convert BigQuery `field` to an ibis type."""
typ = field.field_type
if typ == 'RECORD':
fields = field.fields
Expand All @@ -70,6 +72,7 @@ def bigquery_field_to_ibis_dtype(field):

@sch.infer.register(bq.table.Table)
def bigquery_schema(table):
"""Infer the schema of a BigQuery `table` object."""
fields = OrderedDict((el.name, dt.dtype(el)) for el in table.schema)
partition_info = table._properties.get('timePartitioning', None)

Expand All @@ -82,39 +85,58 @@ def bigquery_schema(table):
return sch.schema(fields)


class BigQueryCursor(object):
"""Cursor to allow the BigQuery client to reuse machinery in ibis/client.py
class BigQueryCursor:
"""BigQuery cursor.
This allows the BigQuery client to reuse machinery in
:file:`ibis/client.py`.
"""

def __init__(self, query):
"""Construct a BigQueryCursor with query `query`."""
self.query = query

def fetchall(self):
"""Fetch all rows."""
result = self.query.result()
return [row.values() for row in result]

@property
def columns(self):
"""Return the columns of the result set."""
result = self.query.result()
return [field.name for field in result.schema]

@property
def description(self):
"""Get the fields of the result set's schema."""
result = self.query.result()
return [field for field in result.schema]

def __enter__(self):
# For compatibility when constructed from Query.execute()
"""No-op for compatibility.
See Also
--------
ibis.client.Query.execute
"""
return self

def __exit__(self, exc_type, exc_value, traceback):
pass
"""No-op for compatibility.
See Also
--------
ibis.client.Query.execute
"""


def _find_scalar_parameter(expr):
""":func:`~ibis.expr.lineage.traverse` function to find all
:class:`~ibis.expr.types.ScalarParameter` instances and yield the operation
and the parent expresssion's resolved name.
"""Find all :class:`~ibis.expr.types.ScalarParameter` instances.
Parameters
----------
Expand All @@ -123,6 +145,8 @@ def _find_scalar_parameter(expr):
Returns
-------
Tuple[bool, object]
The operation and the parent expresssion's resolved name.
"""
op = expr.op()

Expand All @@ -134,17 +158,18 @@ def _find_scalar_parameter(expr):


class BigQueryQuery(Query):

def __init__(self, client, ddl, query_parameters=None):
super(BigQueryQuery, self).__init__(client, ddl)
super().__init__(client, ddl)

# self.expr comes from the parent class
query_parameter_names = dict(
lin.traverse(_find_scalar_parameter, self.expr))
lin.traverse(_find_scalar_parameter, self.expr)
)
self.query_parameters = [
bigquery_param(
param.to_expr().name(query_parameter_names[param]), value
) for param, value in (query_parameters or {}).items()
)
for param, value in (query_parameters or {}).items()
]

def _fetch(self, cursor):
Expand All @@ -157,15 +182,15 @@ def execute(self):
with self.client._execute(
self.compiled_sql,
results=True,
query_parameters=self.query_parameters
query_parameters=self.query_parameters,
) as cur:
result = self._fetch(cur)

return self._wrap_result(result)


class BigQueryDatabase(Database):
pass
"""A BigQuery dataset."""


bigquery_param = Dispatcher('bigquery_param')
Expand Down Expand Up @@ -199,29 +224,30 @@ def bq_param_array(param, value):
else:
query_value = value
result = bq.ArrayQueryParameter(
param.get_name(), bigquery_type, query_value)
param.get_name(), bigquery_type, query_value
)
return result


@bigquery_param.register(
ir.TimestampScalar,
six.string_types + (datetime.datetime, datetime.date)
ir.TimestampScalar, (str, datetime.datetime, datetime.date)
)
def bq_param_timestamp(param, value):
assert isinstance(param.type(), dt.Timestamp), str(param.type())

# TODO(phillipc): Not sure if this is the correct way to do this.
timestamp_value = pd.Timestamp(value, tz='UTC').to_pydatetime()
return bq.ScalarQueryParameter(
param.get_name(), 'TIMESTAMP', timestamp_value)
param.get_name(), 'TIMESTAMP', timestamp_value
)


@bigquery_param.register(ir.StringScalar, six.string_types)
@bigquery_param.register(ir.StringScalar, str)
def bq_param_string(param, value):
return bq.ScalarQueryParameter(param.get_name(), 'STRING', value)


@bigquery_param.register(ir.IntegerScalar, six.integer_types)
@bigquery_param.register(ir.IntegerScalar, int)
def bq_param_integer(param, value):
return bq.ScalarQueryParameter(param.get_name(), 'INT64', value)

Expand All @@ -236,7 +262,7 @@ def bq_param_boolean(param, value):
return bq.ScalarQueryParameter(param.get_name(), 'BOOL', value)


@bigquery_param.register(ir.DateScalar, six.string_types)
@bigquery_param.register(ir.DateScalar, str)
def bq_param_date_string(param, value):
return bigquery_param(param, pd.Timestamp(value).to_pydatetime().date())

Expand Down Expand Up @@ -279,21 +305,21 @@ def rename_partitioned_column(table_expr, bq_table):
return table_expr.relabel({NATIVE_PARTITION_COL: col})


def parse_project_and_dataset(project, dataset):
"""Figure out the project id under which queries will run versus the
project of where the data live as well as what dataset to use.
def parse_project_and_dataset(
project: str, dataset: Optional[str] = None
) -> Tuple[str, str, Optional[str]]:
"""Compute the billing project, data project, and dataset if available.
This function figure out the project id under which queries will run versus
the project of where the data live as well as what dataset to use.
Parameters
----------
project : str
A project name
dataset : str
dataset : Optional[str]
A ``<project>.<dataset>`` string or just a dataset name
Returns
-------
data_project, billing_project, dataset : str, str, str
Examples
--------
>>> data_project, billing_project, dataset = parse_project_and_dataset(
Expand All @@ -316,40 +342,57 @@ def parse_project_and_dataset(project, dataset):
'ibis-gbq'
>>> dataset
'my_dataset'
>>> data_project, billing_project, dataset = parse_project_and_dataset(
... 'ibis-gbq'
... )
>>> data_project
'ibis-gbq'
>>> print(dataset)
None
"""
try:
data_project, dataset = dataset.split('.')
except ValueError:
except (ValueError, AttributeError):
billing_project = data_project = project
else:
billing_project = project

return data_project, billing_project, dataset


class BigQueryClient(SQLClient):
"""An ibis BigQuery client implementation."""

query_class = BigQueryQuery
database_class = BigQueryDatabase
table_class = BigQueryTable
dialect = comp.BigQueryDialect

def __init__(self, project_id, dataset_id, credentials=None):
"""
def __init__(self, project_id, dataset_id=None, credentials=None):
"""Construct a BigQueryClient.
Parameters
----------
project_id : str
A project name
dataset_id : str
dataset_id : Optional[str]
A ``<project_id>.<dataset_id>`` string or just a dataset name
credentials : google.auth.credentials.Credentials, optional
credentials : google.auth.credentials.Credentials
"""
(self.data_project,
self.billing_project,
self.dataset) = parse_project_and_dataset(project_id, dataset_id)
self.client = bq.Client(project=self.data_project,
credentials=credentials)
(
self.data_project,
self.billing_project,
self.dataset,
) = parse_project_and_dataset(project_id, dataset_id)
self.client = bq.Client(
project=self.data_project, credentials=credentials
)

def _parse_project_and_dataset(self, dataset):
if not dataset and not self.dataset:
raise ValueError("Unable to determine BigQuery dataset.")
project, _, dataset = parse_project_and_dataset(
self.billing_project,
dataset or '{}.{}'.format(self.data_project, self.dataset),
Expand All @@ -365,7 +408,7 @@ def dataset_id(self):
return self.dataset

def table(self, name, database=None):
t = super(BigQueryClient, self).table(name, database=database)
t = super().table(name, database=database)
project, dataset, name = t.op().name.split('.')
dataset_ref = self.client.dataset(dataset, project=project)
table_ref = dataset_ref.table(name)
Expand All @@ -377,16 +420,18 @@ def _build_ast(self, expr, context):
return result

def _execute_query(self, dml):
query = self.query_class(self, dml,
query_parameters=dml.context.params)
query = self.query_class(
self, dml, query_parameters=dml.context.params
)
return query.execute()

def _fully_qualified_name(self, name, database):
project, dataset = self._parse_project_and_dataset(database)
return '{}.{}.{}'.format(project, dataset, name)
return "{}.{}.{}".format(project, dataset, name)

def _get_table_schema(self, qualified_name):
dataset, table = qualified_name.rsplit('.', 1)
assert dataset is not None, "dataset is None"
return self.get_schema(table, database=dataset)

def _get_schema_using_query(self, limited_query):
Expand Down Expand Up @@ -415,6 +460,12 @@ def _execute(self, stmt, results=True, query_parameters=None):
return BigQueryCursor(query)

def database(self, name=None):
if name is None and self.dataset is None:
raise ValueError(
"Unable to determine BigQuery dataset. Call "
"client.database('my_dataset') or set_database('my_dataset') "
"to assign your client a dataset."
)
return self.database_class(name or self.dataset, self)

@property
Expand All @@ -441,7 +492,8 @@ def list_databases(self, like=None):
]
if like:
results = [
dataset_name for dataset_name in results
dataset_name
for dataset_name in results
if re.match(like, dataset_name) is not None
]
return results
Expand All @@ -466,7 +518,8 @@ def list_tables(self, like=None, database=None):
]
if like:
result = [
table_name for table_name in result
table_name
for table_name in result
if re.match(like, table_name) is not None
]
return result
Expand Down
189 changes: 92 additions & 97 deletions ibis/bigquery/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import regex as re

import six

import toolz

Expand All @@ -23,7 +22,11 @@
import ibis.expr.lineage as lin

from ibis.impala.compiler import (
ImpalaSelect, unary, fixed_arity, ImpalaTableSetFormatter, _reduction
ImpalaSelect,
unary,
fixed_arity,
ImpalaTableSetFormatter,
_reduction,
)
from ibis.impala import compiler as impala_compiler

Expand All @@ -35,14 +38,12 @@ class BigQueryUDFNode(ops.ValueOp):


class BigQuerySelectBuilder(comp.SelectBuilder):

@property
def _select_class(self):
return BigQuerySelect


class BigQueryUDFDefinition(comp.DDL):

def __init__(self, expr, context):
self.expr = expr
self.context = context
Expand Down Expand Up @@ -73,13 +74,14 @@ class BigQueryQueryBuilder(comp.QueryBuilder):
def generate_setup_queries(self):
queries = map(
partial(BigQueryUDFDefinition, context=self.context),
lin.traverse(find_bigquery_udf, self.expr)
lin.traverse(find_bigquery_udf, self.expr),
)

# UDFs are uniquely identified by the name of the Node subclass we
# generate.
return list(
toolz.unique(queries, key=lambda x: type(x.expr.op()).__name__))
toolz.unique(queries, key=lambda x: type(x.expr.op()).__name__)
)


def build_ast(expr, context):
Expand All @@ -94,7 +96,6 @@ def to_sql(expr, context):


class BigQueryContext(comp.QueryContext):

def _to_sql(self, expr, ctx):
return to_sql(expr, context=ctx)

Expand All @@ -104,18 +105,19 @@ def extract_field_formatter(translator, expr):
op = expr.op()
arg = translator.translate(op.args[0])
return 'EXTRACT({} from {})'.format(sql_attr, arg)

return extract_field_formatter


bigquery_cast = Dispatcher('bigquery_cast')


@bigquery_cast.register(six.string_types, dt.Timestamp, dt.Integer)
@bigquery_cast.register(str, dt.Timestamp, dt.Integer)
def bigquery_cast_timestamp_to_integer(compiled_arg, from_, to):
return 'UNIX_MICROS({})'.format(compiled_arg)


@bigquery_cast.register(six.string_types, dt.DataType, dt.DataType)
@bigquery_cast.register(str, dt.DataType, dt.DataType)
def bigquery_cast_generate(compiled_arg, from_, to):
sql_type = ibis_type_to_bigquery_type(to)
return 'CAST({} AS {})'.format(compiled_arg, sql_type)
Expand Down Expand Up @@ -156,8 +158,7 @@ def _string_find(translator, expr):
raise NotImplementedError('end not implemented for string find')

return 'STRPOS({}, {}) - 1'.format(
translator.translate(haystack),
translator.translate(needle)
translator.translate(haystack), translator.translate(needle)
)


Expand All @@ -179,9 +180,7 @@ def _regex_extract(translator, expr):
arg, pattern, index = expr.op().args
regex = _translate_pattern(translator, pattern)
result = 'REGEXP_EXTRACT_ALL({}, {})[SAFE_OFFSET({})]'.format(
translator.translate(arg),
regex,
translator.translate(index)
translator.translate(arg), regex, translator.translate(index)
)
return result

Expand All @@ -190,9 +189,7 @@ def _regex_replace(translator, expr):
arg, pattern, replacement = expr.op().args
regex = _translate_pattern(translator, pattern)
result = 'REGEXP_REPLACE({}, {}, {})'.format(
translator.translate(arg),
regex,
translator.translate(replacement),
translator.translate(arg), regex, translator.translate(replacement)
)
return result

Expand All @@ -206,8 +203,7 @@ def _string_concat(translator, expr):
def _string_join(translator, expr):
sep, args = expr.op().args
return 'ARRAY_TO_STRING([{}], {})'.format(
', '.join(map(translator.translate, args)),
translator.translate(sep)
', '.join(map(translator.translate, args)), translator.translate(sep)
)


Expand All @@ -221,8 +217,7 @@ def _string_ascii(translator, expr):
def _string_right(translator, expr):
arg, nchars = map(translator.translate, expr.op().args)
return 'SUBSTR({arg}, -LEAST(LENGTH({arg}), {nchars}))'.format(
arg=arg,
nchars=nchars,
arg=arg, nchars=nchars
)


Expand Down Expand Up @@ -278,7 +273,7 @@ def _arbitrary(translator, expr):
if where is not None:
arg = where.ifelse(arg, ibis.NA)

if how != 'first':
if how not in (None, 'first'):
raise com.UnsupportedOperationError(
'{!r} value not supported for arbitrary in BigQuery'.format(how)
)
Expand Down Expand Up @@ -317,6 +312,7 @@ def truncator(translator, expr):
'{!r}'.format(arg.type(), unit)
)
return '{}_TRUNC({}, {})'.format(kind, trans_arg, valid_unit)

return truncator


Expand Down Expand Up @@ -347,78 +343,71 @@ def _formatter(translator, expr):


_operation_registry = impala_compiler._operation_registry.copy()
_operation_registry.update({
ops.ExtractYear: _extract_field('year'),
ops.ExtractMonth: _extract_field('month'),
ops.ExtractDay: _extract_field('day'),
ops.ExtractHour: _extract_field('hour'),
ops.ExtractMinute: _extract_field('minute'),
ops.ExtractSecond: _extract_field('second'),
ops.ExtractMillisecond: _extract_field('millisecond'),

ops.StringReplace: fixed_arity('REPLACE', 3),
ops.StringSplit: fixed_arity('SPLIT', 2),
ops.StringConcat: _string_concat,
ops.StringJoin: _string_join,
ops.StringAscii: _string_ascii,
ops.StringFind: _string_find,
ops.StrRight: _string_right,
ops.Repeat: fixed_arity('REPEAT', 2),
ops.RegexSearch: _regex_search,
ops.RegexExtract: _regex_extract,
ops.RegexReplace: _regex_replace,

ops.GroupConcat: fixed_arity('STRING_AGG', 2),

ops.IfNull: fixed_arity('IFNULL', 2),
ops.Cast: _cast,

ops.StructField: _struct_field,

ops.ArrayCollect: unary('ARRAY_AGG'),
ops.ArrayConcat: _array_concat,
ops.ArrayIndex: _array_index,
ops.ArrayLength: unary('ARRAY_LENGTH'),

ops.HLLCardinality: _reduction('APPROX_COUNT_DISTINCT'),
ops.Log: _log,
ops.Sign: unary('SIGN'),
ops.Modulus: fixed_arity('MOD', 2),

ops.Date: unary('DATE'),

# BigQuery doesn't have these operations built in.
# ops.ArrayRepeat: _array_repeat,
# ops.ArraySlice: _array_slice,
ops.Literal: _literal,
ops.Arbitrary: _arbitrary,

ops.TimestampTruncate: _truncate('TIMESTAMP', _timestamp_units),
ops.DateTruncate: _truncate('DATE', _date_units),
ops.TimeTruncate: _truncate('TIME', _timestamp_units),

ops.Time: unary('TIME'),

ops.TimestampAdd: _timestamp_op(
'TIMESTAMP_ADD', {'h', 'm', 's', 'ms', 'us'}),
ops.TimestampSub: _timestamp_op(
'TIMESTAMP_DIFF', {'h', 'm', 's', 'ms', 'us'}),

ops.DateAdd: _timestamp_op('DATE_ADD', {'D', 'W', 'M', 'Q', 'Y'}),
ops.DateSub: _timestamp_op('DATE_SUB', {'D', 'W', 'M', 'Q', 'Y'}),
ops.TimestampNow: fixed_arity('CURRENT_TIMESTAMP', 0),
})
_operation_registry.update(
{
ops.ExtractYear: _extract_field('year'),
ops.ExtractMonth: _extract_field('month'),
ops.ExtractDay: _extract_field('day'),
ops.ExtractHour: _extract_field('hour'),
ops.ExtractMinute: _extract_field('minute'),
ops.ExtractSecond: _extract_field('second'),
ops.ExtractMillisecond: _extract_field('millisecond'),
ops.StringReplace: fixed_arity('REPLACE', 3),
ops.StringSplit: fixed_arity('SPLIT', 2),
ops.StringConcat: _string_concat,
ops.StringJoin: _string_join,
ops.StringAscii: _string_ascii,
ops.StringFind: _string_find,
ops.StrRight: _string_right,
ops.Repeat: fixed_arity('REPEAT', 2),
ops.RegexSearch: _regex_search,
ops.RegexExtract: _regex_extract,
ops.RegexReplace: _regex_replace,
ops.GroupConcat: fixed_arity('STRING_AGG', 2),
ops.IfNull: fixed_arity('IFNULL', 2),
ops.Cast: _cast,
ops.StructField: _struct_field,
ops.ArrayCollect: unary('ARRAY_AGG'),
ops.ArrayConcat: _array_concat,
ops.ArrayIndex: _array_index,
ops.ArrayLength: unary('ARRAY_LENGTH'),
ops.HLLCardinality: _reduction('APPROX_COUNT_DISTINCT'),
ops.Log: _log,
ops.Sign: unary('SIGN'),
ops.Modulus: fixed_arity('MOD', 2),
ops.Date: unary('DATE'),
# BigQuery doesn't have these operations built in.
# ops.ArrayRepeat: _array_repeat,
# ops.ArraySlice: _array_slice,
ops.Literal: _literal,
ops.Arbitrary: _arbitrary,
ops.TimestampTruncate: _truncate('TIMESTAMP', _timestamp_units),
ops.DateTruncate: _truncate('DATE', _date_units),
ops.TimeTruncate: _truncate('TIME', _timestamp_units),
ops.Time: unary('TIME'),
ops.TimestampAdd: _timestamp_op(
'TIMESTAMP_ADD', {'h', 'm', 's', 'ms', 'us'}
),
ops.TimestampSub: _timestamp_op(
'TIMESTAMP_DIFF', {'h', 'm', 's', 'ms', 'us'}
),
ops.DateAdd: _timestamp_op('DATE_ADD', {'D', 'W', 'M', 'Q', 'Y'}),
ops.DateSub: _timestamp_op('DATE_SUB', {'D', 'W', 'M', 'Q', 'Y'}),
ops.TimestampNow: fixed_arity('CURRENT_TIMESTAMP', 0),
}
)

_invalid_operations = {
ops.Translate,
ops.FindInSet,
ops.Capitalize,
ops.DateDiff,
ops.TimestampDiff
ops.TimestampDiff,
}

_operation_registry = {
k: v for k, v in _operation_registry.items()
k: v
for k, v in _operation_registry.items()
if k not in _invalid_operations
}

Expand Down Expand Up @@ -470,12 +459,10 @@ def compiles_strftime(translator, expr):
strftime_format_func_name,
fmt_string,
arg_formatted,
arg_type.timezone if arg_type.timezone is not None else 'UTC'
arg_type.timezone if arg_type.timezone is not None else 'UTC',
)
return 'FORMAT_{}({}, {})'.format(
strftime_format_func_name,
fmt_string,
arg_formatted
strftime_format_func_name, fmt_string, arg_formatted
)


Expand All @@ -487,9 +474,7 @@ def compiles_string_to_timestamp(translator, expr):
if timezone_arg is not None:
timezone_str = translator.translate(timezone_arg)
return 'PARSE_TIMESTAMP({}, {}, {})'.format(
fmt_string,
arg_formatted,
timezone_str
fmt_string, arg_formatted, timezone_str
)
return 'PARSE_TIMESTAMP({}, {})'.format(fmt_string, arg_formatted)

Expand Down Expand Up @@ -566,11 +551,7 @@ def bq_mean(expr):
return expr


UNIT_FUNCS = {
's': 'SECONDS',
'ms': 'MILLIS',
'us': 'MICROS',
}
UNIT_FUNCS = {'s': 'SECONDS', 'ms': 'MILLIS', 'us': 'MICROS'}


@compiles(ops.TimestampFromUNIX)
Expand All @@ -586,6 +567,20 @@ def compiles_floor(t, e):
return 'CAST(FLOOR({}) AS {})'.format(t.translate(arg), bigquery_type)


@compiles(ops.CMSMedian)
def compiles_approx(translator, expr):
expr = expr.op()
arg = expr.arg
where = expr.where

if where is not None:
arg = where.ifelse(arg, ibis.NA)

return 'APPROX_QUANTILES({}, 2)[OFFSET(1)]'.format(
translator.translate(arg)
)


class BigQueryDialect(impala_compiler.ImpalaDialect):

translator = BigQueryExprTranslator
Expand Down
18 changes: 9 additions & 9 deletions ibis/bigquery/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import six

from multipledispatch import Dispatcher

import ibis.expr.datatypes as dt


class TypeTranslationContext(object):
class TypeTranslationContext:
"""A tag class to allow alteration of the way a particular type is
translated.
Expand All @@ -15,6 +13,7 @@ class TypeTranslationContext(object):
avoid suprising results due to BigQuery's handling of INT64 types in
JavaScript UDFs.
"""

__slots__ = ()


Expand All @@ -25,7 +24,7 @@ class UDFContext(TypeTranslationContext):
ibis_type_to_bigquery_type = Dispatcher('ibis_type_to_bigquery_type')


@ibis_type_to_bigquery_type.register(six.string_types)
@ibis_type_to_bigquery_type.register(str)
def trans_string_default(datatype):
return ibis_type_to_bigquery_type(dt.dtype(datatype))

Expand All @@ -35,7 +34,7 @@ def trans_default(t):
return ibis_type_to_bigquery_type(t, TypeTranslationContext())


@ibis_type_to_bigquery_type.register(six.string_types, TypeTranslationContext)
@ibis_type_to_bigquery_type.register(str, TypeTranslationContext)
def trans_string_context(datatype, context):
return ibis_type_to_bigquery_type(dt.dtype(datatype), context)

Expand All @@ -62,17 +61,18 @@ def trans_lossy_integer(t, context):
@ibis_type_to_bigquery_type.register(dt.Array, TypeTranslationContext)
def trans_array(t, context):
return 'ARRAY<{}>'.format(
ibis_type_to_bigquery_type(t.value_type, context))
ibis_type_to_bigquery_type(t.value_type, context)
)


@ibis_type_to_bigquery_type.register(dt.Struct, TypeTranslationContext)
def trans_struct(t, context):
return 'STRUCT<{}>'.format(
', '.join(
'{} {}'.format(
name,
ibis_type_to_bigquery_type(dt.dtype(type), context)
) for name, type in zip(t.names, t.types)
name, ibis_type_to_bigquery_type(dt.dtype(type), context)
)
for name, type in zip(t.names, t.types)
)
)

Expand Down
41 changes: 24 additions & 17 deletions ibis/bigquery/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,34 @@

def connect(project_id, dataset_id):
ga = pytest.importorskip('google.auth')
google_application_credentials = os.environ.get(
"GOOGLE_APPLICATION_CREDENTIALS", None
)
if google_application_credentials is None:
pytest.skip(
'Environment variable GOOGLE_APPLICATION_CREDENTIALS is '
'not defined'
)
elif not google_application_credentials:
pytest.skip(
'Environment variable GOOGLE_APPLICATION_CREDENTIALS is empty'
)
elif not os.path.exists(google_application_credentials):
pytest.skip(
'Environment variable GOOGLE_APPLICATION_CREDENTIALS points '
'to {}, which does not exist'.format(
google_application_credentials
)
)

skip_message = (
'No BigQuery credentials found using project_id={}, '
'dataset_id={}. Skipping BigQuery tests.'
).format(project_id, dataset_id)
try:
return ibis.bigquery.connect(project_id, dataset_id)
except ga.exceptions.DefaultCredentialsError:
pytest.skip(
'no BigQuery credentials found (project_id={}, dataset_id={}), '
'skipping'.format(project_id, dataset_id)
)
pytest.skip(skip_message)


@pytest.fixture(scope='session')
Expand All @@ -31,19 +51,6 @@ def client():
return connect(PROJECT_ID, DATASET_ID)


@pytest.fixture(scope='session')
def client_no_credentials():
ga = pytest.importorskip('google.auth')

try:
return ibis.bigquery.connect(PROJECT_ID, DATASET_ID, credentials=None)
except ga.exceptions.DefaultCredentialsError:
pytest.skip(
'no BigQuery credentials found (project_id={}, dataset_id={}), '
'skipping'.format(PROJECT_ID, DATASET_ID)
)


@pytest.fixture(scope='session')
def client2():
return connect(PROJECT_ID, DATASET_ID)
Expand Down
321 changes: 176 additions & 145 deletions ibis/bigquery/tests/test_client.py

Large diffs are not rendered by default.

228 changes: 134 additions & 94 deletions ibis/bigquery/tests/test_compiler.py

Large diffs are not rendered by default.

61 changes: 33 additions & 28 deletions ibis/bigquery/tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import pytest

from pytest import param

from multipledispatch.conflict import ambiguities

import ibis.expr.datatypes as dt
from ibis.bigquery.datatypes import (
ibis_type_to_bigquery_type, UDFContext, TypeTranslationContext
ibis_type_to_bigquery_type,
UDFContext,
TypeTranslationContext,
)


Expand All @@ -29,27 +33,29 @@ def test_no_ambiguities():
(dt.Array(dt.int64), 'ARRAY<INT64>'),
(dt.Array(dt.string), 'ARRAY<STRING>'),
(
dt.Struct.from_tuples([
('a', dt.int64),
('b', dt.string),
('c', dt.Array(dt.string)),
]),
'STRUCT<a INT64, b STRING, c ARRAY<STRING>>'
dt.Struct.from_tuples(
[('a', dt.int64), ('b', dt.string), ('c', dt.Array(dt.string))]
),
'STRUCT<a INT64, b STRING, c ARRAY<STRING>>',
),
(dt.date, 'DATE'),
(dt.timestamp, 'TIMESTAMP'),
pytest.mark.xfail(
(dt.timestamp(timezone='US/Eastern'), 'TIMESTAMP'),
raises=TypeError,
reason='Not supported in BigQuery'
param(
dt.Timestamp(timezone='US/Eastern'),
'TIMESTAMP',
marks=pytest.mark.xfail(
raises=TypeError, reason='Not supported in BigQuery'
),
),
('array<struct<a: string>>', 'ARRAY<STRUCT<a STRING>>'),
pytest.mark.xfail(
(dt.Decimal(38, 9), 'NUMERIC'),
raises=TypeError,
reason='Not supported in BigQuery'
param(
dt.Decimal(38, 9),
'NUMERIC',
marks=pytest.mark.xfail(
raises=TypeError, reason='Not supported in BigQuery'
),
),
]
],
)
def test_simple(datatype, expected):
context = TypeTranslationContext()
Expand All @@ -65,19 +71,18 @@ def test_simple_failure_mode(datatype):
@pytest.mark.parametrize(
('type', 'expected'),
[
pytest.mark.xfail((dt.int64, 'INT64'), raises=TypeError),
pytest.mark.xfail(
(dt.Array(dt.int64), 'ARRAY<INT64>'),
raises=TypeError
param(dt.int64, 'INT64', marks=pytest.mark.xfail(raises=TypeError)),
param(
dt.Array(dt.int64),
'ARRAY<INT64>',
marks=pytest.mark.xfail(raises=TypeError),
),
pytest.mark.xfail(
(
dt.Struct.from_tuples([('a', dt.Array(dt.int64))]),
'STRUCT<a ARRAY<INT64>>'
),
raises=TypeError,
)
]
param(
dt.Struct.from_tuples([('a', dt.Array(dt.int64))]),
'STRUCT<a ARRAY<INT64>>',
marks=pytest.mark.xfail(raises=TypeError),
),
],
)
def test_ibis_type_to_bigquery_type_udf(type, expected):
context = UDFContext()
Expand Down
53 changes: 29 additions & 24 deletions ibis/bigquery/udf/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import collections
import functools
import inspect
import itertools

import ibis.expr.rules as rlz
import ibis.expr.datatypes as dt

from ibis.compat import functools
from ibis.expr.signature import Argument as Arg

from ibis.bigquery.compiler import BigQueryUDFNode, compiles
Expand All @@ -14,7 +14,7 @@
from ibis.bigquery.datatypes import ibis_type_to_bigquery_type, UDFContext


__all__ = 'udf',
__all__ = ('udf',)


_udf_name_cache = collections.defaultdict(itertools.count)
Expand Down Expand Up @@ -68,7 +68,7 @@ def udf(input_type, output_type, strict=True, libraries=None):
Examples
--------
>>> from ibis.bigquery.api import udf
>>> from ibis.bigquery import udf
>>> import ibis.expr.datatypes as dt
>>> @udf(input_type=[dt.double], output_type=dt.double)
... def add_one(x):
Expand Down Expand Up @@ -169,37 +169,42 @@ def wrapper(f):
signature = inspect.signature(f)
parameter_names = signature.parameters.keys()

udf_node_fields = collections.OrderedDict([
(name, Arg(rlz.value(type)))
for name, type in zip(parameter_names, input_type)
] + [
(
'output_type',
lambda self, output_type=output_type: rlz.shape_like(
self.args, dtype=output_type
)
),
('__slots__', ('js',)),
])
udf_node_fields = collections.OrderedDict(
[
(name, Arg(rlz.value(type)))
for name, type in zip(parameter_names, input_type)
]
+ [
(
'output_type',
lambda self, output_type=output_type: rlz.shape_like(
self.args, dtype=output_type
),
),
('__slots__', ('js',)),
]
)

udf_node = create_udf_node(f.__name__, udf_node_fields)

@compiles(udf_node)
def compiles_udf_node(t, expr):
return '{}({})'.format(
udf_node.__name__,
', '.join(map(t.translate, expr.op().args))
udf_node.__name__, ', '.join(map(t.translate, expr.op().args))
)

type_translation_context = UDFContext()
return_type = ibis_type_to_bigquery_type(
dt.dtype(output_type), type_translation_context)
dt.dtype(output_type), type_translation_context
)
bigquery_signature = ', '.join(
'{name} {type}'.format(
name=name,
type=ibis_type_to_bigquery_type(
dt.dtype(type), type_translation_context)
) for name, type in zip(parameter_names, input_type)
dt.dtype(type), type_translation_context
),
)
for name, type in zip(parameter_names, input_type)
)
source = PythonToJavaScriptTranslator(f).compile()
js = '''\
Expand All @@ -217,10 +222,10 @@ def compiles_udf_node(t, expr):
strict=repr('use strict') + ';\n' if strict else '',
args=', '.join(parameter_names),
libraries=(
'\nOPTIONS (\n library={}\n)'.format(
repr(list(libraries))
) if libraries else ''
)
'\nOPTIONS (\n library={}\n)'.format(repr(list(libraries)))
if libraries
else ''
),
)

@functools.wraps(f)
Expand Down
63 changes: 29 additions & 34 deletions ibis/bigquery/udf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import inspect
import textwrap

import six

import ibis.expr.datatypes as dt

Expand All @@ -26,6 +25,7 @@ class SymbolTable(ChainMap):
shove a "let" at the beginning of every variable name if it doesn't already
exist in the current scope.
"""

def __getitem__(self, key):
if key not in self:
self[key] = key
Expand All @@ -47,7 +47,7 @@ def indent(lines, spaces=4):
-------
indented_lines : str
"""
if isinstance(lines, six.string_types):
if isinstance(lines, str):
text = [lines]
text = '\n'.join(lines)
return textwrap.indent(text, ' ' * spaces)
Expand All @@ -60,9 +60,11 @@ def semicolon(f):
----------
f : callable
"""

@functools.wraps(f)
def wrapper(*args, **kwargs):
return f(*args, **kwargs) + ';'

return wrapper


Expand All @@ -72,7 +74,7 @@ def rewrite_print(node):
func=ast.Attribute(
value=ast.Name(id='console', ctx=ast.Load()),
attr='log',
ctx=ast.Load()
ctx=ast.Load(),
),
args=node.args,
keywords=node.keywords,
Expand All @@ -88,11 +90,7 @@ def rewrite_len(node):
@rewrite.register(ast.Call(func=ast.Attribute(attr='append')))
def rewrite_append(node):
return ast.Call(
func=ast.Attribute(
value=node.func.value,
attr='push',
ctx=ast.Load(),
),
func=ast.Attribute(value=node.func.value, attr='push', ctx=ast.Load()),
args=node.args,
keywords=node.keywords,
)
Expand All @@ -114,14 +112,11 @@ class PythonToJavaScriptTranslator:
'list': 'Array',
'Array': 'Array',
'Date': 'Date',

'dict': 'Object',
'Map': 'Map',
'WeakMap': 'WeakMap',

'str': 'String',
'String': 'String',

'set': 'Set',
'Set': 'Set',
'WeakSet': 'WeakSet',
Expand Down Expand Up @@ -184,19 +179,17 @@ def visit_Assign(self, node):

is_name = isinstance(target, ast.Name)
compiled_target = self.visit(target)
if (not is_name or
(self.current_class is not None and
compiled_target.startswith('this.'))):
if not is_name or (
self.current_class is not None
and compiled_target.startswith('this.')
):
self.scope[compiled_target] = compiled_target
return '{} = {}'.format(
self.scope[compiled_target],
self.visit(node.value)
self.scope[compiled_target], self.visit(node.value)
)

def translate_special_method(self, name):
return {
'__init__': 'constructor'
}.get(name, name)
return {'__init__': 'constructor'}.get(name, name)

def visit_FunctionDef(self, node):
self.current_function = node
Expand Down Expand Up @@ -226,11 +219,11 @@ def visit_FunctionDef(self, node):
prefix += ' ' * (self.current_class is None)

lines = [
prefix +
self.translate_special_method(node.name) +
'({}) {{'.format(self.visit(node.args)),
prefix
+ self.translate_special_method(node.name)
+ '({}) {{'.format(self.visit(node.args)),
body,
'}'
'}',
]

self.current_function = None
Expand Down Expand Up @@ -289,10 +282,11 @@ def visit_NameConstant(self, node):
return 'true'
elif value is False:
return 'false'
assert value is None, \
'value is not True and is not False, must be None, got {}'.format(
value
)
assert (
value is None
), 'value is not True and is not False, must be None, got {}'.format(
value
)
return 'null'

def visit_Str(self, node):
Expand Down Expand Up @@ -404,9 +398,7 @@ def visit_Compare(self, node):
for op, right in zip(ops, rights):
comparisons.append(
'({} {} {})'.format(
self.visit(left),
self.visit(op),
self.visit(right)
self.visit(left), self.visit(op), self.visit(right)
)
)
left = right
Expand Down Expand Up @@ -521,7 +513,7 @@ def visit_ListComp(self, node):
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[]
defaults=[],
)
else:
signature = ast.List(elts=argslist, ctx=ast.Load())
Expand All @@ -536,11 +528,13 @@ def visit_ListComp(self, node):
method = ast.Attribute(value=array, attr='filter', ctx=ast.Load())
# array.filter(func)
array = ast.Call(
func=method, args=[lam_sig(body=filt)], keywords=[])
func=method, args=[lam_sig(body=filt)], keywords=[]
)

method = ast.Attribute(value=array, attr='map', ctx=ast.Load())
mapped = ast.Call(
func=method, args=[lam_sig(body=node.elt)], keywords=[])
func=method, args=[lam_sig(body=node.elt)], keywords=[]
)
result = self.visit(mapped)
return result

Expand All @@ -556,7 +550,7 @@ def visit_Delete(self, node):
@udf(
input_type=[dt.double, dt.double, dt.int64],
output_type=dt.Array(dt.double),
strict=False
strict=False,
)
def my_func(a, b, n):
class Rectangle:
Expand Down Expand Up @@ -635,4 +629,5 @@ def range(n):
foo = Rectangle(1, 2)
nnn = len(values)
return [sum(values) - a + b * y ** -x, z, foo.width, nnn]

print(my_func.js)
16 changes: 9 additions & 7 deletions ibis/bigquery/udf/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def find_Call(self, node):
fields = node._fields
else:
fields = [field for field in node._fields if field != 'func']
return toolz.concat(map(
self.find, (getattr(node, field) for field in fields)
))
return toolz.concat(
map(self.find, (getattr(node, field) for field in fields))
)


def find_names(node):
Expand All @@ -64,7 +64,9 @@ def find_names(node):
>>> names[1].id
'b'
"""
return list(toolz.unique(
filter(None, NameFinder().find(node)),
key=lambda node: (node.id, type(node.ctx))
))
return list(
toolz.unique(
filter(None, NameFinder().find(node)),
key=lambda node: (node.id, type(node.ctx)),
)
)
5 changes: 4 additions & 1 deletion ibis/bigquery/udf/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def matches(value, pattern):

fields = [
(field, getattr(pattern, field))
for field in pattern._fields if hasattr(pattern, field)
for field in pattern._fields
if hasattr(pattern, field)
]
for field_name, field_value in fields:
if not matches(getattr(value, field_name), field_value):
Expand All @@ -38,13 +39,15 @@ class Rewriter:
----------
funcs : List[Tuple[ast.AST, Callable[ast.expr, [ast.expr]]]]
"""

def __init__(self):
self.funcs = []

def register(self, pattern):
def wrapper(f):
self.funcs.append((pattern, f))
return f

return wrapper

def __call__(self, node):
Expand Down
23 changes: 18 additions & 5 deletions ibis/bigquery/udf/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ def test_yield_from():
d = {}

with tempfile.NamedTemporaryFile('r+') as f:
f.write("""\
f.write(
"""\
def f(a):
yield from [1, 2, 3]""")
yield from [1, 2, 3]"""
)
f.seek(0)
code = builtins.compile(f.read(), f.name, 'exec')
exec(code, d)
Expand Down Expand Up @@ -115,15 +117,16 @@ def div(x, y):


@pytest.mark.parametrize(
('op', 'expected'),
[(add, '+'), (sub, '-'), (mul, '*'), (div, '/')]
('op', 'expected'), [(add, '+'), (sub, '-'), (mul, '*'), (div, '/')]
)
def test_binary_operators(op, expected):
js = compile(op)
expected = """\
function {}(x, y) {{
return (x {} y);
}}""".format(op.__name__, expected)
}}""".format(
op.__name__, expected
)
assert expected == js


Expand Down Expand Up @@ -193,6 +196,7 @@ def f():
b = False
c = None
return a if c != None else b # noqa: E711

expected = """\
function f() {
let a = true;
Expand Down Expand Up @@ -376,6 +380,7 @@ def f(a):
y = '2'
x[y] = y
return x

expected = """\
function f(a) {
let x = {};
Expand All @@ -395,6 +400,7 @@ def f(a):
del x[0 + 3]
del y.a
return 1

expected = """\
function f(a) {
let x = [a, 1, 2, 3];
Expand Down Expand Up @@ -451,6 +457,7 @@ def test_list_comp():
def f():
x = [a + b for a, b in [(1, 2), (3, 4), (5, 6)] if a > 1 if b > 2]
return x

expected = """\
function f() {
let x = [[1, 2], [3, 4], [5, 6]].filter((([a, b]) => ((a > 1) && (b > 2)))).map((([a, b]) => (a + b)));
Expand All @@ -476,6 +483,7 @@ def f():
if c > 3
]
return x

expected = """\
function f() {
let x = [1, 4, 7].map(
Expand All @@ -495,8 +503,10 @@ def test_splat():
def f(x, y, z):
def g(a, b, c):
return a - b - c

args = [x, y, z]
return g(*args)

expected = """\
function f(x, y, z) {
function g(a, b, c) {
Expand All @@ -512,6 +522,7 @@ def g(a, b, c):
def test_varargs():
def f(*args):
return sum(*args)

expected = """\
function f(...args) {
return sum(...args);
Expand All @@ -523,6 +534,7 @@ def f(*args):
def test_missing_vararg():
def my_range(n):
return [1 for x in [n]]

js = compile(my_range)
expected = """\
function my_range(n) {
Expand All @@ -534,6 +546,7 @@ def my_range(n):
def test_len_rewrite():
def my_func(a):
return len(a)

js = compile(my_func)
expected = """\
function my_func(a) {
Expand Down
3 changes: 1 addition & 2 deletions ibis/bigquery/udf/tests/test_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def test_find_Compare():
found = find_names(expr)
assert len(found) == 6
assert eq(
found,
[var('a'), var('b'), var('c'), var('e'), var('f'), var('gh')]
found, [var('a'), var('b'), var('c'), var('e'), var('f'), var('gh')]
)


Expand Down
76 changes: 42 additions & 34 deletions ibis/bigquery/udf/tests/test_udf_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest

from pytest import param

import pandas as pd
import pandas.util.testing as tm

Expand All @@ -12,20 +14,18 @@

pytestmark = pytest.mark.bigquery

from ibis.bigquery.api import udf # noqa: E402
from ibis.bigquery.tests.conftest import (
connect as bigquery_connect,
) # noqa: E402
from ibis.bigquery import udf # noqa: E402

PROJECT_ID = os.environ.get('GOOGLE_BIGQUERY_PROJECT_ID', 'ibis-gbq')
DATASET_ID = 'testing'


@pytest.fixture(scope='module')
def client():
ga = pytest.importorskip('google.auth')

try:
return ibis.bigquery.connect(PROJECT_ID, DATASET_ID)
except ga.exceptions.DefaultCredentialsError:
pytest.skip("no credentials found, skipping")
return bigquery_connect(PROJECT_ID, DATASET_ID)


@pytest.fixture(scope='module')
Expand All @@ -52,26 +52,28 @@ def my_add(a, b):
expected = (df.double_col + df.double_col).rename('tmp')
tm.assert_series_equal(
result.value_counts().sort_index(),
expected.value_counts().sort_index()
expected.value_counts().sort_index(),
)


def test_udf_with_struct(client, alltypes, df):
@udf(
input_type=[dt.double, dt.double],
output_type=dt.Struct.from_tuples([
('width', dt.double),
('height', dt.double)
])
output_type=dt.Struct.from_tuples(
[('width', dt.double), ('height', dt.double)]
),
)
def my_struct_thing(a, b):
class Rectangle:
def __init__(self, width, height):
self.width = width
self.height = height

return Rectangle(a, b)

assert my_struct_thing.js == '''\
assert (
my_struct_thing.js
== '''\
CREATE TEMPORARY FUNCTION my_struct_thing_0(a FLOAT64, b FLOAT64)
RETURNS STRUCT<width FLOAT64, height FLOAT64>
LANGUAGE js AS """
Expand All @@ -87,14 +89,14 @@ class Rectangle {
}
return my_struct_thing(a, b);
""";'''
)

expr = my_struct_thing(alltypes.double_col, alltypes.double_col)
result = expr.execute()
assert not result.empty

expected = pd.Series(
[{'width': c, 'height': c} for c in df.double_col],
name='tmp'
[{'width': c, 'height': c} for c in df.double_col], name='tmp'
)
tm.assert_series_equal(result, expected)

Expand Down Expand Up @@ -126,7 +128,6 @@ def my_add(x, y):


def test_multiple_calls_has_one_definition(client):

@udf([dt.string], dt.double)
def my_str_len(s):
return s.length
Expand Down Expand Up @@ -157,7 +158,7 @@ def test_udf_libraries(client):
dt.double,
# whatever symbols are exported in the library are visible inside the
# UDF, in this case lodash defines _ and we use that here
libraries=['gs://ibis-testing-libraries/lodash.min.js']
libraries=['gs://ibis-testing-libraries/lodash.min.js'],
)
def string_length(strings):
return _.sum(_.map(strings, lambda x: x.length)) # noqa: F821
Expand All @@ -184,7 +185,6 @@ def my_array_len(x):


def test_multiple_calls_redefinition(client):

@udf([dt.string], dt.double)
def my_len(s):
return s.length
Expand All @@ -195,6 +195,7 @@ def my_len(s):
@udf([dt.string], dt.double)
def my_len(s):
return s.length + 1

expr = expr + my_len(s)

sql = client.compile(expr)
Expand Down Expand Up @@ -226,26 +227,33 @@ def my_len(s):
@pytest.mark.parametrize(
('argument_type', 'return_type'),
[
pytest.mark.xfail((dt.int64, dt.float64), raises=TypeError),
pytest.mark.xfail((dt.float64, dt.int64), raises=TypeError),
param(dt.int64, dt.float64, marks=pytest.mark.xfail(raises=TypeError)),
param(dt.float64, dt.int64, marks=pytest.mark.xfail(raises=TypeError)),
# complex argument type, valid return type
pytest.mark.xfail((dt.Array(dt.int64), dt.float64), raises=TypeError),
param(
dt.Array(dt.int64),
dt.float64,
marks=pytest.mark.xfail(raises=TypeError),
),
# valid argument type, complex invalid return type
pytest.mark.xfail(
(dt.float64, dt.Array(dt.int64)), raises=TypeError),
param(
dt.float64,
dt.Array(dt.int64),
marks=pytest.mark.xfail(raises=TypeError),
),
# both invalid
pytest.mark.xfail(
(dt.Array(dt.Array(dt.int64)), dt.int64), raises=TypeError),
param(
dt.Array(dt.Array(dt.int64)),
dt.int64,
marks=pytest.mark.xfail(raises=TypeError),
),
# struct type with nested integer, valid return type
pytest.mark.xfail(
(dt.Struct.from_tuples([('x', dt.Array(dt.int64))]), dt.float64),
raises=TypeError,
)
]
param(
dt.Struct.from_tuples([('x', dt.Array(dt.int64))]),
dt.float64,
marks=pytest.mark.xfail(raises=TypeError),
),
],
)
def test_udf_int64(client, argument_type, return_type):
# invalid argument type, valid return type
Expand Down
25 changes: 20 additions & 5 deletions ibis/clickhouse/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

try:
import lz4 # noqa: F401

_default_compression = 'lz4'
except ImportError:
_default_compression = False
Expand All @@ -25,6 +26,7 @@ def compile(expr, params=None):
compiled : string
"""
from ibis.clickhouse.compiler import to_sql

return to_sql(expr, dialect.make_context(params=params))


Expand All @@ -40,8 +42,15 @@ def verify(expr, params=None):
return False


def connect(host='localhost', port=9000, database='default', user='default',
password='', client_name='ibis', compression=_default_compression):
def connect(
host='localhost',
port=9000,
database='default',
user='default',
password='',
client_name='ibis',
compression=_default_compression,
):
"""Create an ClickhouseClient for use with Ibis.
Parameters
Expand Down Expand Up @@ -83,9 +92,15 @@ def connect(host='localhost', port=9000, database='default', user='default',
-------
ClickhouseClient
"""
client = ClickhouseClient(host, port=port, database=database, user=user,
password=password, client_name=client_name,
compression=compression)
client = ClickhouseClient(
host,
port=port,
database=database,
user=user,
password=password,
client_name=client_name,
compression=compression,
)
if options.default_backend is None:
options.default_backend = client

Expand Down
63 changes: 35 additions & 28 deletions ibis/clickhouse/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import re
import numpy as np
import pandas as pd

from collections import OrderedDict
from pkg_resources import parse_version

import numpy as np
import pandas as pd

import ibis.common as com
import ibis.expr.types as ir
Expand All @@ -11,7 +13,6 @@
import ibis.expr.operations as ops

from ibis.config import options
from ibis.compat import zip as czip, parse_version
from ibis.client import Query, Database, DatabaseEntity, SQLClient
from ibis.clickhouse.compiler import ClickhouseDialect, build_ast
from ibis.util import log
Expand Down Expand Up @@ -40,13 +41,13 @@
'String': dt.String,
'FixedString': dt.String,
'Date': dt.Date,
'DateTime': dt.Timestamp
'DateTime': dt.Timestamp,
}
_ibis_dtypes = {v: k for k, v in _clickhouse_dtypes.items()}
_ibis_dtypes[dt.String] = 'String'


class ClickhouseDataType(object):
class ClickhouseDataType:

__slots__ = 'typename', 'nullable'

Expand Down Expand Up @@ -96,27 +97,28 @@ class ClickhouseDatabase(Database):


class ClickhouseQuery(Query):

def _external_tables(self):
tables = []
for name, df in self.extra_options.get('external_tables', {}).items():
if not isinstance(df, pd.DataFrame):
raise TypeError('External table is not an instance of pandas '
'dataframe')
raise TypeError(
'External table is not an instance of pandas ' 'dataframe'
)

schema = sch.infer(df)
chtypes = map(ClickhouseDataType.from_ibis, schema.types)
structure = list(zip(schema.names, map(str, chtypes)))

tables.append(dict(name=name,
data=df.to_dict('records'),
structure=structure))
tables.append(
dict(
name=name, data=df.to_dict('records'), structure=structure
)
)
return tables

def execute(self):
cursor = self.client._execute(
self.compiled_sql,
external_tables=self._external_tables()
self.compiled_sql, external_tables=self._external_tables()
)
result = self._fetch(cursor)
return self._wrap_result(result)
Expand All @@ -127,9 +129,7 @@ def _fetch(self, cursor):
# handle empty resultset
return pd.DataFrame([], columns=colnames)

df = pd.DataFrame.from_dict(
OrderedDict(zip(colnames, data))
)
df = pd.DataFrame.from_dict(OrderedDict(zip(colnames, data)))
return self.schema().apply_to(df)


Expand All @@ -151,8 +151,11 @@ def _client(self):
def _match_name(self):
m = fully_qualified_re.match(self._qualified_name)
if not m:
raise com.IbisError('Cannot determine database name from {0}'
.format(self._qualified_name))
raise com.IbisError(
'Cannot determine database name from {0}'.format(
self._qualified_name
)
)
db, quoted, unquoted = m.groups()
return db, quoted or unquoted

Expand Down Expand Up @@ -184,14 +187,16 @@ def _execute(self, stmt):

def insert(self, obj, **kwargs):
from .identifiers import quote_identifier

schema = self.schema()

assert isinstance(obj, pd.DataFrame)
assert set(schema.names) >= set(obj.columns)

columns = ', '.join(map(quote_identifier, obj.columns))
query = 'INSERT INTO {table} ({columns}) VALUES'.format(
table=self._qualified_name, columns=columns)
table=self._qualified_name, columns=columns
)

# convert data columns with datetime64 pandas dtype to native date
# because clickhouse-driver 0.0.10 does arithmetic operations on it
Expand All @@ -201,7 +206,7 @@ def insert(self, obj, **kwargs):
obj[col] = obj[col].dt.date

data = obj.to_dict('records')
return self._client.con.process_insert_query(query, data, **kwargs)
return self._client.con.execute(query, data, **kwargs)


class ClickhouseDatabaseTable(ops.DatabaseTable):
Expand Down Expand Up @@ -240,15 +245,17 @@ def _execute(self, query, external_tables=(), results=True):
query = query.compile()
self.log(query)

response = self.con.process_ordinary_query(
query, columnar=True, with_column_types=True,
external_tables=external_tables
response = self.con.execute(
query,
columnar=True,
with_column_types=True,
external_tables=external_tables,
)
if not results:
return response

data, columns = response
colnames, typenames = czip(*columns)
colnames, typenames = zip(*columns)
coltypes = list(map(ClickhouseDataType.parse, typenames))

return data, colnames, coltypes
Expand Down Expand Up @@ -382,7 +389,7 @@ def exists_table(self, name, database=None):
return len(self.list_tables(like=name, database=database)) > 0

def _ensure_temp_db_exists(self):
name = options.clickhouse.temp_db,
name = (options.clickhouse.temp_db,)
if not self.exists_database(name):
self.create_database(name, force=True)

Expand Down Expand Up @@ -410,9 +417,9 @@ def version(self):

try:
server = self.con.connection.server_info
vstring = '{}.{}.{}'.format(server.version_major,
server.version_minor,
server.revision)
vstring = '{}.{}.{}'.format(
server.version_major, server.version_minor, server.revision
)
except Exception:
self.con.connection.disconnect()
raise
Expand Down
35 changes: 18 additions & 17 deletions ibis/clickhouse/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from six import StringIO
from io import StringIO

import ibis.common as com
import ibis.util as util
Expand Down Expand Up @@ -27,7 +27,6 @@ def to_sql(expr, context=None):


class ClickhouseSelectBuilder(comp.SelectBuilder):

@property
def _select_class(self):
return ClickhouseSelect
Expand All @@ -42,13 +41,11 @@ class ClickhouseQueryBuilder(comp.QueryBuilder):


class ClickhouseQueryContext(comp.QueryContext):

def _to_sql(self, expr, ctx):
return to_sql(expr, context=ctx)


class ClickhouseSelect(comp.Select):

@property
def translator(self):
return ClickhouseExprTranslator
Expand All @@ -64,8 +61,9 @@ def format_group_by(self):

lines = []
if len(self.group_by) > 0:
columns = ['`{0}`'.format(expr.get_name())
for expr in self.group_by]
columns = [
'`{0}`'.format(expr.get_name()) for expr in self.group_by
]
clause = 'GROUP BY {0}'.format(', '.join(columns))
lines.append(clause)

Expand Down Expand Up @@ -98,7 +96,7 @@ class ClickhouseTableSetFormatter(comp.TableSetFormatter):
ops.InnerJoin: 'ALL INNER JOIN',
ops.LeftJoin: 'ALL LEFT JOIN',
ops.AnyInnerJoin: 'ANY INNER JOIN',
ops.AnyLeftJoin: 'ANY LEFT JOIN'
ops.AnyLeftJoin: 'ANY LEFT JOIN',
}

def get_result(self):
Expand All @@ -115,16 +113,18 @@ def get_result(self):
# TODO: Now actually format the things
buf = StringIO()
buf.write(self.join_tables[0])
for jtype, table, preds in zip(self.join_types, self.join_tables[1:],
self.join_predicates):
for jtype, table, preds in zip(
self.join_types, self.join_tables[1:], self.join_predicates
):
buf.write('\n')
buf.write(util.indent('{0} {1}'.format(jtype, table), self.indent))

if len(preds):
buf.write('\n')
fmt_preds = map(self._format_predicate, preds)
fmt_preds = util.indent('USING ' + ', '.join(fmt_preds),
self.indent * 2)
fmt_preds = util.indent(
'USING ' + ', '.join(fmt_preds), self.indent * 2
)
buf.write(fmt_preds)

return buf.getvalue()
Expand All @@ -133,13 +133,15 @@ def _validate_join_predicates(self, predicates):
for pred in predicates:
op = pred.op()
if not isinstance(op, ops.Equals):
raise com.TranslationError('Non-equality join predicates are '
'not supported')
raise com.TranslationError(
'Non-equality join predicates are ' 'not supported'
)

left_on, right_on = op.args
if left_on.get_name() != right_on.get_name():
raise com.TranslationError('Joining on different column names '
'is not supported')
raise com.TranslationError(
'Joining on different column names ' 'is not supported'
)

def _format_predicate(self, predicate):
column = predicate.op().args[0]
Expand All @@ -155,8 +157,7 @@ class ClickhouseExprTranslator(comp.ExprTranslator):
context_class = ClickhouseQueryContext

def name(self, translated, name, force=True):
return _name_expr(translated,
quote_identifier(name, force=force))
return _name_expr(translated, quote_identifier(name, force=force))


class ClickhouseDialect(comp.Dialect):
Expand Down
218 changes: 110 additions & 108 deletions ibis/clickhouse/identifiers.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,113 @@
_identifiers = frozenset({
'add',
'aggregate',
'all',
'alter',
'and',
'as',
'asc',
'between',
'by',
'cached',
'case',
'cast',
'change',
'class',
'column',
'columns',
'comment',
'create',
'cross',
'data',
'database',
'databases',
'date',
'datetime',
'desc',
'describe',
'distinct',
'div',
'double',
'drop',
'else',
'end',
'escaped',
'exists',
'explain',
'external',
'fields',
'fileformat',
'first',
'float',
'format',
'from',
'full',
'function',
'functions',
'group',
'having',
'if',
'in',
'inner',
'inpath',
'insert',
'int',
'integer',
'intermediate',
'interval',
'into',
'is',
'join',
'last',
'left',
'like',
'limit',
'lines',
'load',
'location',
'metadata',
'not',
'null',
'offset',
'on',
'or',
'order',
'outer',
'partition',
'partitioned',
'partitions',
'real',
'refresh',
'regexp',
'rename',
'replace',
'returns',
'right',
'row',
'schema',
'schemas',
'select',
'set',
'show',
'stats',
'stored',
'string',
'symbol',
'table',
'tables',
'then',
'to',
'union',
'use',
'using',
'values',
'view',
'when',
'where',
'with'
})
_identifiers = frozenset(
{
'add',
'aggregate',
'all',
'alter',
'and',
'as',
'asc',
'between',
'by',
'cached',
'case',
'cast',
'change',
'class',
'column',
'columns',
'comment',
'create',
'cross',
'data',
'database',
'databases',
'date',
'datetime',
'desc',
'describe',
'distinct',
'div',
'double',
'drop',
'else',
'end',
'escaped',
'exists',
'explain',
'external',
'fields',
'fileformat',
'first',
'float',
'format',
'from',
'full',
'function',
'functions',
'group',
'having',
'if',
'in',
'inner',
'inpath',
'insert',
'int',
'integer',
'intermediate',
'interval',
'into',
'is',
'join',
'last',
'left',
'like',
'limit',
'lines',
'load',
'location',
'metadata',
'not',
'null',
'offset',
'on',
'or',
'order',
'outer',
'partition',
'partitioned',
'partitions',
'real',
'refresh',
'regexp',
'rename',
'replace',
'returns',
'right',
'row',
'schema',
'schemas',
'select',
'set',
'show',
'stats',
'stored',
'string',
'symbol',
'table',
'tables',
'then',
'to',
'union',
'use',
'using',
'values',
'view',
'when',
'where',
'with',
}
)


def quote_identifier(name, quotechar='`', force=False):
Expand Down
104 changes: 43 additions & 61 deletions ibis/clickhouse/operations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from six import StringIO
from io import StringIO
from datetime import date, datetime

import ibis.common as com
Expand Down Expand Up @@ -65,18 +65,19 @@ def formatter(translator, expr):
msg = 'Incorrect number of args {0} instead of {1}'
raise com.UnsupportedOperationError(msg.format(arg_count, arity))
return _call(translator, func_name, *op.args)

return formatter


def agg(func):
def formatter(translator, expr):
return _aggregate(translator, func, *expr.op().args)

return formatter


def agg_variance_like(func):
variants = {'sample': '{0}Samp'.format(func),
'pop': '{0}Pop'.format(func)}
variants = {'sample': '{0}Samp'.format(func), 'pop': '{0}Pop'.format(func)}

def formatter(translator, expr):
arg, how, where = expr.op().args
Expand All @@ -94,6 +95,7 @@ def formatter(translator, expr):
right_ = _parenthesize(translator, right)

return '{0!s} {1!s} {2!s}'.format(left_, infix_sym, right_)

return formatter


Expand Down Expand Up @@ -124,14 +126,18 @@ def varargs(func_name):
def varargs_formatter(translator, expr):
op = expr.op()
return _call(translator, func_name, *op.arg)

return varargs_formatter


def _arbitrary(translator, expr):
arg, how, where = expr.op().args
functions = {'first': 'any',
'last': 'anyLast',
'heavy': 'anyHeavy'}
functions = {
None: 'any',
'first': 'any',
'last': 'anyLast',
'heavy': 'anyHeavy',
}
return _aggregate(translator, functions[how], arg, where=where)


Expand Down Expand Up @@ -228,11 +234,18 @@ def _hash(translator, expr):
op = expr.op()
arg, how = op.args

algorithms = {'MD5', 'halfMD5',
'SHA1', 'SHA224', 'SHA256',
'intHash32', 'intHash64',
'cityHash64',
'sipHash64', 'sipHash128'}
algorithms = {
'MD5',
'halfMD5',
'SHA1',
'SHA224',
'SHA256',
'intHash32',
'intHash64',
'cityHash64',
'sipHash64',
'sipHash128',
}

if how not in algorithms:
raise com.UnsupportedOperationError(
Expand Down Expand Up @@ -268,7 +281,8 @@ def _interval_format(translator, expr):
dtype = expr.type()
if dtype.unit in {'ms', 'us', 'ns'}:
raise com.UnsupportedOperationError(
"Clickhouse doesn't support subsecond interval resolutions")
"Clickhouse doesn't support subsecond interval resolutions"
)

return 'INTERVAL {} {}'.format(expr.op().value, dtype.resolution.upper())

Expand All @@ -280,7 +294,8 @@ def _interval_from_integer(translator, expr):
dtype = expr.type()
if dtype.unit in {'ms', 'us', 'ns'}:
raise com.UnsupportedOperationError(
"Clickhouse doesn't support subsecond interval resolutions")
"Clickhouse doesn't support subsecond interval resolutions"
)

arg_ = translator.translate(arg)
return 'INTERVAL {} {}'.format(arg_, dtype.resolution.upper())
Expand Down Expand Up @@ -315,8 +330,7 @@ def literal(translator, expr):
raise NotImplementedError(type(expr))


class CaseFormatter(object):

class CaseFormatter:
def __init__(self, translator, base, cases, results, default):
self.translator = translator
self.base = base
Expand Down Expand Up @@ -367,15 +381,17 @@ def _next_case(self):

def _simple_case(translator, expr):
op = expr.op()
formatter = CaseFormatter(translator, op.base, op.cases, op.results,
op.default)
formatter = CaseFormatter(
translator, op.base, op.cases, op.results, op.default
)
return formatter.get_result()


def _searched_case(translator, expr):
op = expr.op()
formatter = CaseFormatter(translator, None, op.cases, op.results,
op.default)
formatter = CaseFormatter(
translator, None, op.cases, op.results, op.default
)
return formatter.get_result()


Expand Down Expand Up @@ -407,7 +423,7 @@ def _truncate(translator, expr):
'D': 'toDate',
'h': 'toStartOfHour',
'm': 'toStartOfMinute',
's': 'toDateTime'
's': 'toDateTime',
}

try:
Expand Down Expand Up @@ -466,15 +482,15 @@ def _table_column(translator, expr):
def _string_split(translator, expr):
value, sep = expr.op().args
return 'splitByString({}, {})'.format(
translator.translate(sep),
translator.translate(value)
translator.translate(sep), translator.translate(value)
)


def _string_join(translator, expr):
sep, elements = expr.op().args
assert isinstance(elements.op(), ops.ValueList), \
'elements must be a ValueList, got {}'.format(type(elements.op()))
assert isinstance(
elements.op(), ops.ValueList
), 'elements must be a ValueList, got {}'.format(type(elements.op()))
return 'arrayStringConcat([{}], {})'.format(
', '.join(map(translator.translate, elements)),
translator.translate(sep),
Expand Down Expand Up @@ -507,50 +523,39 @@ def _string_like(translator, expr):
ops.Divide: binary_infix_op('/'),
ops.Power: fixed_arity('pow', 2),
ops.Modulus: binary_infix_op('%'),

# Comparisons
ops.Equals: binary_infix_op('='),
ops.NotEquals: binary_infix_op('!='),
ops.GreaterEqual: binary_infix_op('>='),
ops.Greater: binary_infix_op('>'),
ops.LessEqual: binary_infix_op('<='),
ops.Less: binary_infix_op('<'),

# Boolean comparisons
ops.And: binary_infix_op('AND'),
ops.Or: binary_infix_op('OR'),
ops.Xor: _xor,
}

_unary_ops = {
ops.Negate: _negate,
ops.Not: _not
}
_unary_ops = {ops.Negate: _negate, ops.Not: _not}


_operation_registry = {
# Unary operations
ops.TypeOf: unary('toTypeName'),

ops.IsNan: unary('isNaN'),
ops.IsInf: unary('isInfinite'),

ops.Abs: unary('abs'),
ops.Ceil: unary('ceil'),
ops.Floor: unary('floor'),
ops.Exp: unary('exp'),
ops.Round: _round,

ops.Sign: _sign,
ops.Sqrt: unary('sqrt'),

ops.Hash: _hash,

ops.Log: _log,
ops.Ln: unary('log'),
ops.Log2: unary('log2'),
ops.Log10: unary('log10'),

# Unary aggregates
ops.CMSMedian: agg('median'),
# TODO: there is also a `uniq` function which is the
Expand All @@ -560,16 +565,12 @@ def _string_like(translator, expr):
ops.Sum: agg('sum'),
ops.Max: agg('max'),
ops.Min: agg('min'),

ops.StandardDev: agg_variance_like('stddev'),
ops.Variance: agg_variance_like('var'),

# ops.GroupConcat: fixed_arity('group_concat', 2),

ops.Count: agg('count'),
ops.CountDistinct: agg('uniq'),
ops.Arbitrary: _arbitrary,

# string operations
ops.StringLength: unary('length'),
ops.Lowercase: unary('lower'),
Expand All @@ -583,67 +584,50 @@ def _string_like(translator, expr):
ops.StringSplit: _string_split,
ops.StringSQLLike: _string_like,
ops.Repeat: _string_repeat,

ops.RegexSearch: fixed_arity('match', 2),
# TODO: extractAll(haystack, pattern)[index + 1]
ops.RegexExtract: _regex_extract,
ops.RegexReplace: fixed_arity('replaceRegexpAll', 3),
ops.ParseURL: _parse_url,

# Temporal operations
ops.Date: unary('toDate'),
ops.DateTruncate: _truncate,

ops.TimestampNow: lambda *args: 'now()',
ops.TimestampTruncate: _truncate,

ops.TimeTruncate: _truncate,

ops.IntervalFromInteger: _interval_from_integer,

ops.ExtractYear: unary('toYear'),
ops.ExtractMonth: unary('toMonth'),
ops.ExtractDay: unary('toDayOfMonth'),
ops.ExtractHour: unary('toHour'),
ops.ExtractMinute: unary('toMinute'),
ops.ExtractSecond: unary('toSecond'),

# Other operations
ops.E: lambda *args: 'e()',

ops.Literal: literal,
ops.ValueList: _value_list,

ops.Cast: _cast,

# for more than 2 args this should be arrayGreatest|Least(array([]))
# because clickhouse's greatest and least doesn't support varargs
ops.Greatest: varargs('greatest'),
ops.Least: varargs('least'),

ops.Where: fixed_arity('if', 3),

ops.Between: _between,
ops.Contains: binary_infix_op('IN'),
ops.NotContains: binary_infix_op('NOT IN'),

ops.SimpleCase: _simple_case,
ops.SearchedCase: _searched_case,

ops.TableColumn: _table_column,
ops.TableArrayView: _table_array_view,

ops.DateAdd: binary_infix_op('+'),
ops.DateSub: binary_infix_op('-'),
ops.DateDiff: binary_infix_op('-'),
ops.TimestampAdd: binary_infix_op('+'),
ops.TimestampSub: binary_infix_op('-'),
ops.TimestampDiff: binary_infix_op('-'),
ops.TimestampFromUNIX: _timestamp_from_unix,

transforms.ExistsSubquery: _exists_subquery,
transforms.NotExistsSubquery: _exists_subquery,

ops.ArrayLength: unary('length'),
}

Expand Down Expand Up @@ -680,7 +664,7 @@ def _zero_if_null(translator, expr):
ops.NullIf: fixed_arity('nullIf', 2),
ops.Coalesce: varargs('coalesce'),
ops.NullIfZero: _null_if_zero,
ops.ZeroIfNull: _zero_if_null
ops.ZeroIfNull: _zero_if_null,
}


Expand All @@ -696,18 +680,16 @@ def _zero_if_null(translator, expr):
ops.CumulativeAny,
ops.CumulativeAll,
ops.IdenticalTo,

ops.RowNumber,
ops.DenseRank,
ops.MinRank,
ops.PercentRank,

ops.FirstValue,
ops.LastValue,
ops.NthValue,
ops.Lag,
ops.Lead,
ops.NTile
ops.NTile,
]
_unsupported_ops = {k: raise_error for k in _unsupported_ops}

Expand Down
1 change: 1 addition & 0 deletions ibis/clickhouse/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def df(alltypes):
@pytest.fixture
def translate():
from ibis.clickhouse.compiler import ClickhouseDialect

dialect = ClickhouseDialect()
context = dialect.make_context()
return lambda expr: dialect.translator(expr, context).get_result()
229 changes: 117 additions & 112 deletions ibis/clickhouse/tests/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
pytestmark = pytest.mark.clickhouse


@pytest.mark.parametrize(('reduction', 'func_translated'), [
('sum', 'sum'),
('count', 'count'),
('mean', 'avg'),
('max', 'max'),
('min', 'min'),
('std', 'stddevSamp'),
('var', 'varSamp')
])
@pytest.mark.parametrize(
('reduction', 'func_translated'),
[
('sum', 'sum'),
('count', 'count'),
('mean', 'avg'),
('max', 'max'),
('min', 'min'),
('std', 'stddevSamp'),
('var', 'varSamp'),
],
)
def test_reduction_where(con, alltypes, translate, reduction, func_translated):
template = '{0}If(`double_col`, `bigint_col` < 70)'
expected = template.format(func_translated)
Expand All @@ -40,12 +43,7 @@ def test_std_var_pop(con, alltypes, translate):
assert isinstance(con.execute(expr2), np.float)


@pytest.mark.parametrize('reduction', [
'sum',
'count',
'max',
'min'
])
@pytest.mark.parametrize('reduction', ['sum', 'count', 'max', 'min'])
def test_reduction_invalid_where(con, alltypes, reduction):
condbad_literal = L('T')

Expand All @@ -54,92 +52,95 @@ def test_reduction_invalid_where(con, alltypes, reduction):
fn(alltypes.double_col)


@pytest.mark.parametrize(('func', 'pandas_func'), [
(
lambda t, cond: t.bool_col.count(),
lambda df, cond: df.bool_col.count(),
),
(
lambda t, cond: t.bool_col.approx_nunique(),
lambda df, cond: df.bool_col.nunique(),
),
(
lambda t, cond: t.double_col.sum(),
lambda df, cond: df.double_col.sum(),
),
(
lambda t, cond: t.double_col.mean(),
lambda df, cond: df.double_col.mean(),
),
(
lambda t, cond: t.int_col.approx_median(),
lambda df, cond: np.int32(df.int_col.median()),
),
(
lambda t, cond: t.double_col.min(),
lambda df, cond: df.double_col.min(),
),
(
lambda t, cond: t.double_col.max(),
lambda df, cond: df.double_col.max(),
),
(
lambda t, cond: t.double_col.var(),
lambda df, cond: df.double_col.var(),
),
(
lambda t, cond: t.double_col.std(),
lambda df, cond: df.double_col.std(),
),
(
lambda t, cond: t.double_col.var(how='sample'),
lambda df, cond: df.double_col.var(ddof=1),
),
(
lambda t, cond: t.double_col.std(how='pop'),
lambda df, cond: df.double_col.std(ddof=0),
),
(
lambda t, cond: t.bool_col.count(where=cond),
lambda df, cond: df.bool_col[cond].count(),
),
(
lambda t, cond: t.double_col.sum(where=cond),
lambda df, cond: df.double_col[cond].sum(),
),
(
lambda t, cond: t.double_col.mean(where=cond),
lambda df, cond: df.double_col[cond].mean(),
),
(
lambda t, cond: t.float_col.approx_median(where=cond),
lambda df, cond: df.float_col[cond].median(),
),
(
lambda t, cond: t.double_col.min(where=cond),
lambda df, cond: df.double_col[cond].min(),
),
(
lambda t, cond: t.double_col.max(where=cond),
lambda df, cond: df.double_col[cond].max(),
),
(
lambda t, cond: t.double_col.var(where=cond),
lambda df, cond: df.double_col[cond].var(),
),
(
lambda t, cond: t.double_col.std(where=cond),
lambda df, cond: df.double_col[cond].std(),
),
(
lambda t, cond: t.double_col.var(where=cond, how='sample'),
lambda df, cond: df.double_col[cond].var(),
),
(
lambda t, cond: t.double_col.std(where=cond, how='pop'),
lambda df, cond: df.double_col[cond].std(ddof=0),
)
])
@pytest.mark.parametrize(
('func', 'pandas_func'),
[
(
lambda t, cond: t.bool_col.count(),
lambda df, cond: df.bool_col.count(),
),
(
lambda t, cond: t.bool_col.approx_nunique(),
lambda df, cond: df.bool_col.nunique(),
),
(
lambda t, cond: t.double_col.sum(),
lambda df, cond: df.double_col.sum(),
),
(
lambda t, cond: t.double_col.mean(),
lambda df, cond: df.double_col.mean(),
),
(
lambda t, cond: t.int_col.approx_median(),
lambda df, cond: np.int32(df.int_col.median()),
),
(
lambda t, cond: t.double_col.min(),
lambda df, cond: df.double_col.min(),
),
(
lambda t, cond: t.double_col.max(),
lambda df, cond: df.double_col.max(),
),
(
lambda t, cond: t.double_col.var(),
lambda df, cond: df.double_col.var(),
),
(
lambda t, cond: t.double_col.std(),
lambda df, cond: df.double_col.std(),
),
(
lambda t, cond: t.double_col.var(how='sample'),
lambda df, cond: df.double_col.var(ddof=1),
),
(
lambda t, cond: t.double_col.std(how='pop'),
lambda df, cond: df.double_col.std(ddof=0),
),
(
lambda t, cond: t.bool_col.count(where=cond),
lambda df, cond: df.bool_col[cond].count(),
),
(
lambda t, cond: t.double_col.sum(where=cond),
lambda df, cond: df.double_col[cond].sum(),
),
(
lambda t, cond: t.double_col.mean(where=cond),
lambda df, cond: df.double_col[cond].mean(),
),
(
lambda t, cond: t.float_col.approx_median(where=cond),
lambda df, cond: df.float_col[cond].median(),
),
(
lambda t, cond: t.double_col.min(where=cond),
lambda df, cond: df.double_col[cond].min(),
),
(
lambda t, cond: t.double_col.max(where=cond),
lambda df, cond: df.double_col[cond].max(),
),
(
lambda t, cond: t.double_col.var(where=cond),
lambda df, cond: df.double_col[cond].var(),
),
(
lambda t, cond: t.double_col.std(where=cond),
lambda df, cond: df.double_col[cond].std(),
),
(
lambda t, cond: t.double_col.var(where=cond, how='sample'),
lambda df, cond: df.double_col[cond].var(),
),
(
lambda t, cond: t.double_col.std(where=cond, how='pop'),
lambda df, cond: df.double_col[cond].std(ddof=0),
),
],
)
def test_aggregations(alltypes, df, func, pandas_func, translate):
table = alltypes.limit(100)
count = table.count().execute()
Expand All @@ -155,14 +156,17 @@ def test_aggregations(alltypes, df, func, pandas_func, translate):
np.testing.assert_allclose(result, expected)


@pytest.mark.parametrize('op', [
methodcaller('sum'),
methodcaller('mean'),
methodcaller('min'),
methodcaller('max'),
methodcaller('std'),
methodcaller('var')
])
@pytest.mark.parametrize(
'op',
[
methodcaller('sum'),
methodcaller('mean'),
methodcaller('min'),
methodcaller('max'),
methodcaller('std'),
methodcaller('var'),
],
)
def test_boolean_reduction(alltypes, op, df):
result = op(alltypes.bool_col).execute()
assert result == op(df.bool_col)
Expand All @@ -189,7 +193,8 @@ def test_boolean_summary(alltypes):
'sum',
'mean',
'approx_nunique',
]
],
)
tm.assert_frame_equal(
result, expected, check_column_type=False, check_dtype=False
)
tm.assert_frame_equal(result, expected, check_column_type=False,
check_dtype=False)
37 changes: 21 additions & 16 deletions ibis/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from io import StringIO

import pytest
import pandas as pd

Expand All @@ -7,7 +9,6 @@
import pandas.util.testing as tm

from ibis import literal as L
from ibis.compat import StringIO

pytest.importorskip('clickhouse_driver')
pytestmark = pytest.mark.clickhouse
Expand Down Expand Up @@ -85,7 +86,7 @@ def logger(x):

def test_sql_query_limits(alltypes):
table = alltypes
with config.option_context('sql.default_limit', 100000):
with config.option_context('sql.default_limit', 100_000):
# table has 25 rows
assert len(table.execute()) == 7300
# comply with limit arg for TableExpr
Expand Down Expand Up @@ -144,8 +145,7 @@ def test_database_default_current_database(con):
def test_embedded_identifier_quoting(alltypes):
t = alltypes

expr = (t[[(t.double_col * 2).name('double(fun)')]]
['double(fun)'].sum())
expr = t[[(t.double_col * 2).name('double(fun)')]]['double(fun)'].sum()
expr.execute()


Expand All @@ -157,25 +157,26 @@ def test_table_info(alltypes):


def test_execute_exprs_no_table_ref(con):
cases = [
(L(1) + L(2), 3)
]
cases = [(L(1) + L(2), 3)]

for expr, expected in cases:
result = con.execute(expr)
assert result == expected

# ExprList
exlist = ibis.api.expr_list([L(1).name('a'),
ibis.now().name('b'),
L(2).log().name('c')])
exlist = ibis.api.expr_list(
[L(1).name('a'), ibis.now().name('b'), L(2).log().name('c')]
)
con.execute(exlist)


@pytest.mark.skip(reason="FIXME: it is raising KeyError: 'Unnamed: 0'")
def test_insert(con, alltypes, df):
drop = 'DROP TABLE IF EXISTS temporary_alltypes'
create = ('CREATE TABLE IF NOT EXISTS '
'temporary_alltypes AS functional_alltypes')
create = (
'CREATE TABLE IF NOT EXISTS '
'temporary_alltypes AS functional_alltypes'
)

con.raw_sql(drop)
con.raw_sql(create)
Expand All @@ -191,8 +192,10 @@ def test_insert(con, alltypes, df):

def test_insert_with_less_columns(con, alltypes, df):
drop = 'DROP TABLE IF EXISTS temporary_alltypes'
create = ('CREATE TABLE IF NOT EXISTS '
'temporary_alltypes AS functional_alltypes')
create = (
'CREATE TABLE IF NOT EXISTS '
'temporary_alltypes AS functional_alltypes'
)

con.raw_sql(drop)
con.raw_sql(create)
Expand All @@ -207,8 +210,10 @@ def test_insert_with_less_columns(con, alltypes, df):

def test_insert_with_more_columns(con, alltypes, df):
drop = 'DROP TABLE IF EXISTS temporary_alltypes'
create = ('CREATE TABLE IF NOT EXISTS '
'temporary_alltypes AS functional_alltypes')
create = (
'CREATE TABLE IF NOT EXISTS '
'temporary_alltypes AS functional_alltypes'
)

con.raw_sql(drop)
con.raw_sql(create)
Expand Down
Loading