35 changes: 35 additions & 0 deletions ibis/backends/trino/registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import operator
from functools import partial, reduce
from typing import Literal

Expand All @@ -11,6 +12,7 @@

import ibis
import ibis.common.exceptions as com
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.base.sql.alchemy.registry import (
_literal as _alchemy_literal,
Expand All @@ -27,6 +29,7 @@
varargs,
)
from ibis.backends.postgres.registry import _corr, _covar
from ibis.backends.trino.datatypes import INTERVAL

operation_registry = sqlalchemy_operation_registry.copy()
operation_registry.update(sqlalchemy_window_functions_registry)
Expand Down Expand Up @@ -68,8 +71,19 @@ def _literal(t, op):
return sa.literal(float(value), type_=DOUBLE())
elif dtype.is_integer():
return sa.literal(int(value), type_=t.get_sqla_type(dtype))
elif dtype.is_timestamp():
return sa.cast(
sa.func.from_iso8601_timestamp(value.isoformat()), t.get_sqla_type(dtype)
)
elif dtype.is_date():
return sa.func.from_iso8601_date(value.isoformat())
elif dtype.is_time():
return sa.cast(sa.literal(str(value)), t.get_sqla_type(dtype))
elif dtype.is_interval():
return sa.literal_column(
f"INTERVAL '{value}' {dtype.resolution.upper()}", type_=INTERVAL
)

return _alchemy_literal(t, op)


Expand Down Expand Up @@ -324,6 +338,18 @@ def _array_intersect(t, op):
3,
)


def _interval_from_integer(t, op):
unit = op.unit.short
if unit in ("Y", "Q", "M", "W"):
raise com.UnsupportedOperationError(f"Interval unit {unit!r} not supported")
arg = sa.func.concat(
t.translate(ops.Cast(op.arg, dt.String(nullable=op.arg.dtype.nullable))),
unit.lower(),
)
return sa.type_coerce(sa.func.parse_duration(arg), INTERVAL)


operation_registry.update(
{
# conditional expressions
Expand Down Expand Up @@ -511,6 +537,15 @@ def _array_intersect(t, op):
ops.TimeDelta: _temporal_delta,
ops.DateDelta: _temporal_delta,
ops.TimestampDelta: _temporal_delta,
ops.TimestampAdd: fixed_arity(operator.add, 2),
ops.TimestampSub: fixed_arity(operator.sub, 2),
ops.TimestampDiff: fixed_arity(lambda x, y: sa.type_coerce(x - y, INTERVAL), 2),
ops.DateAdd: fixed_arity(operator.add, 2),
ops.DateSub: fixed_arity(operator.sub, 2),
ops.DateDiff: fixed_arity(lambda x, y: sa.type_coerce(x - y, INTERVAL), 2),
ops.IntervalAdd: fixed_arity(operator.add, 2),
ops.IntervalSubtract: fixed_arity(operator.sub, 2),
ops.IntervalFromInteger: _interval_from_integer,
}
)

Expand Down