Skip to content

Commit

Permalink
refactor(sqlite): remove roundtripping for DayOfWeekIndex and DayOfWe…
Browse files Browse the repository at this point in the history
…ekName
  • Loading branch information
mesejo authored and cpcloud committed Apr 28, 2023
1 parent 69cdee5 commit b5a2bc5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 33 deletions.
30 changes: 0 additions & 30 deletions ibis/backends/sqlite/compiler.py
Expand Up @@ -16,9 +16,7 @@
import sqlalchemy as sa
from sqlalchemy.dialects import sqlite

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy import (
AlchemyCompiler,
AlchemyExprTranslator,
Expand All @@ -36,34 +34,6 @@ class SQLiteExprTranslator(AlchemyExprTranslator):
rewrites = SQLiteExprTranslator.rewrites


@rewrites(ops.DayOfWeekIndex)
def day_of_week_index(op):
# TODO(kszucs): avoid expr roundtrip
expr = op.arg.to_expr()
new_expr = ((expr.strftime('%w').cast(dt.int16) + 6) % 7).cast(dt.int16)
return new_expr.op()


@rewrites(ops.DayOfWeekName)
def day_of_week_name(op):
# TODO(kszucs): avoid expr roundtrip
expr = op.arg.to_expr()
new_expr = (
expr.day_of_week.index()
.case()
.when(0, 'Monday')
.when(1, 'Tuesday')
.when(2, 'Wednesday')
.when(3, 'Thursday')
.when(4, 'Friday')
.when(5, 'Saturday')
.when(6, 'Sunday')
.else_(ibis.NA)
.end()
)
return new_expr.op()


class SQLiteCompiler(AlchemyCompiler):
translator_class = SQLiteExprTranslator
support_values_syntax_in_select = False
Expand Down
23 changes: 20 additions & 3 deletions ibis/backends/sqlite/registry.py
Expand Up @@ -72,9 +72,7 @@ def _default_cast_impl(arg, from_, to, translator=None):

def _strftime_int(fmt):
def translator(t, op):
# TODO(kszucs): avoid expr roundtrip, should be done in rewrite phase
new_expr = op.arg.to_expr().strftime(fmt).cast(dt.int32)
return t.translate(new_expr.op())
return sa.cast(sa.func.strftime(fmt, t.translate(op.arg)), sa.INT)

return translator

Expand Down Expand Up @@ -271,6 +269,18 @@ def translate(t, op: ops.ArgMin | ops.ArgMax):
return translate


def _day_of_the_week_name(arg):
return sa.case(
(sa.func.strftime('%w', arg) == '0', 'Sunday'),
(sa.func.strftime('%w', arg) == '1', 'Monday'),
(sa.func.strftime('%w', arg) == '2', 'Tuesday'),
(sa.func.strftime('%w', arg) == '3', 'Wednesday'),
(sa.func.strftime('%w', arg) == '4', 'Thursday'),
(sa.func.strftime('%w', arg) == '5', 'Friday'),
(sa.func.strftime('%w', arg) == '6', 'Saturday'),
)


operation_registry.update(
{
# TODO(kszucs): don't dispatch on op.arg since that should be always an
Expand Down Expand Up @@ -332,6 +342,13 @@ def translate(t, op: ops.ArgMin | ops.ArgMax):
ops.ExtractMillisecond: fixed_arity(
lambda arg: (sa.func.strftime('%f', arg) * 1000) % 1000, 1
),
ops.DayOfWeekIndex: fixed_arity(
lambda arg: sa.cast(
sa.cast(sa.func.strftime('%w', arg) + 6, sa.SMALLINT) % 7, sa.SMALLINT
),
1,
),
ops.DayOfWeekName: fixed_arity(_day_of_the_week_name, 1),
ops.TimestampNow: fixed_arity(lambda: sa.func.datetime("now"), 0),
ops.RegexSearch: fixed_arity(sa.func._ibis_sqlite_regex_search, 2),
ops.RegexReplace: fixed_arity(sa.func._ibis_sqlite_regex_replace, 3),
Expand Down

0 comments on commit b5a2bc5

Please sign in to comment.