160 changes: 160 additions & 0 deletions ibis/clickhouse/compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from six import StringIO

import ibis.common as com
import ibis.util as util
import ibis.expr.operations as ops
import ibis.sql.compiler as comp

from .identifiers import quote_identifier
from .operations import _operation_registry, _name_expr


def build_ast(expr, context=None, params=None):
builder = ClickhouseQueryBuilder(expr, context=context, params=params)
return builder.get_result()


def _get_query(expr, context):
ast = build_ast(expr, context)
query = ast.queries[0]

return query


def to_sql(expr, context=None):
query = _get_query(expr, context)
return query.compile()


class ClickhouseSelectBuilder(comp.SelectBuilder):

@property
def _select_class(self):
return ClickhouseSelect

def _convert_group_by(self, exprs):
return exprs


class ClickhouseQueryBuilder(comp.QueryBuilder):

select_builder = ClickhouseSelectBuilder

@property
def _make_context(self):
return ClickhouseQueryContext


class ClickhouseQueryContext(comp.QueryContext):

def _to_sql(self, expr, ctx):
return to_sql(expr, context=ctx)


class ClickhouseSelect(comp.Select):

@property
def translator(self):
return ClickhouseExprTranslator

@property
def table_set_formatter(self):
return ClickhouseTableSetFormatter

def format_group_by(self):
if not len(self.group_by):
# There is no aggregation, nothing to see here
return None

lines = []
if len(self.group_by) > 0:
columns = ['`{0}`'.format(expr.get_name())
for expr in self.group_by]
clause = 'GROUP BY {0}'.format(', '.join(columns))
lines.append(clause)

if len(self.having) > 0:
trans_exprs = []
for expr in self.having:
translated = self._translate(expr)
trans_exprs.append(translated)
lines.append('HAVING {0}'.format(' AND '.join(trans_exprs)))

return '\n'.join(lines)


class ClickhouseTableSetFormatter(comp.TableSetFormatter):

_join_names = {
ops.InnerJoin: 'ALL INNER JOIN',
ops.LeftJoin: 'ALL LEFT JOIN',
ops.AnyInnerJoin: 'ANY INNER JOIN',
ops.AnyLeftJoin: 'ANY LEFT JOIN'
}

def get_result(self):
# Got to unravel the join stack; the nesting order could be
# arbitrary, so we do a depth first search and push the join tokens
# and predicates onto a flat list, then format them
op = self.expr.op()

if isinstance(op, ops.Join):
self._walk_join_tree(op)
else:
self.join_tables.append(self._format_table(self.expr))

# TODO: Now actually format the things
buf = StringIO()
buf.write(self.join_tables[0])
for jtype, table, preds in zip(self.join_types, self.join_tables[1:],
self.join_predicates):
buf.write('\n')
buf.write(util.indent('{0} {1}'.format(jtype, table), self.indent))

if len(preds):
buf.write('\n')
fmt_preds = map(self._format_predicate, preds)
fmt_preds = util.indent('USING ' + ', '.join(fmt_preds),
self.indent * 2)
buf.write(fmt_preds)

return buf.getvalue()

def _validate_join_predicates(self, predicates):
for pred in predicates:
op = pred.op()
if not isinstance(op, ops.Equals):
raise com.TranslationError('Non-equality join predicates are '
'not supported')

left_on, right_on = op.args
if left_on.get_name() != right_on.get_name():
raise com.TranslationError('Joining on different column names '
'is not supported')

def _format_predicate(self, predicate):
column = predicate.op().args[0]
return quote_identifier(column.get_name(), force=True)

def _quote_identifier(self, name):
return quote_identifier(name)


class ClickhouseExprTranslator(comp.ExprTranslator):

_registry = _operation_registry
_context_class = ClickhouseQueryContext

def name(self, translated, name, force=True):
return _name_expr(translated,
quote_identifier(name, force=force))


compiles = ClickhouseExprTranslator.compiles
rewrites = ClickhouseExprTranslator.rewrites


@rewrites(ops.FloorDivide)
def _floor_divide(expr):
left, right = expr.op().args
return left.div(right).floor()
115 changes: 115 additions & 0 deletions ibis/clickhouse/identifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
_identifiers = frozenset({
'add',
'aggregate',
'all',
'alter',
'and',
'as',
'asc',
'between',
'by',
'cached',
'case',
'cast',
'change',
'class',
'column',
'columns',
'comment',
'create',
'cross',
'data',
'database',
'databases',
'date',
'datetime',
'desc',
'describe',
'distinct',
'div',
'double',
'drop',
'else',
'end',
'escaped',
'exists',
'explain',
'external',
'fields',
'fileformat',
'first',
'float',
'format',
'from',
'full',
'function',
'functions',
'group',
'having',
'if',
'in',
'inner',
'inpath',
'insert',
'int',
'integer',
'intermediate',
'interval',
'into',
'is',
'join',
'last',
'left',
'like',
'limit',
'lines',
'load',
'location',
'metadata',
'not',
'null',
'offset',
'on',
'or',
'order',
'outer',
'partition',
'partitioned',
'partitions',
'real',
'refresh',
'regexp',
'rename',
'replace',
'returns',
'right',
'row',
'schema',
'schemas',
'select',
'set',
'show',
'stats',
'stored',
'string',
'symbol',
'table',
'tables',
'then',
'to',
'union',
'use',
'using',
'values',
'view',
'when',
'where',
'with'
})


def quote_identifier(name, quotechar='`', force=False):
if force or name.count(' ') or name in _identifiers:
return '{0}{1}{0}'.format(quotechar, name)
else:
return name
646 changes: 646 additions & 0 deletions ibis/clickhouse/operations.py

Large diffs are not rendered by default.

Empty file.
41 changes: 41 additions & 0 deletions ibis/clickhouse/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import ibis
import pytest


CLICKHOUSE_HOST = os.environ.get('IBIS_CLICKHOUSE_HOST', 'localhost')
CLICKHOUSE_PORT = int(os.environ.get('IBIS_CLICKHOUSE_PORT', 9000))
CLICKHOUSE_USER = os.environ.get('IBIS_CLICKHOUSE_USER', 'default')
CLICKHOUSE_PASS = os.environ.get('IBIS_CLICKHOUSE_PASS', '')
IBIS_TEST_CLICKHOUSE_DB = os.environ.get('IBIS_TEST_DATA_DB', 'ibis_testing')


@pytest.fixture(scope='module')
def con():
return ibis.clickhouse.connect(
host=CLICKHOUSE_HOST,
user=CLICKHOUSE_USER,
password=CLICKHOUSE_PASS,
database=IBIS_TEST_CLICKHOUSE_DB,
)


@pytest.fixture(scope='module')
def db(con):
return con.database()


@pytest.fixture(scope='module')
def alltypes(db):
return db.functional_alltypes


@pytest.fixture(scope='module')
def df(alltypes):
return alltypes.execute()


@pytest.fixture
def translate():
from ibis.clickhouse.compiler import ClickhouseExprTranslator
return lambda expr: ClickhouseExprTranslator(expr).get_result()
330 changes: 330 additions & 0 deletions ibis/clickhouse/tests/test_aggregations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,330 @@
import pytest
import numpy as np
import pandas as pd
import pandas.util.testing as tm
from operator import methodcaller
from ibis import literal as L

pytest.importorskip('clickhouse_driver')
pytestmark = pytest.mark.clickhouse


@pytest.mark.parametrize(('reduction', 'func_translated'), [
('sum', 'sum'),
('count', 'count'),
('mean', 'avg'),
('max', 'max'),
('min', 'min'),
('std', 'stddevSamp'),
('var', 'varSamp')
])
def test_reduction_where(con, alltypes, translate, reduction, func_translated):
template = '{0}If(`double_col`, `bigint_col` < 70)'
expected = template.format(func_translated)

method = getattr(alltypes.double_col, reduction)
cond = alltypes.bigint_col < 70
expr = method(where=cond)

assert translate(expr) == expected
assert isinstance(con.execute(expr), (np.float, np.uint))


def test_std_var_pop(con, alltypes, translate):
cond = alltypes.bigint_col < 70
expr1 = alltypes.double_col.std(where=cond, how='pop')
expr2 = alltypes.double_col.var(where=cond, how='pop')

assert translate(expr1) == 'stddevPopIf(`double_col`, `bigint_col` < 70)'
assert translate(expr2) == 'varPopIf(`double_col`, `bigint_col` < 70)'
assert isinstance(con.execute(expr1), np.float)
assert isinstance(con.execute(expr2), np.float)


@pytest.mark.parametrize('reduction', [
'sum',
'count',
'max',
'min'
])
def test_reduction_invalid_where(con, alltypes, reduction):
condbad_literal = L('T')

with pytest.raises(TypeError):
fn = methodcaller(reduction, where=condbad_literal)
fn(alltypes.double_col)


# @pytest.mark.parametrize(
# ('func', 'pandas_func'),
# [
# # tier and histogram
# (
# lambda d: d.bucket([0, 10, 25, 50, 100]),
# lambda s: pd.cut(
# s, [0, 10, 25, 50, 100], right=False, labels=False,
# )
# ),
# (
# lambda d: d.bucket([0, 10, 25, 50], include_over=True),
# lambda s: pd.cut(
# s, [0, 10, 25, 50, np.inf], right=False, labels=False
# )
# ),
# (
# lambda d: d.bucket([0, 10, 25, 50], close_extreme=False),
# lambda s: pd.cut(s, [0, 10, 25, 50], right=False, labels=False),
# ),
# (
# lambda d: d.bucket(
# [0, 10, 25, 50], closed='right', close_extreme=False
# ),
# lambda s: pd.cut(
# s, [0, 10, 25, 50],
# include_lowest=False,
# right=True,
# labels=False,
# )
# ),
# (
# lambda d: d.bucket([10, 25, 50, 100], include_under=True),
# lambda s: pd.cut(
# s, [0, 10, 25, 50, 100], right=False, labels=False
# ),
# ),
# ]
# )
# def test_bucket(alltypes, df, func, pandas_func):
# expr = func(alltypes.double_col)
# result = expr.execute()
# expected = pandas_func(df.double_col)
# tm.assert_series_equal(result, expected, check_names=False)


# def test_category_label(alltypes, df):
# t = alltypes
# d = t.double_col

# bins = [0, 10, 25, 50, 100]
# labels = ['a', 'b', 'c', 'd']
# bucket = d.bucket(bins)
# expr = bucket.label(labels)
# result = expr.execute().astype('category', ordered=True)
# result.name = 'double_col'

# expected = pd.cut(df.double_col, bins, labels=labels, right=False)

# tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(('func', 'pandas_func'), [
(
lambda t, cond: t.bool_col.count(),
lambda df, cond: df.bool_col.count(),
),
# (
# lambda t, cond: t.bool_col.nunique(),
# lambda df, cond: df.bool_col.nunique(),
# ),
(
lambda t, cond: t.bool_col.approx_nunique(),
lambda df, cond: df.bool_col.nunique(),
),
# group_concat
# (
# lambda t, cond: t.bool_col.any(),
# lambda df, cond: df.bool_col.any(),
# ),
# (
# lambda t, cond: t.bool_col.all(),
# lambda df, cond: df.bool_col.all(),
# ),
# (
# lambda t, cond: t.bool_col.notany(),
# lambda df, cond: ~df.bool_col.any(),
# ),
# (
# lambda t, cond: t.bool_col.notall(),
# lambda df, cond: ~df.bool_col.all(),
# ),
(
lambda t, cond: t.double_col.sum(),
lambda df, cond: df.double_col.sum(),
),
(
lambda t, cond: t.double_col.mean(),
lambda df, cond: df.double_col.mean(),
),
(
lambda t, cond: t.int_col.approx_median(),
lambda df, cond: df.int_col.median(),
),
(
lambda t, cond: t.double_col.min(),
lambda df, cond: df.double_col.min(),
),
(
lambda t, cond: t.double_col.max(),
lambda df, cond: df.double_col.max(),
),
(
lambda t, cond: t.double_col.var(),
lambda df, cond: df.double_col.var(),
),
(
lambda t, cond: t.double_col.std(),
lambda df, cond: df.double_col.std(),
),
(
lambda t, cond: t.double_col.var(how='sample'),
lambda df, cond: df.double_col.var(ddof=1),
),
(
lambda t, cond: t.double_col.std(how='pop'),
lambda df, cond: df.double_col.std(ddof=0),
),
(
lambda t, cond: t.bool_col.count(where=cond),
lambda df, cond: df.bool_col[cond].count(),
),
# (
# lambda t, cond: t.bool_col.nunique(where=cond),
# lambda df, cond: df.bool_col[cond].nunique(),
# ),
# (
# lambda t, cond: t.bool_col.approx_nunique(where=cond),
# lambda df, cond: df.bool_col[cond].nunique(),
# ),
(
lambda t, cond: t.double_col.sum(where=cond),
lambda df, cond: df.double_col[cond].sum(),
),
(
lambda t, cond: t.double_col.mean(where=cond),
lambda df, cond: df.double_col[cond].mean(),
),
(
lambda t, cond: t.int_col.approx_median(where=cond),
lambda df, cond: df.int_col[cond].median(),
),
(
lambda t, cond: t.double_col.min(where=cond),
lambda df, cond: df.double_col[cond].min(),
),
(
lambda t, cond: t.double_col.max(where=cond),
lambda df, cond: df.double_col[cond].max(),
),
(
lambda t, cond: t.double_col.var(where=cond),
lambda df, cond: df.double_col[cond].var(),
),
(
lambda t, cond: t.double_col.std(where=cond),
lambda df, cond: df.double_col[cond].std(),
),
(
lambda t, cond: t.double_col.var(where=cond, how='sample'),
lambda df, cond: df.double_col[cond].var(),
),
(
lambda t, cond: t.double_col.std(where=cond, how='pop'),
lambda df, cond: df.double_col[cond].std(ddof=0),
)
])
def test_aggregations(alltypes, df, func, pandas_func, translate):
table = alltypes.limit(100)
count = table.count().execute()
df = df.head(int(count))

cond = table.string_col.isin(['1', '7'])
mask = cond.execute().astype('bool')
expr = func(table, cond)

result = expr.execute()
expected = pandas_func(df, mask)

np.testing.assert_allclose(result, expected)


# def test_group_concat(alltypes, df):
# expr = alltypes.string_col.group_concat()
# result = expr.execute()
# expected = ','.join(df.string_col.dropna())
# assert result == expected


# TODO: requires CountDistinct to support condition
# def test_distinct_aggregates(alltypes, df, translate):
# expr = alltypes.limit(100).double_col.nunique()
# result = expr.execute()

# assert translate(expr) == 'uniq(`double_col`)'
# assert result == df.head(100).double_col.nunique()


@pytest.mark.parametrize('op', [
methodcaller('sum'),
methodcaller('mean'),
methodcaller('min'),
methodcaller('max'),
methodcaller('std'),
methodcaller('var')
])
def test_boolean_reduction(alltypes, op, df):
result = op(alltypes.bool_col).execute()
assert result == op(df.bool_col)


def test_anonymus_aggregate(alltypes, df, translate):
t = alltypes
expr = t[t.double_col > t.double_col.mean()]
result = expr.execute().set_index('id')
expected = df[df.double_col > df.double_col.mean()].set_index('id')
tm.assert_frame_equal(result, expected, check_like=True)


# def test_rank(con):
# t = con.table('functional_alltypes')
# expr = t.double_col.rank()
# sqla_expr = expr.compile()
# result = str(sqla_expr.compile(compile_kwargs=dict(literal_binds=True)))
# expected = """\
# assert result == expected


# def test_percent_rank(con):
# t = con.table('functional_alltypes')
# expr = t.double_col.percent_rank()
# sqla_expr = expr.compile()
# result = str(sqla_expr.compile(compile_kwargs=dict(literal_binds=True)))
# expected = """\
# assert result == expected


# def test_ntile(con):
# t = con.table('functional_alltypes')
# expr = t.double_col.ntile(7)
# sqla_expr = expr.compile()
# result = str(sqla_expr.compile(compile_kwargs=dict(literal_binds=True)))
# expected = """\
# assert result == expected


def test_boolean_summary(alltypes):
expr = alltypes.bool_col.summary()
result = expr.execute()
expected = pd.DataFrame(
[[7300, 0, 0, 1, 3650, 0.5, 2]],
columns=[
'count',
'nulls',
'min',
'max',
'sum',
'mean',
'approx_nunique',
]
)
tm.assert_frame_equal(result, expected, check_column_type=False,
check_dtype=False)
172 changes: 172 additions & 0 deletions ibis/clickhouse/tests/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import pytest
import pandas as pd

import ibis
import ibis.config as config
import ibis.expr.types as ir

from ibis import literal as L
from ibis.compat import StringIO


pytest.importorskip('clickhouse_driver')
pytestmark = pytest.mark.clickhouse


def test_get_table_ref(db):
table = db.functional_alltypes
assert isinstance(table, ir.TableExpr)

table = db['functional_alltypes']
assert isinstance(table, ir.TableExpr)


def test_run_sql(con, db):
query = 'SELECT * FROM {0}.`functional_alltypes`'.format(db.name)
table = con.sql(query)

fa = con.table('functional_alltypes')
assert isinstance(table, ir.TableExpr)
assert table.schema() == fa.schema()

expr = table.limit(10)
result = expr.execute()
assert len(result) == 10


def test_get_schema(con, db):
t = con.table('functional_alltypes')
schema = con.get_schema('functional_alltypes', database=db.name)
assert t.schema() == schema


def test_result_as_dataframe(con, alltypes):
expr = alltypes.limit(10)

ex_names = expr.schema().names
result = con.execute(expr)

assert isinstance(result, pd.DataFrame)
assert result.columns.tolist() == ex_names
assert len(result) == 10


def test_array_default_limit(con, alltypes):
result = con.execute(alltypes.float_col, limit=100)
assert len(result) == 100


def test_limit_overrides_expr(con, alltypes):
result = con.execute(alltypes.limit(10), limit=5)
assert len(result) == 5


def test_limit_equals_none_no_limit(alltypes):
with config.option_context('sql.default_limit', 10):
result = alltypes.execute(limit=None)
assert len(result) > 10


def test_verbose_log_queries(con, db):
queries = []

def logger(x):
queries.append(x)

with config.option_context('verbose', True):
with config.option_context('verbose_log', logger):
con.table('functional_alltypes', database=db.name)

expected = 'DESC {0}.`functional_alltypes`'.format(db.name)

assert len(queries) == 1
assert queries[0] == expected


def test_sql_query_limits(alltypes):
table = alltypes
with config.option_context('sql.default_limit', 100000):
# table has 25 rows
assert len(table.execute()) == 7300
# comply with limit arg for TableExpr
assert len(table.execute(limit=10)) == 10
# state hasn't changed
assert len(table.execute()) == 7300
# non-TableExpr ignores default_limit
assert table.count().execute() == 7300
# non-TableExpr doesn't observe limit arg
assert table.count().execute(limit=10) == 7300
with config.option_context('sql.default_limit', 20):
# TableExpr observes default limit setting
assert len(table.execute()) == 20
# explicit limit= overrides default
assert len(table.execute(limit=15)) == 15
assert len(table.execute(limit=23)) == 23
# non-TableExpr ignores default_limit
assert table.count().execute() == 7300
# non-TableExpr doesn't observe limit arg
assert table.count().execute(limit=10) == 7300
# eliminating default_limit doesn't break anything
with config.option_context('sql.default_limit', None):
assert len(table.execute()) == 7300
assert len(table.execute(limit=15)) == 15
assert len(table.execute(limit=10000)) == 7300
assert table.count().execute() == 7300
assert table.count().execute(limit=10) == 7300


def test_expr_compile_verify(alltypes):
expr = alltypes.double_col.sum()

assert isinstance(expr.compile(), str)
assert expr.verify()


def test_api_compile_verify(alltypes):
t = alltypes.timestamp_col

supported = t.year()
unsupported = t.rank()

assert ibis.clickhouse.verify(supported)
assert not ibis.clickhouse.verify(unsupported)


def test_database_repr(db):
assert db.name in repr(db)


def test_database_default_current_database(con):
db = con.database()
assert db.name == con.current_database


def test_embedded_identifier_quoting(alltypes):
t = alltypes

expr = (t[[(t.double_col * 2).name('double(fun)')]]
['double(fun)'].sum())
expr.execute()


def test_table_info(alltypes):
buf = StringIO()
alltypes.info(buf=buf)

assert buf.getvalue() is not None


def test_execute_exprs_no_table_ref(con):
cases = [
(L(1) + L(2), 3)
]

for expr, expected in cases:
result = con.execute(expr)
assert result == expected

# ExprList
exlist = ibis.api.expr_list([L(1).name('a'),
ibis.now().name('b'),
L(2).log().name('c')])
con.execute(exlist)
664 changes: 664 additions & 0 deletions ibis/clickhouse/tests/test_functions.py

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions ibis/clickhouse/tests/test_identifiers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import ibis
import pytest


pytest.importorskip('clickhouse_driver')
pytestmark = pytest.mark.clickhouse


def test_column_ref_quoting(translate):
schema = [('has a space', 'double')]
table = ibis.table(schema)
assert translate(table['has a space']) == '`has a space`'


def test_identifier_quoting(translate):
schema = [('date', 'double'), ('table', 'string')]
table = ibis.table(schema)
assert translate(table['date']) == '`date`'
assert translate(table['table']) == '`table`'


# TODO: fix it
# def test_named_expression(alltypes, translate):
# a, b = alltypes.get_columns(['int_col', 'float_col'])
# expr = ((a - b) * a).name('expr')

# expected = '(`int_col` - `float_col`) * `int_col` AS `expr`'
# assert translate(expr) == expected
53 changes: 53 additions & 0 deletions ibis/clickhouse/tests/test_literals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
from pandas import Timestamp

import ibis
from ibis import literal as L


pytest.importorskip('clickhouse_driver')
pytestmark = pytest.mark.clickhouse


@pytest.mark.parametrize('expr', [
L(Timestamp('2015-01-01 12:34:56')),
L(Timestamp('2015-01-01 12:34:56').to_pydatetime()),
ibis.timestamp('2015-01-01 12:34:56')
])
def test_timestamp_literals(con, translate, expr):
expected = "toDateTime('2015-01-01 12:34:56')"

assert translate(expr) == expected
assert con.execute(expr) == Timestamp('2015-01-01 12:34:56')


@pytest.mark.parametrize(('value', 'expected'), [
('simple', "'simple'"),
('I can\'t', "'I can\\'t'"),
('An "escape"', "'An \"escape\"'")
])
def test_string_literals(con, translate, value, expected):
expr = ibis.literal(value)
assert translate(expr) == expected
# TODO clickhouse-driver escaping problem
# assert con.execute(expr) == expected


@pytest.mark.parametrize(('value', 'expected'), [
(5, '5'),
(1.5, '1.5'),
])
def test_number_literals(con, translate, value, expected):
expr = ibis.literal(value)
assert translate(expr) == expected
assert con.execute(expr) == value


@pytest.mark.parametrize(('value', 'expected'), [
(True, '1'),
(False, '0'),
])
def test_boolean_literals(con, translate, value, expected):
expr = ibis.literal(value)
assert translate(expr) == expected
assert con.execute(expr) == value
275 changes: 275 additions & 0 deletions ibis/clickhouse/tests/test_operators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import pytest
import operator
import numpy as np
import pandas as pd
import pandas.util.testing as tm
from datetime import date, datetime

import ibis
import ibis.expr.datatypes as dt
from ibis import literal as L


pytest.importorskip('clickhouse_driver')
pytestmark = pytest.mark.clickhouse


# def test_not(alltypes):
# t = alltypes.limit(10)
# expr = t.projection([(~t.double_col.isnull()).name('double_col')])
# result = expr.execute().double_col
# expected = ~t.execute().double_col.isnull()
# tm.assert_series_equal(result, expected)


# @pytest.mark.parametrize('op', [operator.invert, operator.neg])
# def test_not_and_negate_bool(con, op, df):
# t = con.table('functional_alltypes').limit(10)
# expr = t.projection([op(t.bool_col).name('bool_col')])
# result = expr.execute().bool_col
# expected = op(df.head(10).bool_col)
# tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(('left', 'right', 'type'), [
(L('2017-04-01'), date(2017, 4, 2), dt.date),
(date(2017, 4, 2), L('2017-04-01'), dt.date),
(L('2017-04-01 01:02:33'), datetime(2017, 4, 1, 1, 3, 34), dt.timestamp),
(datetime(2017, 4, 1, 1, 3, 34), L('2017-04-01 01:02:33'), dt.timestamp)
])
@pytest.mark.parametrize('op', [
operator.eq,
operator.ne,
operator.lt,
operator.le,
operator.gt,
operator.ge,
])
def test_string_temporal_compare(con, op, left, right, type):
expr = op(left, right)
result = con.execute(expr)
left_raw = con.execute(L(left).cast(type))
right_raw = con.execute(L(right).cast(type))
expected = op(left_raw, right_raw)
assert result == expected


@pytest.mark.parametrize(('func', 'left', 'right', 'expected'), [
(operator.add, L(3), L(4), 7),
(operator.sub, L(3), L(4), -1),
(operator.mul, L(3), L(4), 12),
(operator.truediv, L(12), L(4), 3),
(operator.pow, L(12), L(2), 144),
(operator.mod, L(12), L(5), 2),
(operator.truediv, L(7), L(2), 3.5),
(operator.floordiv, L(7), L(2), 3),
(lambda x, y: x.floordiv(y), L(7), 2, 3),
(lambda x, y: x.rfloordiv(y), L(2), 7, 3)
])
def test_binary_arithmetic(con, func, left, right, expected):
expr = func(left, right)
result = con.execute(expr)
assert result == expected


@pytest.mark.parametrize(('op', 'expected'), [
(lambda a, b: a + b, '`int_col` + `tinyint_col`'),
(lambda a, b: a - b, '`int_col` - `tinyint_col`'),
(lambda a, b: a * b, '`int_col` * `tinyint_col`'),
(lambda a, b: a / b, '`int_col` / `tinyint_col`'),
(lambda a, b: a ** b, 'pow(`int_col`, `tinyint_col`)'),
(lambda a, b: a < b, '`int_col` < `tinyint_col`'),
(lambda a, b: a <= b, '`int_col` <= `tinyint_col`'),
(lambda a, b: a > b, '`int_col` > `tinyint_col`'),
(lambda a, b: a >= b, '`int_col` >= `tinyint_col`'),
(lambda a, b: a == b, '`int_col` = `tinyint_col`'),
(lambda a, b: a != b, '`int_col` != `tinyint_col`')
])
def test_binary_infix_operators(con, alltypes, translate, op, expected):
a, b = alltypes.int_col, alltypes.tinyint_col
expr = op(a, b)
assert translate(expr) == expected
assert len(con.execute(expr))


# TODO: test boolean operators
# (h & bool_col, '`h` AND (`a` > 0)'),
# (h | bool_col, '`h` OR (`a` > 0)'),
# (h ^ bool_col, 'xor(`h`, (`a` > 0))')


@pytest.mark.parametrize(('op', 'expected'), [
(lambda a, b, c: (a + b) + c,
'(`int_col` + `tinyint_col`) + `double_col`'),
(lambda a, b, c: a.log() + c,
'log(`int_col`) + `double_col`'),
(lambda a, b, c: (b + (-(a + c))),
'`tinyint_col` + (-(`int_col` + `double_col`))')
])
def test_binary_infix_parenthesization(con, alltypes, translate, op, expected):
a = alltypes.int_col
b = alltypes.tinyint_col
c = alltypes.double_col

expr = op(a, b, c)
assert translate(expr) == expected
assert len(con.execute(expr))


def test_between(con, alltypes, translate):
expr = alltypes.int_col.between(0, 10)
assert translate(expr) == '`int_col` BETWEEN 0 AND 10'
assert len(con.execute(expr))


@pytest.mark.parametrize(('left', 'right'), [
(L('2017-03-31').cast(dt.date), date(2017, 4, 2)),
(date(2017, 3, 31), L('2017-04-02').cast(dt.date))
])
def test_string_temporal_compare_between_dates(con, left, right):
expr = ibis.timestamp('2017-04-01').cast(dt.date).between(left, right)
result = con.execute(expr)
assert result


@pytest.mark.parametrize(('left', 'right'), [
(
L('2017-03-31 00:02:33').cast(dt.timestamp),
datetime(2017, 4, 1, 1, 3, 34),
),
(
datetime(2017, 3, 31, 0, 2, 33),
L('2017-04-01 01:03:34').cast(dt.timestamp),
)
])
def test_string_temporal_compare_between_datetimes(con, left, right):
expr = ibis.timestamp('2017-04-01 00:02:34').between(left, right)
result = con.execute(expr)
assert result


def test_field_in_literals(con, alltypes, translate):
expr = alltypes.string_col.isin(['foo', 'bar', 'baz'])
assert translate(expr) == "`string_col` IN ('foo', 'bar', 'baz')"
assert len(con.execute(expr))

expr = alltypes.string_col.notin(['foo', 'bar', 'baz'])
assert translate(expr) == "`string_col` NOT IN ('foo', 'bar', 'baz')"
assert len(con.execute(expr))


@pytest.mark.parametrize('column', [
'int_col',
'float_col',
'bool_col'
])
def test_negate(con, alltypes, translate, column):
# clickhouse represent boolean as UInt8
expr = -getattr(alltypes, column)
assert translate(expr) == '-`{0}`'.format(column)
assert len(con.execute(expr))


@pytest.mark.parametrize('field', [
'tinyint_col',
'smallint_col',
'int_col',
'bigint_col',
'float_col',
'double_col',
'year',
'month',
])
def test_negate_non_boolean(con, alltypes, field, df):
t = alltypes.limit(10)
expr = t.projection([(-t[field]).name(field)])
result = expr.execute()[field]
expected = -df.head(10)[field]
tm.assert_series_equal(result, expected)


# def test_negate_boolean(con, alltypes, df):
# t = alltypes.limit(10)
# expr = t.projection([(~t.bool_col).name('bool_col')])
# result = expr.execute().bool_col
# print(result)
# expected = ~df.head(10).bool_col
# tm.assert_series_equal(result, expected)


def test_negate_literal(con):
expr = -L(5.245)
assert round(con.execute(expr), 3) == -5.245


@pytest.mark.parametrize(('op', 'pandas_op'), [
(
lambda t: (t.double_col > 20).ifelse(10, -20),
lambda df: pd.Series(np.where(df.double_col > 20, 10, -20),
dtype='int16')
),
(
lambda t: (t.double_col > 20).ifelse(10, -20).abs(),
lambda df: (pd.Series(np.where(df.double_col > 20, 10, -20))
.abs()
.astype('uint16'))
),
])
def test_ifelse(alltypes, df, op, pandas_op, translate):
expr = op(alltypes)
result = expr.execute()
result.name = None
expected = pandas_op(df)

tm.assert_series_equal(result, expected)


def test_simple_case(con, alltypes, translate):
t = alltypes
expr = (t.string_col.case()
.when('foo', 'bar')
.when('baz', 'qux')
.else_('default')
.end())

expected = """CASE `string_col`
WHEN 'foo' THEN 'bar'
WHEN 'baz' THEN 'qux'
ELSE 'default'
END"""
assert translate(expr) == expected
assert len(con.execute(expr))


def test_search_case(con, alltypes, translate):
t = alltypes
expr = (ibis.case()
.when(t.float_col > 0, t.int_col * 2)
.when(t.float_col < 0, t.int_col)
.else_(0)
.end())

expected = """CASE
WHEN `float_col` > 0 THEN `int_col` * 2
WHEN `float_col` < 0 THEN `int_col`
ELSE 0
END"""
assert translate(expr) == expected
assert len(con.execute(expr))


# TODO: Clickhouse raises incompatible type error
# def test_bucket_to_case(con, alltypes, translate):
# buckets = [0, 10, 25, 50]

# expr1 = alltypes.float_col.bucket(buckets)
# expected1 = """\
# CASE
# WHEN (`float_col` >= 0) AND (`float_col` < 10) THEN 0
# WHEN (`float_col` >= 10) AND (`float_col` < 25) THEN 1
# WHEN (`float_col` >= 25) AND (`float_col` <= 50) THEN 2
# ELSE Null
# END"""

# assert translate(expr1) == expected1
# assert len(con.execute(expr1))
471 changes: 471 additions & 0 deletions ibis/clickhouse/tests/test_select.py

Large diffs are not rendered by default.

17 changes: 17 additions & 0 deletions ibis/clickhouse/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest
import pandas as pd


pytest.importorskip('clickhouse_driver')
pytestmark = pytest.mark.clickhouse


def test_column_types(alltypes):
df = alltypes.execute()
assert df.tinyint_col.dtype.name == 'int8'
assert df.smallint_col.dtype.name == 'int16'
assert df.int_col.dtype.name == 'int32'
assert df.bigint_col.dtype.name == 'int64'
assert df.float_col.dtype.name == 'float32'
assert df.double_col.dtype.name == 'float64'
assert pd.core.common.is_datetime64_dtype(df.timestamp_col.dtype)
81 changes: 81 additions & 0 deletions ibis/clickhouse/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
pandas_to_clickhouse = {
'object': 'String',
'uint64': 'UInt64',
'uint32': 'UInt32',
'uint16': 'UInt16',
'float64': 'Float64',
'float32': 'Float32',
'uint8': 'UInt8',
'int64': 'Int64',
'int32': 'Int32',
'int16': 'Int16',
'int8': 'Int8',
'bool': 'UInt8',
'datetime64[D]': 'Date',
'datetime64[ns]': 'DateTime'
}

clickhouse_to_pandas = {
'UInt8': 'uint8',
'UInt16': 'uint16',
'UInt32': 'uint32',
'UInt64': 'uint64',
'Int8': 'int8',
'Int16': 'int16',
'Int32': 'int32',
'Int64': 'int64',
'Float64': 'float64',
'Float32': 'float32',
'String': 'object',
'FixedString': 'object', # TODO
'Null': 'object',
'Date': 'datetime64[ns]',
'DateTime': 'datetime64[ns]',
'Nullable(UInt8)': 'float32',
'Nullable(UInt16)': 'float32',
'Nullable(UInt32)': 'float32',
'Nullable(UInt64)': 'float64',
'Nullable(Int8)': 'float32',
'Nullable(Int16)': 'float32',
'Nullable(Int32)': 'float32',
'Nullable(Int64)': 'float64',
'Nullable(Float32)': 'float32',
'Nullable(Float64)': 'float64',
'Nullable(String)': 'object',
'Nullable(FixedString)': 'object', # TODO
'Nullable(Date)': 'Date',
'Nullable(DateTime)': 'DateTime'
}

ibis_to_clickhouse = {
'null': 'Null',
'int8': 'Int8',
'int16': 'Int16',
'int32': 'Int32',
'int64': 'Int64',
'float': 'Float32',
'double': 'Float64',
'string': 'String',
'boolean': 'UInt8',
'date': 'Date',
'timestamp': 'DateTime',
'decimal': 'UInt64' # see yandex/clickhouse#253
}

clickhouse_to_ibis = {
'Null': 'null',
'UInt64': 'int64',
'UInt32': 'int32',
'UInt16': 'int16',
'UInt8': 'int8',
'Int64': 'int64',
'Int32': 'int32',
'Int16': 'int16',
'Int8': 'int8',
'Float64': 'double',
'Float32': 'float',
'String': 'string',
'FixedString': 'string',
'Date': 'date',
'DateTime': 'timestamp'
}
27 changes: 11 additions & 16 deletions ibis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,12 +174,7 @@ def sql(self, query):
"""
# Get the schema by adding a LIMIT 0 on to the end of the query. If
# there is already a limit in the query, we find and remove it
limited_query = """\
SELECT *
FROM (
{0}
) t0
LIMIT 0""".format(query)
limited_query = 'SELECT * FROM ({}) t0 LIMIT 0'.format(query)
schema = self._get_schema_using_query(limited_query)

node = ops.SQLQueryResult(query, schema, self)
Expand Down Expand Up @@ -226,7 +221,7 @@ def execute(self, expr, params=None, limit='default', async=False):
Array expressions: pandas.Series
Scalar expressions: Python scalar value
"""
ast = self._build_ast_ensure_limit(expr, limit)
ast = self._build_ast_ensure_limit(expr, limit, params=params)

if len(ast.queries) > 1:
raise NotImplementedError
Expand All @@ -245,12 +240,12 @@ def compile(self, expr, params=None, limit=None):
-------
output : single query or list of queries
"""
ast = self._build_ast_ensure_limit(expr, limit)
ast = self._build_ast_ensure_limit(expr, limit, params=params)
queries = [query.compile() for query in ast.queries]
return queries[0] if len(queries) == 1 else queries

def _build_ast_ensure_limit(self, expr, limit):
ast = self._build_ast(expr)
def _build_ast_ensure_limit(self, expr, limit, params=None):
ast = self._build_ast(expr, params=params)
# note: limit can still be None at this point, if the global
# default_limit is None
for query in reversed(ast.queries):
Expand Down Expand Up @@ -298,7 +293,7 @@ def explain(self, expr):
return 'Query:\n{0}\n\n{1}'.format(util.indent(query, 2),
'\n'.join(result))

def _build_ast(self, expr):
def _build_ast(self, expr, params=None):
# Implement in clients
raise NotImplementedError

Expand All @@ -313,14 +308,14 @@ class QueryPipeline(object):
pass


def execute(expr, limit='default', async=False):
def execute(expr, limit='default', async=False, params=None):
backend = find_backend(expr)
return backend.execute(expr, limit=limit, async=async)
return backend.execute(expr, limit=limit, async=async, params=params)


def compile(expr, limit=None):
def compile(expr, limit=None, params=None):
backend = find_backend(expr)
return backend.compile(expr, limit=limit)
return backend.compile(expr, limit=limit, params=params)


def find_backend(expr):
Expand Down Expand Up @@ -368,7 +363,7 @@ def __repr__(self):
def __dir__(self):
attrs = dir(type(self))
unqualified_tables = [self._unqualify(x) for x in self.tables]
return list(sorted(set(attrs + unqualified_tables)))
return list(frozenset(attrs + unqualified_tables))

def __contains__(self, key):
return key in self.tables
Expand Down
14 changes: 14 additions & 0 deletions ibis/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import itertools

import numpy as np
Expand Down Expand Up @@ -65,3 +67,15 @@ def dict_values(x):
range = xrange # noqa: F821

integer_types = six.integer_types + (np.integer,)


# pandas compat
try:
from pandas.api.types import DatetimeTZDtype # noqa: F401
except ImportError:
from pandas.types.dtypes import DatetimeTZDtype # noqa: F401

try:
from pandas.core.tools.datetimes import to_time # noqa: F401
except ImportError:
from pandas.tseries.tools import to_time # noqa: F401
2 changes: 1 addition & 1 deletion ibis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __getattr__(self, key):
return _get_option(prefix)

def __dir__(self):
return sorted(list(self.d.keys()))
return list(self.d.keys())


# For user convenience, we'd like to have the available options described
Expand Down
8 changes: 8 additions & 0 deletions ibis/config_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,11 @@
cf.register_option('temp_db', '__ibis_tmp', impala_temp_db_doc)
cf.register_option('temp_hdfs_path', '/tmp/ibis',
impala_temp_hdfs_path_doc)


clickhouse_temp_db_doc = """
Database to use for temporary tables, views. functions, etc.
"""

with cf.config_prefix('clickhouse'):
cf.register_option('temp_db', '__ibis_tmp', clickhouse_temp_db_doc)
88 changes: 57 additions & 31 deletions ibis/expr/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,68 @@

def sub_for(expr, substitutions):
mapping = {repr(k.op()): v for k, v in substitutions}
return _subs(expr, mapping)
substitutor = Substitutor()
return substitutor.substitute(expr, mapping)


def _expr_key(expr):
try:
return repr(expr.op())
name = expr.get_name()
except (AttributeError, ExpressionError):
name = None

try:
op = expr.op()
except AttributeError:
return expr
return expr, name
else:
return repr(op), name


@toolz.memoize(key=lambda args, kwargs: _expr_key(args[0]))
def _subs(expr, mapping):
"""Substitute expressions with other expressions
"""
node = expr.op()
key = repr(node)
if key in mapping:
return mapping[key]
if node.blocks():
return expr

new_args = list(node.args)
unchanged = True
for i, arg in enumerate(new_args):
if isinstance(arg, ir.Expr):
new_arg = _subs(arg, mapping)
unchanged = unchanged and new_arg is arg
new_args[i] = new_arg
if unchanged:
return expr
try:
new_node = type(node)(*new_args)
except IbisTypeError:
return expr
class Substitutor(object):

def __init__(self):
cache = toolz.memoize(key=lambda args, kwargs: _expr_key(args[0]))
self.substitute = cache(self._substitute)

def _substitute(self, expr, mapping):
"""Substitute expressions with other expressions.
Parameters
----------
expr : ibis.expr.types.Expr
mapping : Dict, OrderedDict
return expr._factory(new_node, name=getattr(expr, '_name', None))
Returns
-------
new_expr : ibis.expr.types.Expr
"""
node = expr.op()
key = repr(node)
if key in mapping:
return mapping[key]
if node.blocks():
return expr

new_args = list(node.args)
unchanged = True
for i, arg in enumerate(new_args):
if isinstance(arg, ir.Expr):
new_arg = self.substitute(arg, mapping)
unchanged = unchanged and new_arg is arg
new_args[i] = new_arg
if unchanged:
return expr
try:
new_node = type(node)(*new_args)
except IbisTypeError:
return expr

try:
name = expr.get_name()
except ExpressionError:
name = None
return expr._factory(new_node, name=name)


class ScalarAggregate(object):
Expand All @@ -83,7 +109,7 @@ def get_result(self):
try:
name = subbed_expr.get_name()
named_expr = subbed_expr
except:
except ExpressionError:
name = self.default_name
named_expr = subbed_expr.name(self.default_name)

Expand Down Expand Up @@ -139,7 +165,7 @@ def reduction_to_aggregation(expr, default_name='tmp'):
try:
name = expr.get_name()
named_expr = expr
except:
except ExpressionError:
name = default_name
named_expr = expr.name(default_name)

Expand Down Expand Up @@ -794,7 +820,7 @@ def _check_fusion(self, root):
def _maybe_resolve_exprs(table, exprs):
try:
return table._resolve(exprs)
except:
except AttributeError:
return None


Expand Down
53 changes: 28 additions & 25 deletions ibis/expr/analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
import ibis.expr.operations as ops


class BucketLike(ir.ValueOp):
def _validate_closed(closed):
closed = closed.lower()
if closed not in {'left', 'right'}:
raise ValueError("closed must be 'left' or 'right'")
return closed


def _validate_closed(self, closed):
closed = closed.lower()
if closed not in ['left', 'right']:
raise ValueError("closed must be 'left' or 'right'")
return closed
class BucketLike(ir.ValueOp):

@property
def nbuckets(self):
Expand All @@ -42,36 +43,37 @@ def __init__(self, arg, buckets, closed='left', close_extreme=True,
include_under=False, include_over=False):
self.arg = arg
self.buckets = buckets
self.closed = self._validate_closed(closed)
self.closed = _validate_closed(closed)

self.close_extreme = bool(close_extreme)
self.include_over = bool(include_over)
self.include_under = bool(include_under)

if len(buckets) == 0:
if not len(buckets):
raise ValueError('Must be at least one bucket edge')
elif len(buckets) == 1:
if not self.include_under or not self.include_over:
raise ValueError('If one bucket edge provided, must have'
' include_under=True and include_over=True')
raise ValueError(
'If one bucket edge provided, must have '
'include_under=True and include_over=True'
)

ir.ValueOp.__init__(self, self.arg, self.buckets, self.closed,
self.close_extreme, self.include_under,
self.include_over)
super(Bucket, self).__init__(
arg, buckets, self.closed,
self.close_extreme, self.include_under, self.include_over
)

@property
def nbuckets(self):
k = len(self.buckets) - 1
k += int(self.include_over) + int(self.include_under)
return k
return len(self.buckets) - 1 + self.include_over + self.include_under


class Histogram(BucketLike):

def __init__(self, arg, nbins, binwidth, base, closed='left',
aux_hash=None):
def __init__(
self, arg, nbins, binwidth, base, closed='left', aux_hash=None
):
self.arg = arg

self.nbins = nbins
self.binwidth = binwidth
self.base = base
Expand All @@ -82,11 +84,12 @@ def __init__(self, arg, nbins, binwidth, base, closed='left',
elif self.binwidth is not None:
raise ValueError('nbins and binwidth are mutually exclusive')

self.closed = self._validate_closed(closed)

self.closed = _validate_closed(closed)
self.aux_hash = aux_hash
ir.ValueOp.__init__(self, self.arg, self.nbins, self.binwidth,
self.base, self.closed, self.aux_hash)

super(Histogram, self).__init__(
arg, nbins, binwidth, base, self.closed, aux_hash
)

def output_type(self):
# always undefined cardinality (for now)
Expand All @@ -101,12 +104,12 @@ def __init__(self, arg, labels, nulls):
self.labels = labels

card = self.arg.type().cardinality
if len(self.labels) != card:
if len(labels) != card:
raise ValueError('Number of labels must match number of '
'categories: %d' % card)

self.nulls = nulls
ir.ValueOp.__init__(self, self.arg, self.labels, self.nulls)
super(CategoryLabel, self).__init__(self.arg, labels, nulls)

def output_type(self):
return rules.shape_like(self.arg, 'string')
Expand Down
429 changes: 348 additions & 81 deletions ibis/expr/api.py

Large diffs are not rendered by default.

383 changes: 249 additions & 134 deletions ibis/expr/datatypes.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions ibis/expr/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def get_result(self):
text = 'Literal[{}]\n {}'.format(
self._get_type_display(), str(what.value)
)
elif isinstance(what, ir.ScalarParameter):
text = 'ScalarParameter[{}]'.format(self._get_type_display())
elif isinstance(what, ir.Node):
text = self._format_node(self.expr)

Expand Down
422 changes: 276 additions & 146 deletions ibis/expr/operations.py

Large diffs are not rendered by default.

267 changes: 189 additions & 78 deletions ibis/expr/rules.py

Large diffs are not rendered by default.

17 changes: 16 additions & 1 deletion ibis/expr/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def schema_dict(schema):

@pytest.fixture
def table(schema):
return ibis.table(schema, name='schema')
return ibis.table(schema, name='table')


@pytest.fixture(params=list('abcdh'))
Expand Down Expand Up @@ -74,3 +74,18 @@ def col(request):
@pytest.fixture
def con():
return MockConnection()


@pytest.fixture
def alltypes(con):
return con.table('alltypes')


@pytest.fixture
def functional_alltypes(con):
return con.table('functional_alltypes')


@pytest.fixture
def lineitem(con):
return con.table('tpch_lineitem')
12 changes: 6 additions & 6 deletions ibis/expr/tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,19 +345,19 @@ def _get_table_schema(self, name):
name = name.replace('`', '')
return Schema.from_tuples(self._tables[name])

def _build_ast(self, expr):
def _build_ast(self, expr, params=None):
from ibis.impala.compiler import build_ast
return build_ast(expr)
return build_ast(expr, params=params)

def execute(self, expr, limit=None, async=False):
def execute(self, expr, limit=None, async=False, params=None):
if async:
raise NotImplementedError
ast = self._build_ast_ensure_limit(expr, limit)
ast = self._build_ast_ensure_limit(expr, limit, params=params)
for query in ast.queries:
self.executed_queries.append(query.compile())
return None

def compile(self, expr, limit=None):
ast = self._build_ast_ensure_limit(expr, limit)
def compile(self, expr, limit=None, params=None):
ast = self._build_ast_ensure_limit(expr, limit, params=params)
queries = [q.compile() for q in ast.queries]
return queries[0] if len(queries) == 1 else queries
11 changes: 10 additions & 1 deletion ibis/expr/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,18 @@ def test_timestamp_with_invalid_timezone():

def test_timestamp_with_timezone_repr():
ts = dt.Timestamp('UTC')
assert repr(ts) == "Timestamp(timezone='UTC')"
assert repr(ts) == "Timestamp(timezone='UTC', nullable=True)"


def test_timestamp_with_timezone_str():
ts = dt.Timestamp('UTC')
assert str(ts) == "timestamp('UTC')"


def test_time():
ts = dt.time
assert str(ts) == "time"


def test_time_valid():
assert dt.validate_type('time').equals(dt.time)
29 changes: 9 additions & 20 deletions ibis/expr/tests/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,24 @@
import ibis.expr.types as ir
import ibis.expr.operations as ops

from ibis.expr.tests.mocks import MockConnection

import pytest


@pytest.fixture
def con():
return MockConnection()


@pytest.fixture
def lineitem(con):
return con.table('tpch_lineitem')


def test_type_metadata(lineitem):
col = lineitem.l_extendedprice
assert isinstance(col, ir.DecimalColumn)

assert col._precision == 12
assert col._scale == 2
assert col.meta.precision == 12
assert col.meta.scale == 2


def test_cast_scalar_to_decimal():
val = api.literal('1.2345')

casted = val.cast('decimal(15,5)')
assert isinstance(casted, ir.DecimalScalar)
assert casted._precision == 15
assert casted._scale == 5
assert casted.meta.precision == 15
assert casted.meta.scale == 5


def test_decimal_aggregate_function_behavior(lineitem):
Expand All @@ -60,8 +48,8 @@ def test_decimal_aggregate_function_behavior(lineitem):
for func_name in functions:
result = getattr(col, func_name)()
assert isinstance(result, ir.DecimalScalar)
assert result._precision == col._precision
assert result._scale == 38
assert result.meta.precision == col.meta.precision
assert result.meta.scale == 38


def test_where(lineitem):
Expand Down Expand Up @@ -111,13 +99,14 @@ def test_invalid_precision_scale_combo():
def test_decimal_str(lineitem):
col = lineitem.l_extendedprice
t = col.type()
assert str(t) == 'decimal({0:d}, {1:d})'.format(t.precision, t.scale)
assert str(t) == 'decimal({:d}, {:d})'.format(t.precision, t.scale)


def test_decimal_repr(lineitem):
col = lineitem.l_extendedprice
t = col.type()
assert repr(t) == 'Decimal(precision={0:d}, scale={1:d})'.format(
expected = 'Decimal(precision={:d}, scale={:d}, nullable=True)'.format(
t.precision,
t.scale,
)
assert repr(t) == expected
341 changes: 168 additions & 173 deletions ibis/expr/tests/test_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,215 +12,203 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import re

import ibis

from ibis.expr.types import Expr
from ibis.expr.format import ExprFormatter
from ibis.expr.tests.mocks import MockConnection


class TestExprFormatting(unittest.TestCase):
# Uncertain about how much we want to commit to unit tests around the
# particulars of the output at the moment.

def setUp(self):
self.schema = [
('a', 'int8'),
('b', 'int16'),
('c', 'int32'),
('d', 'int64'),
('e', 'float'),
('f', 'double'),
('g', 'string'),
('h', 'boolean')
]
self.schema_dict = dict(self.schema)
self.table = ibis.table(self.schema)
self.con = MockConnection()

def test_format_custom_expr(self):

class CustomExpr(Expr):
def _type_display(self):
return 'my-custom'

op = ibis.literal(5).op()
expr = CustomExpr(op)

result = repr(expr)
expected = """Literal[my-custom]
5"""
assert result == expected

def test_format_table_column(self):
# GH #507
result = repr(self.table.f)
assert 'Column[double*]' in result

def test_format_projection(self):
# This should produce a ref to the projection
proj = self.table[['c', 'a', 'f']]
repr(proj['a'])

def test_table_type_output(self):
foo = ibis.table(
[
('job', 'string'),
('dept_id', 'string'),
('year', 'int32'),
('y', 'double')
], 'foo')

expr = foo.dept_id == foo.view().dept_id
result = repr(expr)
assert 'SelfReference[table]' in result
assert 'UnboundTable[table]' in result

def test_memoize_aggregate_correctly(self):
table = self.table

agg_expr = (table['c'].sum() / table['c'].mean() - 1).name('analysis')
metrics = [
table['a'].sum().name('sum(a)'),
table['b'].mean().name('mean(b)'),
agg_expr,
]

result = table.aggregate(metrics, by=['g'])

formatter = ExprFormatter(result)
formatted = formatter.get_result()

alias = formatter.memo.get_alias(table)
assert formatted.count(alias) == 7

def test_aggregate_arg_names(self):
# Not sure how to test this *well*

t = self.table
def test_format_custom_expr():

by_exprs = [t.g.name('key1'), t.f.round().name('key2')]
metrics = [t.c.sum().name('c'), t.d.mean().name('d')]

expr = self.table.group_by(by_exprs).aggregate(metrics)
result = repr(expr)
assert 'metrics' in result
assert 'by' in result

def test_format_multiple_join_with_projection(self):
# Star schema with fact table
table = ibis.table([
('c', 'int32'),
('f', 'double'),
('foo_id', 'string'),
('bar_id', 'string'),
], 'one')

table2 = ibis.table([
('foo_id', 'string'),
('value1', 'double')
], 'two')
class CustomExpr(Expr):
def _type_display(self):
return 'my-custom'

table3 = ibis.table([
('bar_id', 'string'),
('value2', 'double')
], 'three')
op = ibis.literal(5).op()
expr = CustomExpr(op)

result = repr(expr)
expected = 'Literal[my-custom]\n 5'
assert result == expected


def test_format_table_column(table):
# GH #507
result = repr(table.f)
assert 'Column[double*]' in result

filtered = table[table['f'] > 0]

pred1 = filtered['foo_id'] == table2['foo_id']
pred2 = filtered['bar_id'] == table3['bar_id']
def test_format_projection(table):
# This should produce a ref to the projection
proj = table[['c', 'a', 'f']]
repr(proj['a'])


def test_table_type_output():
foo = ibis.table(
[
('job', 'string'),
('dept_id', 'string'),
('year', 'int32'),
('y', 'double')
], 'foo')

expr = foo.dept_id == foo.view().dept_id
result = repr(expr)
assert 'SelfReference[table]' in result
assert 'UnboundTable[table]' in result

j1 = filtered.left_join(table2, [pred1])
j2 = j1.inner_join(table3, [pred2])

# Project out the desired fields
view = j2[[filtered, table2['value1'], table3['value2']]]
def test_memoize_aggregate_correctly(table):
agg_expr = (table['c'].sum() / table['c'].mean() - 1).name('analysis')
metrics = [
table['a'].sum().name('sum(a)'),
table['b'].mean().name('mean(b)'),
agg_expr,
]

# it works!
repr(view)
result = table.aggregate(metrics, by=['g'])

formatter = ExprFormatter(result)
formatted = formatter.get_result()

alias = formatter.memo.get_alias(table)
assert formatted.count(alias) == 7


def test_aggregate_arg_names(table):
# Not sure how to test this *well*

t = table

by_exprs = [t.g.name('key1'), t.f.round().name('key2')]
metrics = [t.c.sum().name('c'), t.d.mean().name('d')]

expr = t.group_by(by_exprs).aggregate(metrics)
result = repr(expr)
assert 'metrics' in result
assert 'by' in result

def test_memoize_database_table(self):
table = self.con.table('test1')
table2 = self.con.table('test2')

filter_pred = table['f'] > 0
table3 = table[filter_pred]
join_pred = table3['g'] == table2['key']
def test_format_multiple_join_with_projection():
# Star schema with fact table
table = ibis.table([
('c', 'int32'),
('f', 'double'),
('foo_id', 'string'),
('bar_id', 'string'),
], 'one')

joined = table2.inner_join(table3, [join_pred])
table2 = ibis.table([
('foo_id', 'string'),
('value1', 'double')
], 'two')

met1 = (table3['f'] - table2['value']).mean().name('foo')
result = joined.aggregate([met1, table3['f'].sum().name('bar')],
by=[table3['g'], table2['key']])
table3 = ibis.table([
('bar_id', 'string'),
('value2', 'double')
], 'three')

formatted = repr(result)
assert formatted.count('test1') == 1
assert formatted.count('test2') == 1
filtered = table[table['f'] > 0]

def test_memoize_filtered_table(self):
airlines = ibis.table([('dest', 'string'),
('origin', 'string'),
('arrdelay', 'int32')], 'airlines')
pred1 = filtered['foo_id'] == table2['foo_id']
pred2 = filtered['bar_id'] == table3['bar_id']

dests = ['ORD', 'JFK', 'SFO']
t = airlines[airlines.dest.isin(dests)]
delay_filter = t.dest.topk(10, by=t.arrdelay.mean())
j1 = filtered.left_join(table2, [pred1])
j2 = j1.inner_join(table3, [pred2])

result = repr(delay_filter)
assert result.count('Selection') == 1
# Project out the desired fields
view = j2[[filtered, table2['value1'], table3['value2']]]

def test_memoize_insert_sort_key(self):
table = self.con.table('airlines')
# it works!
repr(view)

t = table['arrdelay', 'dest']
expr = (t.group_by('dest')
.mutate(dest_avg=t.arrdelay.mean(),
dev=t.arrdelay - t.arrdelay.mean()))

worst = (expr[expr.dev.notnull()]
.sort_by(ibis.desc('dev'))
.limit(10))
def test_memoize_database_table(con):
table = con.table('test1')
table2 = con.table('test2')

result = repr(worst)
assert result.count('airlines') == 1
filter_pred = table['f'] > 0
table3 = table[filter_pred]
join_pred = table3['g'] == table2['key']

def test_named_value_expr_show_name(self):
expr = self.table.f * 2
expr2 = expr.name('baz')
joined = table2.inner_join(table3, [join_pred])

# it works!
repr(expr)
met1 = (table3['f'] - table2['value']).mean().name('foo')
result = joined.aggregate([met1, table3['f'].sum().name('bar')],
by=[table3['g'], table2['key']])

result2 = repr(expr2)
formatted = repr(result)
assert formatted.count('test1') == 1
assert formatted.count('test2') == 1

# not really committing to a particular output yet
assert 'baz' in result2

def test_memoize_filtered_tables_in_join(self):
# related: GH #667
purchases = ibis.table([('region', 'string'),
('kind', 'string'),
('user', 'int64'),
('amount', 'double')], 'purchases')
def test_memoize_filtered_table():
airlines = ibis.table([('dest', 'string'),
('origin', 'string'),
('arrdelay', 'int32')], 'airlines')

metric = purchases.amount.sum().name('total')
agged = (purchases.group_by(['region', 'kind'])
.aggregate(metric))
dests = ['ORD', 'JFK', 'SFO']
t = airlines[airlines.dest.isin(dests)]
delay_filter = t.dest.topk(10, by=t.arrdelay.mean())

left = agged[agged.kind == 'foo']
right = agged[agged.kind == 'bar']
result = repr(delay_filter)
assert result.count('Selection') == 1

cond = left.region == right.region
joined = (left.join(right, cond)
[left, right.total.name('right_total')])

result = repr(joined)
def test_memoize_insert_sort_key(con):
table = con.table('airlines')

# Join, and one for each aggregation
assert result.count('predicates') == 3
t = table['arrdelay', 'dest']
expr = (t.group_by('dest')
.mutate(dest_avg=t.arrdelay.mean(),
dev=t.arrdelay - t.arrdelay.mean()))

worst = (expr[expr.dev.notnull()]
.sort_by(ibis.desc('dev'))
.limit(10))

result = repr(worst)
assert result.count('airlines') == 1


def test_named_value_expr_show_name(table):
expr = table.f * 2
expr2 = expr.name('baz')

# it works!
repr(expr)

result2 = repr(expr2)

# not really committing to a particular output yet
assert 'baz' in result2


def test_memoize_filtered_tables_in_join():
# related: GH #667
purchases = ibis.table([('region', 'string'),
('kind', 'string'),
('user', 'int64'),
('amount', 'double')], 'purchases')

metric = purchases.amount.sum().name('total')
agged = (purchases.group_by(['region', 'kind'])
.aggregate(metric))

left = agged[agged.kind == 'foo']
right = agged[agged.kind == 'bar']

cond = left.region == right.region
joined = (left.join(right, cond)
[left, right.total.name('right_total')])

result = repr(joined)

# Join, and one for each aggregation
assert result.count('predicates') == 3


def test_argument_repr_shows_name():
Expand All @@ -241,3 +229,10 @@ def test_argument_repr_shows_name():
Literal[int8]
2"""
assert result == expected


def test_scalar_parameter_formatting():
value = ibis.param('array<date>')
assert re.match(
r'param\[\d+\] = ScalarParameter\[array<date>\]', str(value)
)
35 changes: 35 additions & 0 deletions ibis/expr/tests/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ibis.common import IbisTypeError
import ibis.expr.operations as ops
import ibis.expr.types as ir
import ibis.expr.datatypes as dt
from ibis.expr import rules


Expand Down Expand Up @@ -135,3 +136,37 @@ class MyOp(ops.ValueOp):

assert MyOp(1).args[0].equals(ibis.literal(1))
assert MyOp(1.42).args[0].equals(ibis.literal(1.42))


def test_array_rule():

class MyOp(ops.ValueOp):

input_type = [rules.array(dt.double, name='value')]
output_type = rules.type_of_arg(0)

raw_value = [1.0, 2.0, 3.0]
op = MyOp(raw_value)
result = op.value
expected = ibis.literal(raw_value)
assert result.equals(expected)


def test_scalar_default_arg():
class MyOp(ops.ValueOp):

input_type = [
rules.scalar(
value_type=dt.boolean,
optional=True,
default=False,
name='value'
)
]
output_type = rules.type_of_arg(0)

op = MyOp()
assert op.value.equals(ibis.literal(False))

op = MyOp(True)
assert op.value.equals(ibis.literal(True))
56 changes: 20 additions & 36 deletions ibis/expr/tests/test_sql_builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,6 @@
import ibis.expr.types as ir

from ibis.tests.util import assert_equal
from ibis.expr.tests.mocks import MockConnection


@pytest.fixture
def con():
return MockConnection()


@pytest.fixture
def alltypes(con):
return con.table('functional_alltypes')


@pytest.fixture
def lineitem(con):
return con.table('tpch_lineitem')


@pytest.fixture
Expand Down Expand Up @@ -67,19 +51,19 @@ def function(request):
'double_col',
]
)
def test_abs(alltypes, lineitem, colname):
def test_abs(functional_alltypes, lineitem, colname):
fname = 'abs'
op = ops.Abs

expr = alltypes[colname]
expr = functional_alltypes[colname]
_check_unary_op(expr, fname, op, type(expr))

expr = lineitem.l_extendedprice
_check_unary_op(expr, fname, op, type(expr))


def test_group_concat(alltypes):
col = alltypes.string_col
def test_group_concat(functional_alltypes):
col = functional_alltypes.string_col

expr = col.group_concat()
assert isinstance(expr.op(), ops.GroupConcat)
Expand All @@ -91,9 +75,9 @@ def test_group_concat(alltypes):
sep == '|'


def test_zeroifnull(alltypes):
dresult = alltypes.double_col.zeroifnull()
iresult = alltypes.int_col.zeroifnull()
def test_zeroifnull(functional_alltypes):
dresult = functional_alltypes.double_col.zeroifnull()
iresult = functional_alltypes.int_col.zeroifnull()

assert type(dresult.op()) == ops.ZeroIfNull
assert type(dresult) == ir.DoubleColumn
Expand All @@ -102,23 +86,23 @@ def test_zeroifnull(alltypes):
assert type(iresult) == type(iresult)


def test_fillna(alltypes):
result = alltypes.double_col.fillna(5)
def test_fillna(functional_alltypes):
result = functional_alltypes.double_col.fillna(5)
assert isinstance(result, ir.DoubleColumn)

assert isinstance(result.op(), ops.IfNull)

result = alltypes.bool_col.fillna(True)
result = functional_alltypes.bool_col.fillna(True)
assert isinstance(result, ir.BooleanColumn)

# Highest precedence type
result = alltypes.int_col.fillna(alltypes.bigint_col)
result = functional_alltypes.int_col.fillna(functional_alltypes.bigint_col)
assert isinstance(result, ir.Int64Column)


def test_ceil_floor(alltypes, lineitem):
cresult = alltypes.double_col.ceil()
fresult = alltypes.double_col.floor()
def test_ceil_floor(functional_alltypes, lineitem):
cresult = functional_alltypes.double_col.ceil()
fresult = functional_alltypes.double_col.floor()
assert isinstance(cresult, ir.Int64Column)
assert isinstance(fresult, ir.Int64Column)
assert type(cresult.op()) == ops.Ceil
Expand All @@ -139,8 +123,8 @@ def test_ceil_floor(alltypes, lineitem):
assert fresult.meta == dec_col.meta


def test_sign(alltypes, lineitem):
result = alltypes.double_col.sign()
def test_sign(functional_alltypes, lineitem):
result = functional_alltypes.double_col.sign()
assert isinstance(result, ir.FloatColumn)
assert type(result.op()) == ops.Sign

Expand All @@ -152,18 +136,18 @@ def test_sign(alltypes, lineitem):
assert isinstance(result, ir.FloatColumn)


def test_round(alltypes, lineitem):
result = alltypes.double_col.round()
def test_round(functional_alltypes, lineitem):
result = functional_alltypes.double_col.round()
assert isinstance(result, ir.Int64Column)
assert result.op().args[1] is None

result = alltypes.double_col.round(2)
result = functional_alltypes.double_col.round(2)
assert isinstance(result, ir.DoubleColumn)
assert result.op().args[1].equals(ibis.literal(2))

# Even integers are double (at least in Impala, check with other DB
# implementations)
result = alltypes.int_col.round(2)
result = functional_alltypes.int_col.round(2)
assert isinstance(result, ir.DoubleColumn)

dec = lineitem.l_extendedprice
Expand Down
30 changes: 30 additions & 0 deletions ibis/expr/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,24 @@ def test_join_no_predicate_list(con):
assert_equal(joined, expected)


def test_asof_join():
left = ibis.table([('time', 'int32'), ('value', 'double')])
right = ibis.table([('time', 'int32'), ('value2', 'double')])
joined = api.asof_join(left, right, 'time')
pred = joined.op().predicates[0].op()
assert pred.left.op().name == pred.right.op().name == 'time'


def test_asof_join_with_by():
left = ibis.table(
[('time', 'int32'), ('key', 'int32'), ('value', 'double')])
right = ibis.table(
[('time', 'int32'), ('key', 'int32'), ('value2', 'double')])
joined = api.asof_join(left, right, 'time', by='key')
by = joined.op().by_predicates[0].op()
assert by.left.op().name == by.right.op().name == 'key'


def test_equijoin_schema_merge():
table1 = ibis.table([('key1', 'string'), ('value1', 'double')])
table2 = ibis.table([('key2', 'string'), ('stuff', 'int32')])
Expand Down Expand Up @@ -964,6 +982,18 @@ def test_join_invalid_refs(con):
t1.inner_join(t2, [predicate])


def test_join_invalid_expr_type(con):
left = con.table('star1')
invalid_right = left.foo_id
join_key = ['bar_id']

with pytest.raises(TypeError) as e:
left.inner_join(invalid_right, join_key)

message = str(e)
assert type(invalid_right).__name__ in message


def test_join_non_boolean_expr(con):
t1 = con.table('star1')
t2 = con.table('star2')
Expand Down
51 changes: 17 additions & 34 deletions ibis/expr/tests/test_timestamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,9 @@
import ibis.expr.types as ir
from ibis.expr.rules import highest_precedence_type

from ibis.expr.tests.mocks import MockConnection


@pytest.fixture
def con():
return MockConnection()


@pytest.fixture
def alltypes(con):
return con.table('alltypes')


@pytest.fixture
def col(alltypes):
return alltypes.i


def test_field_select(col):
assert isinstance(col, ir.TimestampColumn)
def test_field_select(alltypes):
assert isinstance(alltypes.i, ir.TimestampColumn)


def test_string_cast_to_timestamp(alltypes):
Expand All @@ -66,9 +49,9 @@ def test_string_cast_to_timestamp(alltypes):
('millisecond', ops.ExtractMillisecond, ir.Int32Column),
]
)
def test_extract_fields(field, expected_operation, expected_type, col):
def test_extract_fields(field, expected_operation, expected_type, alltypes):
# type-size may be database specific
result = getattr(col, field)()
result = getattr(alltypes.i, field)()
assert result.get_name() == field
assert isinstance(result, expected_type)
assert isinstance(result.op(), expected_operation)
Expand Down Expand Up @@ -103,34 +86,34 @@ def test_integer_to_timestamp():
assert False


def test_comparison_timestamp(col):
expr = col > (col.min() + ibis.day(3))
def test_comparison_timestamp(alltypes):
expr = alltypes.i > (alltypes.i.min() + ibis.day(3))
assert isinstance(expr, ir.BooleanColumn)


def test_comparisons_string(col):
def test_comparisons_string(alltypes):
val = '2015-01-01 00:00:00'
expr = col > val
expr = alltypes.i > val
op = expr.op()
assert isinstance(op.right, ir.TimestampScalar)

expr2 = val < col
expr2 = val < alltypes.i
op = expr2.op()
assert isinstance(op, ops.Greater)
assert isinstance(op.right, ir.TimestampScalar)


def test_comparisons_pandas_timestamp(col):
def test_comparisons_pandas_timestamp(alltypes):
val = pd.Timestamp('2015-01-01 00:00:00')
expr = col > val
expr = alltypes.i > val
op = expr.op()
assert isinstance(op.right, ir.TimestampScalar)


@pytest.mark.xfail(raises=TypeError, reason='Upstream pandas bug')
def test_greater_comparison_pandas_timestamp(col):
def test_greater_comparison_pandas_timestamp(alltypes):
val = pd.Timestamp('2015-01-01 00:00:00')
expr2 = val < col
expr2 = val < alltypes.i
op = expr2.op()
assert isinstance(op, ops.Greater)
assert isinstance(op.right, ir.TimestampScalar)
Expand All @@ -151,9 +134,9 @@ def test_timestamp_precedence():
]
)
def test_timestamp_field_access_on_date(
field, expected_operation, expected_type, col
field, expected_operation, expected_type, alltypes
):
date_col = col.cast('date')
date_col = alltypes.i.cast('date')
result = getattr(date_col, field)()
assert isinstance(result, expected_type)
assert isinstance(result.op(), expected_operation)
Expand All @@ -169,9 +152,9 @@ def test_timestamp_field_access_on_date(
]
)
def test_timestamp_field_access_on_date_failure(
field, expected_operation, expected_type, col
field, expected_operation, expected_type, alltypes
):
date_col = col.cast('date')
date_col = alltypes.i.cast('date')
with pytest.raises(AttributeError):
getattr(date_col, field)

Expand Down
350 changes: 326 additions & 24 deletions ibis/expr/tests/test_value_exprs.py

Large diffs are not rendered by default.

13 changes: 3 additions & 10 deletions ibis/expr/tests/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,6 @@
from ibis.expr import rules # noqa: E402


@pytest.fixture
def t():
return ibis.table(
[('a', 'int64'), ('b', 'double'), ('c', 'string')], name='t'
)


@pytest.mark.parametrize(
'expr_func',
[
Expand All @@ -31,10 +24,10 @@ def t():
)
]
)
def test_exprs(t, expr_func):
expr = expr_func(t)
def test_exprs(table, expr_func):
expr = expr_func(table)
graph = viz.to_graph(expr)
assert str(hash(repr(t.op()))) in graph.source
assert str(hash(repr(table.op()))) in graph.source
assert str(hash(repr(expr.op()))) in graph.source


Expand Down
30 changes: 13 additions & 17 deletions ibis/expr/tests/test_window_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@
from ibis.tests.util import assert_equal


@pytest.fixture
def t(con):
return con.table('alltypes')


def test_compose_group_by_apis(t):
t = t
def test_compose_group_by_apis(alltypes):
t = alltypes
w = ibis.window(group_by=t.g, order_by=t.f)

diff = t.d - t.d.lag()
Expand All @@ -43,8 +38,8 @@ def test_compose_group_by_apis(t):
assert_equal(expr, expr3)


def test_combine_windows(t):
t = t
def test_combine_windows(alltypes):
t = alltypes
w1 = ibis.window(group_by=t.g, order_by=t.f)
w2 = ibis.window(preceding=5, following=5)

Expand All @@ -61,9 +56,9 @@ def test_combine_windows(t):
assert_equal(w5, expected)


def test_over_auto_bind(t):
def test_over_auto_bind(alltypes):
# GH #542
t = t
t = alltypes

w = ibis.window(group_by='g', order_by='f')

Expand All @@ -74,9 +69,9 @@ def test_over_auto_bind(t):
assert_equal(actual_window, expected)


def test_window_function_bind(t):
def test_window_function_bind(alltypes):
# GH #532
t = t
t = alltypes

w = ibis.window(group_by=lambda x: x.g,
order_by=lambda x: x.f)
Expand Down Expand Up @@ -121,17 +116,18 @@ def test_mutate_sorts_keys(con):
assert_equal(result, expected)


def test_window_bind_to_table(t):
def test_window_bind_to_table(alltypes):
t = alltypes
w = ibis.window(group_by='g', order_by=ibis.desc('f'))

w2 = w.bind(t)
w2 = w.bind(alltypes)
expected = ibis.window(group_by=t.g,
order_by=ibis.desc(t.f))

assert_equal(w2, expected)


def test_preceding_following_validate(t):
def test_preceding_following_validate(alltypes):
# these all work
[
ibis.window(preceding=0),
Expand Down Expand Up @@ -160,5 +156,5 @@ def test_preceding_following_validate(t):


@pytest.mark.xfail(raises=AssertionError, reason='NYT')
def test_window_equals(t):
def test_window_equals(alltypes):
assert False
Loading