diff --git a/holoviews/core/data/grid.py b/holoviews/core/data/grid.py index 97ea668a1a..983433dbb3 100644 --- a/holoviews/core/data/grid.py +++ b/holoviews/core/data/grid.py @@ -419,6 +419,8 @@ def groupby(cls, dataset, dim_names, container_type, group_type, **kwargs): @classmethod def key_select_mask(cls, dataset, values, ind): + if util.pd and values.dtype.kind == 'M': + ind = util.parse_datetime_selection(ind) if isinstance(ind, tuple): ind = slice(*ind) if isinstance(ind, get_array_types()): diff --git a/holoviews/core/data/interface.py b/holoviews/core/data/interface.py index ff5cd25cd5..8f8945a256 100644 --- a/holoviews/core/data/interface.py +++ b/holoviews/core/data/interface.py @@ -291,30 +291,35 @@ def select_mask(cls, dataset, selection): have been selected. """ mask = np.ones(len(dataset), dtype=np.bool) - for dim, k in selection.items(): - if isinstance(k, tuple): - k = slice(*k) + for dim, sel in selection.items(): + if isinstance(sel, tuple): + sel = slice(*sel) arr = cls.values(dataset, dim) - if isinstance(k, slice): + if util.isdatetime(arr) and util.pd: + try: + sel = util.parse_datetime_selection(sel) + except: + pass + if isinstance(sel, slice): with warnings.catch_warnings(): warnings.filterwarnings('ignore', r'invalid value encountered') - if k.start is not None: - mask &= k.start <= arr - if k.stop is not None: - mask &= arr < k.stop - elif isinstance(k, (set, list)): + if sel.start is not None: + mask &= sel.start <= arr + if sel.stop is not None: + mask &= arr < sel.stop + elif isinstance(sel, (set, list)): iter_slcs = [] - for ik in k: + for ik in sel: with warnings.catch_warnings(): warnings.filterwarnings('ignore', r'invalid value encountered') iter_slcs.append(arr == ik) mask &= np.logical_or.reduce(iter_slcs) - elif callable(k): - mask &= k(arr) + elif callable(sel): + mask &= sel(arr) else: - index_mask = arr == k + index_mask = arr == sel if dataset.ndims == 1 and np.sum(index_mask) == 0: - data_index = np.argmin(np.abs(arr - k)) + data_index = np.argmin(np.abs(arr - sel)) mask = np.zeros(len(dataset), dtype=np.bool) mask[data_index] = True else: diff --git a/holoviews/core/util.py b/holoviews/core/util.py index 0027e04785..69b5976bba 100644 --- a/holoviews/core/util.py +++ b/holoviews/core/util.py @@ -1910,6 +1910,31 @@ def date_range(start, end, length, time_unit='us'): return start+step/2.+np.arange(length)*step +def parse_datetime(date): + """ + Parses dates specified as string or integer or pandas Timestamp + """ + if pd is None: + raise ImportError('Parsing dates from strings requires pandas') + return pd.to_datetime(date).to_datetime64() + + +def parse_datetime_selection(sel): + """ + Parses string selection specs as datetimes. + """ + if isinstance(sel, basestring) or isdatetime(sel): + sel = parse_datetime(sel) + if isinstance(sel, slice): + if isinstance(sel.start, basestring) or isdatetime(sel.start): + sel = slice(parse_datetime(sel.start), sel.stop) + if isinstance(sel.stop, basestring) or isdatetime(sel.stop): + sel = slice(sel.start, parse_datetime(sel.stop)) + if isinstance(sel, (set, list)): + sel = [parse_datetime(v) if isinstance(v, basestring) else v for v in sel] + return sel + + def dt_to_int(value, time_unit='us'): """ Converts a datetime type to an integer with the supplied time unit. diff --git a/holoviews/element/chart.py b/holoviews/element/chart.py index 0f55363034..e5127091fb 100644 --- a/holoviews/element/chart.py +++ b/holoviews/element/chart.py @@ -45,21 +45,7 @@ class Chart(Dataset, Element2D): __abstract = True def __getitem__(self, index): - sliced = super(Chart, self).__getitem__(index) - if not isinstance(sliced, Chart): - return sliced - - if not isinstance(index, tuple): index = (index,) - ndims = len(self.extents)//2 - lower_bounds, upper_bounds = [None]*ndims, [None]*ndims - for i, slc in enumerate(index[:ndims]): - if isinstance(slc, slice): - lbound = self.extents[i] - ubound = self.extents[ndims:][i] - lower_bounds[i] = lbound if slc.start is None else slc.start - upper_bounds[i] = ubound if slc.stop is None else slc.stop - sliced.extents = tuple(lower_bounds+upper_bounds) - return sliced + return super(Chart, self).__getitem__(index) class Scatter(Chart): @@ -69,7 +55,7 @@ class Scatter(Chart): location along the x-axis while the first value dimension represents the location of the point along the y-axis. """ - + group = param.String(default='Scatter', constant=True) diff --git a/holoviews/tests/element/testelementselect.py b/holoviews/tests/element/testelementselect.py index 4530587119..5331b393e1 100644 --- a/holoviews/tests/element/testelementselect.py +++ b/holoviews/tests/element/testelementselect.py @@ -1,8 +1,14 @@ from itertools import product +import datetime as dt import numpy as np +try: + import pandas as pd +except ImportError: + pd = None + from holoviews.core import HoloMap -from holoviews.element import Image, Contours +from holoviews.element import Image, Contours, Curve from holoviews.element.comparison import ComparisonTestCase class DimensionedSelectionTest(ComparisonTestCase): @@ -11,6 +17,11 @@ def setUp(self): self.img_fn = lambda: Image(np.random.rand(10, 10)) self.contour_fn = lambda: Contours([np.random.rand(10, 2) for i in range(2)]) + self.datetime_fn = lambda: Curve(( + [dt.datetime(2000,1,1), dt.datetime(2000,1,2), + dt.datetime(2000,1,3)], + np.random.rand(3) + ), 'time', 'x') params = [list(range(3)) for i in range(2)] self.sanitized_map = HoloMap({i: Image(i*np.random.rand(10,10)) for i in range(1,10)}, kdims=['A B']) @@ -85,3 +96,17 @@ def test_duplicate_dim_select(self): def test_overlap_select(self): selection = self.overlap_layout.select(Default=(6, None)) self.assertEqual(selection, self.overlap1.clone(shared_data=False) + self.overlap2[6:]) + + def test_datetime_select(self): + s, e = '1999-12-31', '2000-1-2' + curve = self.datetime_fn() + overlay = curve * self.datetime_fn() + for el in [curve, overlay]: + self.assertEqual(el.select(time=(s, e)), el[s:e]) + self.assertEqual(el.select(time= + (dt.datetime(1999, 12, 31), dt.datetime(2000, 1, 2))), el[s:e] + ) + if pd: + self.assertEqual(el.select( + time=(pd.Timestamp(s), pd.Timestamp(e)) + ), el[pd.Timestamp(s):pd.Timestamp(e)])