Skip to content

Commit

Permalink
ENH: stat reductions for timedelta and datetime add
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Jevnik committed Jan 12, 2016
1 parent ad67522 commit 887b241
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 8 deletions.
14 changes: 14 additions & 0 deletions blaze/compute/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
from __future__ import absolute_import, division, print_function

from datetime import timedelta
import fnmatch
import itertools
from distutils.version import LooseVersion
Expand Down Expand Up @@ -218,9 +219,22 @@ def compute_up(t, s, **kwargs):

@dispatch((std, var), (Series, SeriesGroupBy))
def compute_up(t, s, **kwargs):
measure = t.schema.measure
is_timedelta = isinstance(
getattr(measure, 'ty', measure),
datashape.TimeDelta,
)
if is_timedelta:
# part 1 of 2 to work around the fact that pandas does not have
# timedelta var or std: cast to a double
s = s.astype('timedelta64[s]').astype('int64')
result = get_scalar(getattr(s, t.symbol)(ddof=t.unbiased))
if t.keepdims:
result = Series([result], name=s.name)
if is_timedelta:
# part 2 of 2 to work around the fact that postgres does not have
# timedelta var or std: cast back from seconds by creating a timedelta
result = timedelta(seconds=result)
return result


Expand Down
15 changes: 14 additions & 1 deletion blaze/compute/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""
from __future__ import absolute_import, division, print_function

import datetime
import itertools
from itertools import chain

Expand Down Expand Up @@ -597,11 +598,23 @@ def compute_up(t, s, **kwargs):

@dispatch((std, var), sql.elements.ColumnElement)
def compute_up(t, s, **kwargs):
measure = t.schema.measure
is_timedelta = isinstance(getattr(measure, 'ty', measure), TimeDelta)
if is_timedelta:
# part 1 of 2 to work around the fact that postgres does not have
# timedelta var or std: cast to a double which is seconds
s = sa.extract('epoch', s)
if t.axis != (0,):
raise ValueError('axis not equal to 0 not defined for SQL reductions')
funcname = 'samp' if t.unbiased else 'pop'
full_funcname = '%s_%s' % (prefixes[type(t)], funcname)
return getattr(sa.func, full_funcname)(s).label(t._name)
ret = getattr(sa.func, full_funcname)(s)
if is_timedelta:
# part 2 of 2 to work around the fact that postgres does not have
# timedelta var or std: cast back from seconds by
# multiplying by a 1 second timedelta
ret = ret * datetime.timedelta(seconds=1)
return ret.label(t._name)


@dispatch(count, Selectable)
Expand Down
10 changes: 10 additions & 0 deletions blaze/compute/tests/test_pandas_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,16 @@ def test_timedelta_arith():
).all()


@pytest.mark.parametrize('func,expected', (
('var', timedelta(0, 8, 250000)),
('std', timedelta(0, 2, 872281)),
))
def test_timedelta_stat_reduction(func, expected):
deltas = pd.Series([timedelta(seconds=n) for n in range(10)])
sym = symbol('s', discover(deltas))
assert compute(getattr(sym, func)(), deltas) == expected


def test_coerce_series():
s = pd.Series(list('123'), name='a')
t = symbol('t', discover(s))
Expand Down
26 changes: 26 additions & 0 deletions blaze/compute/tests/test_postgresql_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,20 @@ def sql_with_dts(url):
drop(t)


@pytest.yield_fixture
def sql_with_timedeltas(url):
try:
t = resource(url % next(names), dshape='var * {N: timedelta}')
except sa.exc.OperationalError as e:
pytest.skip(str(e))
else:
t = odo([(timedelta(seconds=n),) for n in range(10)], t)
try:
yield t
finally:
drop(t)


@pytest.yield_fixture
def sql_two_tables(url):
dshape = 'var * {a: int32}'
Expand Down Expand Up @@ -309,6 +323,18 @@ def test_timedelta_arith(sql_with_dts):
).all()


@pytest.mark.parametrize('func', ('var', 'std'))
def test_timedelta_stat_reduction(sql_with_timedeltas, func):
sym = symbol('s', discover(sql_with_timedeltas))
expr = getattr(sym.N, func)()

deltas = pd.Series([timedelta(seconds=n) for n in range(10)])
expected = timedelta(
seconds=getattr(deltas.astype('int64') / 1e9, func)(ddof=expr.unbiased)
)
assert odo(compute(expr, sql_with_timedeltas), timedelta) == expected


def test_coerce_bool_and_sum(sql):
n = sql.name
t = symbol(n, discover(sql))
Expand Down
59 changes: 52 additions & 7 deletions blaze/expr/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,24 @@
from toolz import first
import numpy as np
import pandas as pd
from datashape import dshape, var, DataShape, Option, datetime_, timedelta_
from dateutil.parser import parse as dt_parse
from datashape import (
DataShape,
DateTime,
Option,
TimeDelta,
coretypes as ct,
datetime_,
discover,
dshape,
optionify,
promote,
timedelta_,
unsigned,
var,
)
from datashape.predicates import isscalar, isboolean, isnumeric, isdatelike
from datashape import coretypes as ct, discover, unsigned, promote, optionify
from dateutil.parser import parse as dt_parse


from .core import parenthesize, eval_str
from .expressions import Expr, shape, ElemWise
Expand Down Expand Up @@ -159,6 +173,30 @@ class Add(Arithmetic):
symbol = '+'
op = operator.add

@property
def _dtype(self):
lmeasure = discover(self.lhs).measure
lty = getattr(lmeasure, 'ty', lmeasure)
rmeasure = discover(self.rhs).measure
rty = getattr(rmeasure, 'ty', rmeasure)
if lmeasure == datetime_ and rmeasure == datetime_:
raise TypeError('cannot add datetime to datetime')

l_is_datetime = lty == datetime_
if l_is_datetime or rty == datetime_:
if l_is_datetime:
other = rty
else:
other = lty
if isinstance(other, TimeDelta):
return optionify(lmeasure, rmeasure, datetime_)
else:
raise TypeError(
'can only add timedeltas to datetimes',
)

return super(Add, self)._dtype


class Mult(Arithmetic):
symbol = '*'
Expand All @@ -178,11 +216,18 @@ class Sub(Arithmetic):
@property
def _dtype(self):
lmeasure = discover(self.lhs).measure
lty = getattr(lmeasure, 'ty', lmeasure)
rmeasure = discover(self.rhs).measure
if (getattr(lmeasure, 'ty', lmeasure) == datetime_ and
getattr(rmeasure, 'ty', rmeasure) == datetime_):

return optionify(lmeasure, rmeasure, timedelta_)
rty = getattr(rmeasure, 'ty', rmeasure)
if lty == datetime_:
if isinstance(rty, DateTime):
return optionify(lmeasure, rmeasure, timedelta_)
if isinstance(rty, TimeDelta):
return optionify(lmeasure, rmeasure, datetime_)
else:
raise TypeError(
'can only subtract timedelta or datetime from datetime',
)

return super(Sub, self)._dtype

Expand Down
13 changes: 13 additions & 0 deletions blaze/expr/tests/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
h = symbol('h', 'datetime')
i = symbol('i', '?datetime')
j = symbol('j', '?datetime')
k = symbol('k', 'timedelta')
l = symbol('l', '?timedelta')
optionals = {d, e, f}


Expand Down Expand Up @@ -115,3 +117,14 @@ def test_datetime_sub():
assert (g - h).dshape == dshape('timedelta')
assert (g - i).dshape == dshape('?timedelta')
assert (i - j).dshape == dshape('?timedelta')
assert (g - k).dshape == dshape('datetime')
assert (g - l).dshape == dshape('?datetime')
assert (i - k).dshape == dshape('?datetime')
assert (i - l).dshape == dshape('?datetime')


def test_datetime_add():
assert (g + k).dshape == dshape('datetime')
assert (g + l).dshape == dshape('?datetime')
assert (i + k).dshape == dshape('?datetime')
assert (i + l).dshape == dshape('?datetime')
16 changes: 16 additions & 0 deletions docs/source/whatsnew/0.9.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ Improved Expressions
New Backends
~~~~~~~~~~~~

None

Improved Backends
~~~~~~~~~~~~~~~~~

Expand All @@ -27,12 +29,26 @@ Improved Backends
Experimental Features
~~~~~~~~~~~~~~~~~~~~~

None

API Changes
~~~~~~~~~~~

None

Bug Fixes
~~~~~~~~~

* Fixed a type issue where ``datetime - datetime :: datetime`` instead of
``timedelta`` (:issue:`1382`).
* Fixed a bug that prevented :func:`~blaze.expr.expressions.coerce` to fail when
computing against ``ColumnElement``\s. This would break ``coerce`` for many
sql operations (:issue:`1382`).
* Fixed reductions over ``timedelta`` returning ``float`` (:issue:`1382`).
* Fixed interactive repr for ``timedelta`` not coercing to ``timedelta``
objects (:issue:`1382`).

Miscellaneous
~~~~~~~~~~~~~

None

0 comments on commit 887b241

Please sign in to comment.