Skip to content

Commit

Permalink
refactor(sqlite): use lambda to define backend operations
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysztof-kwitt authored and cpcloud committed Dec 31, 2022
1 parent dbd61a5 commit b937391
Showing 1 changed file with 52 additions and 95 deletions.
147 changes: 52 additions & 95 deletions ibis/backends/sqlite/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,29 +63,6 @@ def _default_cast_impl(arg, from_, to):
return sa.cast(arg, to_sqla_type(to))


# TODO(kszucs): don't dispatch on op.arg since that should be always an
# instance of ops.Value
def _cast(t, op):
arg = t.translate(op.arg)

return sqlite_cast(arg, op.arg.output_dtype, op.to)


def _string_right(t, op):
f = sa.func.substr

sa_arg = t.translate(op.arg)
sa_length = t.translate(op.nchars)

return f(sa_arg, -sa_length, sa_length)


def _strftime(t, op):
sa_arg = t.translate(op.arg)
sa_format = t.translate(op.format_str)
return sa.func.strftime(sa_format, sa_arg)


def _strftime_int(fmt):
def translator(t, op):
# TODO(kszucs): avoid expr roundtrip, should be done in rewrite phase
Expand All @@ -108,12 +85,6 @@ def _extract_quarter(t, op):
return sa.cast(t.translate(expr_new.op()), sa.Integer)


def _extract_epoch_seconds(t, op):
# example: (julianday('now') - 2440587.5) * 86400.0
sa_expr = (sa.func.julianday(t.translate(op.arg)) - 2440587.5) * 86400.0
return sa.cast(sa_expr, sa.BigInteger)


_truncate_modifiers = {
'Y': 'start of year',
'M': 'start of month',
Expand All @@ -136,26 +107,13 @@ def translator(t, op):
return translator


def _millisecond(t, op):
sa_arg = t.translate(op.arg)
fractional_second = sa.func.strftime('%f', sa_arg)
return (fractional_second * 1000) % 1000


def _log(t, op):
sa_arg = t.translate(op.arg)
if op.base is None:
return sa.func._ibis_sqlite_ln(sa_arg)
return sa.func._ibis_sqlite_log(sa_arg, t.translate(op.base))


def _repeat(t, op):
arg = t.translate(op.arg)
times = t.translate(op.times)
f = sa.func
return f.replace(f.substr(f.quote(f.zeroblob((times + 1) / 2)), 3, times), '0', arg)


def _generic_pad(arg, length, pad):
f = sa.func
arg_length = f.length(arg)
Expand All @@ -172,16 +130,6 @@ def _generic_pad(arg, length, pad):
)


def _lpad(t, op):
arg, length, pad = map(t.translate, op.args)
return _generic_pad(arg, length, pad) + arg


def _rpad(t, op):
arg, length, pad = map(t.translate, op.args)
return arg + _generic_pad(arg, length, pad)


def _extract_week_of_year(t, op):
"""ISO week of year.
Expand Down Expand Up @@ -235,73 +183,82 @@ def _string_join(t, op):
)


def _string_concat(t, op):
return functools.reduce(operator.add, map(t.translate, op.arg))


def _date_from_ymd(t, op):
ymdstr = sa.func.printf(
'%04d-%02d-%02d',
t.translate(op.year),
t.translate(op.month),
t.translate(op.day),
)
return sa.func.date(ymdstr)


def _timestamp_from_ymdhms(t, op):
y, mo, d, h, m, s = (t.translate(x) for x in op.args)
timestr = sa.func.printf('%04d-%02d-%02d %02d:%02d:%02d%s', y, mo, d, h, m, s)
return sa.func.datetime(timestr)


def _time_from_hms(t, op):
timestr = sa.func.printf(
'%02d:%02d:%02d',
t.translate(op.hours),
t.translate(op.minutes),
t.translate(op.seconds),
)
return sa.func.time(timestr)


operation_registry.update(
{
ops.Cast: _cast,
ops.DateFromYMD: _date_from_ymd,
ops.StrRight: _string_right,
# TODO(kszucs): don't dispatch on op.arg since that should be always an
# instance of ops.Value
ops.Cast: (
lambda t, op: sqlite_cast(t.translate(op.arg), op.arg.output_dtype, op.to)
),
ops.StrRight: fixed_arity(
lambda arg, nchars: sa.func.substr(arg, -nchars, nchars), 2
),
ops.StringFind: _gen_string_find(sa.func.instr),
ops.StringJoin: _string_join,
ops.StringConcat: _string_concat,
ops.StringConcat: (
lambda t, op: functools.reduce(operator.add, map(t.translate, op.arg))
),
ops.Least: varargs(sa.func.min),
ops.Greatest: varargs(sa.func.max),
ops.IfNull: fixed_arity(sa.func.ifnull, 2),
ops.DateFromYMD: _date_from_ymd,
ops.TimeFromHMS: _time_from_hms,
ops.TimestampFromYMDHMS: _timestamp_from_ymdhms,
ops.DateFromYMD: fixed_arity(
lambda y, m, d: sa.func.date(sa.func.printf('%04d-%02d-%02d', y, m, d)), 3
),
ops.TimeFromHMS: fixed_arity(
lambda h, m, s: sa.func.time(sa.func.printf('%02d:%02d:%02d', h, m, s)), 3
),
ops.TimestampFromYMDHMS: fixed_arity(
lambda y, mo, d, h, m, s: sa.func.datetime(
sa.func.printf('%04d-%02d-%02d %02d:%02d:%02d%s', y, mo, d, h, m, s)
),
6,
),
ops.DateTruncate: _truncate(sa.func.date),
ops.Date: unary(sa.func.date),
ops.Time: unary(sa.func.time),
ops.TimestampTruncate: _truncate(sa.func.datetime),
ops.Strftime: _strftime,
ops.Strftime: fixed_arity(
lambda arg, format_str: sa.func.strftime(format_str, arg), 2
),
ops.ExtractYear: _strftime_int('%Y'),
ops.ExtractMonth: _strftime_int('%m'),
ops.ExtractDay: _strftime_int('%d'),
ops.ExtractWeekOfYear: _extract_week_of_year,
ops.ExtractDayOfYear: _strftime_int('%j'),
ops.ExtractQuarter: _extract_quarter,
ops.ExtractEpochSeconds: _extract_epoch_seconds,
# example: (julianday('now') - 2440587.5) * 86400.0
ops.ExtractEpochSeconds: fixed_arity(
lambda arg: sa.cast(
(sa.func.julianday(arg) - 2440587.5) * 86400.0, sa.BigInteger
),
1,
),
ops.ExtractHour: _strftime_int('%H'),
ops.ExtractMinute: _strftime_int('%M'),
ops.ExtractSecond: _strftime_int('%S'),
ops.ExtractMillisecond: _millisecond,
ops.ExtractMillisecond: fixed_arity(
lambda arg: (sa.func.strftime('%f', arg) * 1000) % 1000, 1
),
ops.TimestampNow: fixed_arity(lambda: sa.func.datetime("now"), 0),
ops.RegexSearch: fixed_arity(sa.func._ibis_sqlite_regex_search, 2),
ops.RegexReplace: fixed_arity(sa.func._ibis_sqlite_regex_replace, 3),
ops.RegexExtract: fixed_arity(sa.func._ibis_sqlite_regex_extract, 3),
ops.LPad: _lpad,
ops.RPad: _rpad,
ops.Repeat: _repeat,
ops.LPad: fixed_arity(
lambda arg, length, pad: _generic_pad(arg, length, pad) + arg, 3
),
ops.RPad: fixed_arity(
lambda arg, length, pad: arg + _generic_pad(arg, length, pad), 3
),
ops.Repeat: fixed_arity(
lambda arg, times: sa.func.replace(
sa.func.substr(
sa.func.quote(sa.func.zeroblob((times + 1) / 2)), 3, times
),
'0',
arg,
),
2,
),
ops.Reverse: unary(sa.func._ibis_sqlite_reverse),
ops.StringAscii: unary(sa.func._ibis_sqlite_string_ascii),
ops.Capitalize: unary(sa.func._ibis_sqlite_capitalize),
Expand Down

0 comments on commit b937391

Please sign in to comment.