Skip to content

Commit

Permalink
fix(snowflake): implement working TimestampNow
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Jan 19, 2023
1 parent 57b1dd8 commit 42d95b0
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 4 deletions.
5 changes: 3 additions & 2 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,6 @@ def variance_compiler(t, op):
ops.DayOfWeekName: fixed_arity(
lambda arg: sa.func.trim(sa.func.to_char(arg, 'Day')), 1
),
# now is in the timezone of the server, but we want UTC
ops.TimestampNow: lambda *_: sa.func.timezone('UTC', sa.func.now()),
ops.TimeFromHMS: fixed_arity(sa.func.make_time, 3),
ops.CumulativeAll: unary(sa.func.bool_and),
ops.CumulativeAny: unary(sa.func.bool_or),
Expand All @@ -582,5 +580,8 @@ def variance_compiler(t, op):
ops.Mode: _mode,
ops.Quantile: _quantile,
ops.MultiQuantile: _quantile,
ops.TimestampNow: lambda t, op: sa.literal_column(
"CURRENT_TIMESTAMP", type_=t.get_sqla_type(op.output_dtype)
),
}
)
17 changes: 17 additions & 0 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,23 @@ def do_connect(
)
)

@contextlib.contextmanager
def begin(self):
with super().begin() as bind:
previous_timezone = (
bind.execute(sa.text("SHOW PARAMETERS LIKE 'TIMEZONE' IN SESSION"))
.mappings()
.fetchone()
.value
)
bind.execute(sa.text("ALTER SESSION SET TIMEZONE = 'UTC'"))
try:
yield bind
finally:
bind.execute(
sa.text(f"ALTER SESSION SET TIMEZONE = {previous_timezone!r}")
)

def _get_sqla_table(
self, name: str, schema: str | None = None, **_: Any
) -> sa.Table:
Expand Down
4 changes: 2 additions & 2 deletions ibis/backends/tests/test_temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ def test_day_of_week_column_group_by(
backend.assert_frame_equal(result, expected, check_dtype=False)


@pytest.mark.notimpl(["datafusion", "snowflake", "mssql"])
@pytest.mark.notimpl(["datafusion", "mssql"])
def test_now(con):
expr = ibis.now()
result = con.execute(expr.name("tmp"))
Expand All @@ -762,7 +762,7 @@ def test_now(con):


@pytest.mark.notimpl(["dask"], reason="Limit #2553")
@pytest.mark.notimpl(["datafusion", "snowflake", "polars"])
@pytest.mark.notimpl(["datafusion", "polars"])
def test_now_from_projection(alltypes):
n = 5
expr = alltypes[[ibis.now().name('ts')]].limit(n)
Expand Down

0 comments on commit 42d95b0

Please sign in to comment.