Skip to content

Commit

Permalink
Added mapd to backend tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab committed May 4, 2018
1 parent 465fafd commit 53d4414
Show file tree
Hide file tree
Showing 10 changed files with 210 additions and 62 deletions.
4 changes: 3 additions & 1 deletion ibis/mapd/api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from ibis.config import options
from ibis.mapd.compiler import dialect, compiles, rewrites # noqa: F401
from ibis.mapd.compiler import ( # noqa: F401
dialect, compiles, rewrites, unsupported_operations
)
from ibis.mapd.client import MapDClient, EXECUTION_TYPE_CURSOR

import ibis.common as com
Expand Down
2 changes: 1 addition & 1 deletion ibis/mapd/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class MapDDataType(object):
v: k for k, v in dtypes.items()
}

def __init__(self, typename, nullable=False):
def __init__(self, typename, nullable=True):
if typename not in self.dtypes:
raise com.UnsupportedBackendType(typename)
self.typename = typename
Expand Down
12 changes: 9 additions & 3 deletions ibis/mapd/compiler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from six import StringIO
from . import operations as mapd_ops
from .identifiers import quote_identifier # noqa: F401
from .operations import _type_to_sql_string # noqa: F401
from .operations import _type_to_sql_string, _unsupported_ops # noqa: F401
from ibis.expr.api import _add_methods, _unary_op, _binop_expr

import ibis.common as com
Expand Down Expand Up @@ -195,9 +195,15 @@ class MapDDialect(compiles.Dialect):
compiles = MapDExprTranslator.compiles
rewrites = MapDExprTranslator.rewrites

compiles(ops.Distance, mapd_ops.distance)

mapd_reg = mapd_ops._operation_registry
unsupported_operations = frozenset(_unsupported_ops.keys())

compiles(ops.Distance, mapd_ops.distance)
rewrites(ops.All, mapd_ops._all)
rewrites(ops.Any, mapd_ops._any)
rewrites(ops.NotAll, mapd_ops._not_all)
rewrites(ops.NotAny, mapd_ops._not_any)
rewrites(ops.IfNull, mapd_ops.raise_unsupported_expr_error)

_add_methods(
ir.NumericValue, dict(
Expand Down
30 changes: 10 additions & 20 deletions ibis/mapd/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,12 @@
'left', # select/join
'like', # comparison operator
'limit', # select
# 'logicalaggregate', # explain
# 'logicalcalc', # explain
# 'logicalchi', # explain
# 'logicalcorrelate', # explain
# 'logicaldelta', # explain
# 'logicalexchange', # explain
# 'logicalfilter', # explain
# 'logicalintersect', # explain
# 'logicaljoin', # explain
# 'logicalmatch', # explain
# 'logicalminus', # explain
# 'logicalproject', # explain
# 'logicalsort', # explain
# 'logicaltablefunctionscan', # explain
# 'logicaltablemodify', # explain
# 'logicaltablescan', # explain
# 'logicalunion', # explain
# 'logicalvalues', # explain
# 'logicalwindow', # explain
'max',
'min',
'std',
'count',
'mean',
'sum',
'nullif', # comparison operator
'nulls', # select/order
'not', # logical operator
Expand Down Expand Up @@ -148,7 +135,10 @@


def quote_identifier(name, quotechar='"', force=False):
if force or name.count(' ') or name in _identifiers:
if (
(force or name.count(' ') or name in _identifiers) and
quotechar not in name
):
return '{0}{1}{0}'.format(quotechar, name)
else:
return name
120 changes: 105 additions & 15 deletions ibis/mapd/operations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from datetime import date, datetime
from ibis.mapd.identifiers import quote_identifier
from six import StringIO
from ibis.mapd.identifiers import quote_identifier, _identifiers
from ibis.impala import compiler as impala_compiler
from six import StringIO


import ibis
import ibis.common as com
Expand Down Expand Up @@ -53,6 +54,46 @@ def _cast(translator, expr):
return 'CAST({0!s} AS {1!s})'.format(arg_, type_)


def _all(expr):
op = expr.op()
arg = op.args[0]

if isinstance(arg, ir.BooleanValue):
arg = arg.ifelse(1, 0)

return (1 - arg).sum() == 0


def _any(expr):
op = expr.op()
arg = op.args[0]

if isinstance(arg, ir.BooleanValue):
arg = arg.ifelse(1, 0)

return arg.sum() >= 0


def _not_any(expr):
op = expr.op()
arg = op.args[0]

if isinstance(arg, ir.BooleanValue):
arg = arg.ifelse(1, 0)

return arg.sum() == 0


def _not_all(expr):
op = expr.op()
arg = op.args[0]

if isinstance(arg, ir.BooleanValue):
arg = arg.ifelse(1, 0)

return (1 - arg).sum() != 0


def _parenthesize(translator, expr):
op = expr.op()
op_klass = type(op)
Expand Down Expand Up @@ -209,7 +250,19 @@ def _cov(translator, expr):
)


# String
# MATH

def _round(translator, expr):
op = expr.op()
arg, digits = op.args

if digits is not None:
return _call(translator, 'round', arg, digits)
else:
return _call(translator, 'round', arg)


# STRING

def _length(func_name='length', sql_func_name='CHAR_LENGTH'):
def __lenght(translator, expr):
Expand All @@ -222,19 +275,15 @@ def __lenght(translator, expr):
return __lenght


def _name_expr(formatted_expr, quoted_name):
return '{0!s} AS {1!s}'.format(formatted_expr, quoted_name)
def _contains(translator, expr):
arg, pattern = expr.op().args[:2]

pattern_ = '%{}%'.format(translator.translate(pattern)[1:-1])

def _round(translator, expr):
op = expr.op()
arg, digits = op.args
return _parenthesize(translator, arg.like(pattern_).ifelse(1, -1))

if digits is not None:
return _call(translator, 'round', arg, digits)
else:
return _call(translator, 'round', arg)

# GENERIC

def _value_list(translator, expr):
op = expr.op()
Expand Down Expand Up @@ -322,12 +371,25 @@ def literal(translator, expr):
raise NotImplementedError(type(expr))


def raise_unsupported_expr_error(expr):
msg = "MapD backend doesn't support {} operation!"
op = expr.op()
raise com.UnsupportedOperationError(msg.format(type(op)))


def raise_unsupported_op_error(translator, expr, *args):
msg = "MapD backend doesn't support {} operation!"
op = expr.op()
raise com.UnsupportedOperationError(msg.format(type(op)))


# translator
def _name_expr(formatted_expr, quoted_name):
if quoted_name in _identifiers:
quoted_name = '"{}"'.format(quoted_name)
return '{} AS {}'.format(formatted_expr, quoted_name)


class CaseFormatter(object):

def __init__(self, translator, base, cases, results, default):
Expand Down Expand Up @@ -403,6 +465,7 @@ def _timestamp_truncate(translator, expr):
def _table_column(translator, expr):
op = expr.op()
field_name = op.name

quoted_name = quote_identifier(field_name, force=True)

table = op.table
Expand Down Expand Up @@ -488,6 +551,7 @@ class ByteLength(ops.StringLength):
_binary_infix_ops = {
# math
ops.Power: fixed_arity('power', 2),
ops.NotEquals: impala_compiler._binary_infix_op('<>'),
}

_unary_ops = {}
Expand Down Expand Up @@ -537,6 +601,7 @@ class ByteLength(ops.StringLength):
ops.StringLength: _length(),
ByteLength: _length('byte_length', 'LENGTH'),
ops.StringSQLILike: binary_infix_op('ilike'),
ops.StringFind: _contains
}

# DATE
Expand Down Expand Up @@ -579,6 +644,7 @@ class ByteLength(ops.StringLength):
# UNSUPPORTED OPERATIONS
_unsupported_ops = [
# generic/aggregation
ops.CMSMedian,
ops.WindowOp,
ops.DecimalPrecision,
ops.DecimalScale,
Expand All @@ -602,22 +668,46 @@ class ByteLength(ops.StringLength):
ops.Lag,
ops.Lead,
ops.NTile,
ops.GroupConcat,
ops.Arbitrary,
ops.NullIf,
ops.NullIfZero,
ops.NullLiteral,
ops.IsNull,
ops.IsInf,
ops.IsNan,
# string
ops.Lowercase,
ops.Uppercase,
ops.StringFind,
ops.FindInSet,
ops.StringReplace,
ops.StringJoin,
ops.StringSplit,
ops.Translate,
ops.StringAscii,
ops.LPad,
ops.RPad,
ops.Strip,
ops.RStrip,
ops.LStrip,
ops.Capitalize,
ops.Substring,
ops.StrRight,
ops.Repeat,
ops.Reverse,
ops.RegexExtract,
ops.RegexReplace,
ops.ParseURL,
# Numeric
ops.Least,
ops.Greatest,
ops.Log2,
ops.Log,
# date/time/timestamp
ops.TimestampFromUNIX,
ops.Date,
ops.TimeTruncate
ops.TimeTruncate,
ops.TimestampDiff
]

_unsupported_ops = {k: raise_unsupported_op_error for k in _unsupported_ops}
Expand All @@ -636,4 +726,4 @@ class ByteLength(ops.StringLength):
_operation_registry.update(_string_ops)
_operation_registry.update(_date_ops)
_operation_registry.update(_agg_ops)
_operation_registry.update(_unsupported_ops)
_operation_registry.update(_unsupported_ops)
4 changes: 4 additions & 0 deletions ibis/tests/all/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,25 @@
lambda t, where: t.double_col.std(how='sample'),
lambda t, where: t.double_col.std(ddof=1),
id='double_col_std',
marks=pytest.mark.xfail,
),
param(
lambda t, where: t.double_col.var(how='sample'),
lambda t, where: t.double_col.var(ddof=1),
id='double_col_var',
marks=pytest.mark.xfail,
),
param(
lambda t, where: t.double_col.std(how='pop'),
lambda t, where: t.double_col.std(ddof=0),
id='double_col_std_pop',
marks=pytest.mark.xfail,
),
param(
lambda t, where: t.double_col.var(how='pop'),
lambda t, where: t.double_col.var(ddof=0),
id='double_col_var_pop',
marks=pytest.mark.xfail,
),
param(
lambda t, where: t.string_col.approx_nunique(),
Expand Down
Loading

0 comments on commit 53d4414

Please sign in to comment.