Skip to content

Commit

Permalink
Merge 2d22429 into 1cf01c3
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed May 8, 2019
2 parents 1cf01c3 + 2d22429 commit b4ffa08
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 31 deletions.
2 changes: 2 additions & 0 deletions holoviews/core/data/grid.py
Expand Up @@ -419,6 +419,8 @@ def groupby(cls, dataset, dim_names, container_type, group_type, **kwargs):

@classmethod
def key_select_mask(cls, dataset, values, ind):
if util.pd and values.dtype.kind == 'M':
ind = util.parse_datetime_selection(ind)
if isinstance(ind, tuple):
ind = slice(*ind)
if isinstance(ind, get_array_types()):
Expand Down
33 changes: 19 additions & 14 deletions holoviews/core/data/interface.py
Expand Up @@ -283,30 +283,35 @@ def select_mask(cls, dataset, selection):
have been selected.
"""
mask = np.ones(len(dataset), dtype=np.bool)
for dim, k in selection.items():
if isinstance(k, tuple):
k = slice(*k)
for dim, sel in selection.items():
if isinstance(sel, tuple):
sel = slice(*sel)
arr = cls.values(dataset, dim)
if isinstance(k, slice):
if util.isdatetime(arr) and util.pd:
try:
sel = util.parse_datetime_selection(sel)
except:
pass
if isinstance(sel, slice):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', r'invalid value encountered')
if k.start is not None:
mask &= k.start <= arr
if k.stop is not None:
mask &= arr < k.stop
elif isinstance(k, (set, list)):
if sel.start is not None:
mask &= sel.start <= arr
if sel.stop is not None:
mask &= arr < sel.stop
elif isinstance(sel, (set, list)):
iter_slcs = []
for ik in k:
for ik in sel:
with warnings.catch_warnings():
warnings.filterwarnings('ignore', r'invalid value encountered')
iter_slcs.append(arr == ik)
mask &= np.logical_or.reduce(iter_slcs)
elif callable(k):
mask &= k(arr)
elif callable(sel):
mask &= sel(arr)
else:
index_mask = arr == k
index_mask = arr == sel
if dataset.ndims == 1 and np.sum(index_mask) == 0:
data_index = np.argmin(np.abs(arr - k))
data_index = np.argmin(np.abs(arr - sel))
mask = np.zeros(len(dataset), dtype=np.bool)
mask[data_index] = True
else:
Expand Down
25 changes: 25 additions & 0 deletions holoviews/core/util.py
Expand Up @@ -1916,6 +1916,31 @@ def date_range(start, end, length, time_unit='us'):
return start+step/2.+np.arange(length)*step


def parse_datetime(date):
"""
Parses dates specified as string or integer or pandas Timestamp
"""
if pd is None:
raise ImportError('Parsing dates from strings requires pandas')
return pd.to_datetime(date).to_datetime64()


def parse_datetime_selection(sel):
"""
Parses string selection specs as datetimes.
"""
if isinstance(sel, basestring) or isdatetime(sel):
sel = parse_datetime(sel)
if isinstance(sel, slice):
if isinstance(sel.start, basestring) or isdatetime(sel.start):
sel = slice(parse_datetime(sel.start), sel.stop)
if isinstance(sel.stop, basestring) or isdatetime(sel.stop):
sel = slice(sel.start, parse_datetime(sel.stop))
if isinstance(sel, (set, list)):
sel = [parse_datetime(v) if isinstance(v, basestring) else v for v in sel]
return sel


def dt_to_int(value, time_unit='us'):
"""
Converts a datetime type to an integer with the supplied time unit.
Expand Down
18 changes: 2 additions & 16 deletions holoviews/element/chart.py
Expand Up @@ -45,21 +45,7 @@ class Chart(Dataset, Element2D):
__abstract = True

def __getitem__(self, index):
sliced = super(Chart, self).__getitem__(index)
if not isinstance(sliced, Chart):
return sliced

if not isinstance(index, tuple): index = (index,)
ndims = len(self.extents)//2
lower_bounds, upper_bounds = [None]*ndims, [None]*ndims
for i, slc in enumerate(index[:ndims]):
if isinstance(slc, slice):
lbound = self.extents[i]
ubound = self.extents[ndims:][i]
lower_bounds[i] = lbound if slc.start is None else slc.start
upper_bounds[i] = ubound if slc.stop is None else slc.stop
sliced.extents = tuple(lower_bounds+upper_bounds)
return sliced
return super(Chart, self).__getitem__(index)


class Scatter(Chart):
Expand All @@ -69,7 +55,7 @@ class Scatter(Chart):
location along the x-axis while the first value dimension
represents the location of the point along the y-axis.
"""

group = param.String(default='Scatter', constant=True)


Expand Down
27 changes: 26 additions & 1 deletion holoviews/tests/element/testelementselect.py
@@ -1,8 +1,14 @@
from itertools import product
import datetime as dt
import numpy as np

try:
import pandas as pd
except ImportError:
pd = None

from holoviews.core import HoloMap
from holoviews.element import Image, Contours
from holoviews.element import Image, Contours, Curve
from holoviews.element.comparison import ComparisonTestCase

class DimensionedSelectionTest(ComparisonTestCase):
Expand All @@ -11,6 +17,11 @@ def setUp(self):
self.img_fn = lambda: Image(np.random.rand(10, 10))
self.contour_fn = lambda: Contours([np.random.rand(10, 2)
for i in range(2)])
self.datetime_fn = lambda: Curve((
[dt.datetime(2000,1,1), dt.datetime(2000,1,2),
dt.datetime(2000,1,3)],
np.random.rand(3)
), 'time', 'x')
params = [list(range(3)) for i in range(2)]
self.sanitized_map = HoloMap({i: Image(i*np.random.rand(10,10))
for i in range(1,10)}, kdims=['A B'])
Expand Down Expand Up @@ -85,3 +96,17 @@ def test_duplicate_dim_select(self):
def test_overlap_select(self):
selection = self.overlap_layout.select(Default=(6, None))
self.assertEqual(selection, self.overlap1.clone(shared_data=False) + self.overlap2[6:])

def test_datetime_select(self):
s, e = '1999-12-31', '2000-1-2'
curve = self.datetime_fn()
overlay = curve * self.datetime_fn()
for el in [curve, overlay]:
self.assertEqual(el.select(time=(s, e)), el[s:e])
self.assertEqual(el.select(time=
(dt.datetime(1999, 12, 31), dt.datetime(2000, 1, 2))), el[s:e]
)
if pd:
self.assertEqual(el.select(
time=(pd.Timestamp(s), pd.Timestamp(e))
), el[pd.Timestamp(s):pd.Timestamp(e)])

0 comments on commit b4ffa08

Please sign in to comment.