From 49f09cc0819444fedd71a30743d1f9842dd1ca15 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 2 Mar 2018 15:20:28 -0600 Subject: [PATCH] API: Added ExtensionArray constructor from scalars (#19913) --- pandas/core/arrays/base.py | 20 ++++++++++++++++++++ pandas/core/arrays/categorical.py | 4 ++++ pandas/tests/extension/base/constructors.py | 5 +++++ pandas/tests/extension/decimal/array.py | 4 ++++ pandas/tests/extension/json/array.py | 10 ++++++++-- 5 files changed, 41 insertions(+), 2 deletions(-) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index cec881394a021..37074b563efbd 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -18,6 +18,7 @@ class ExtensionArray(object): The interface includes the following abstract methods that must be implemented by subclasses: + * _constructor_from_sequence * __getitem__ * __len__ * dtype @@ -56,6 +57,25 @@ class ExtensionArray(object): # '_typ' is for pandas.core.dtypes.generic.ABCExtensionArray. # Don't override this. _typ = 'extension' + + # ------------------------------------------------------------------------ + # Constructors + # ------------------------------------------------------------------------ + @classmethod + def _constructor_from_sequence(cls, scalars): + """Construct a new ExtensionArray from a sequence of scalars. + + Parameters + ---------- + scalars : Sequence + Each element will be an instance of the scalar type for this + array, ``cls.dtype.type``. + Returns + ------- + ExtensionArray + """ + raise AbstractMethodError(cls) + # ------------------------------------------------------------------------ # Must be a Sequence # ------------------------------------------------------------------------ diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index c6eeabf0148d0..e23dc3b3e5b89 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -364,6 +364,10 @@ def __init__(self, values, categories=None, ordered=None, dtype=None, self._dtype = self._dtype.update_dtype(dtype) self._codes = coerce_indexer_dtype(codes, dtype.categories) + @classmethod + def _constructor_from_sequence(cls, scalars): + return cls(scalars) + @property def categories(self): """The categories of this categorical. diff --git a/pandas/tests/extension/base/constructors.py b/pandas/tests/extension/base/constructors.py index 2d5d747aec5a7..4ac04d71338fd 100644 --- a/pandas/tests/extension/base/constructors.py +++ b/pandas/tests/extension/base/constructors.py @@ -9,6 +9,11 @@ class BaseConstructorsTests(BaseExtensionTests): + def test_array_from_scalars(self, data): + scalars = [data[0], data[1], data[2]] + result = data._constructor_from_sequence(scalars) + assert isinstance(result, type(data)) + def test_series_constructor(self, data): result = pd.Series(data) assert result.dtype == data.dtype diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 8b2eaadeca99e..736556e4be20d 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -32,6 +32,10 @@ def __init__(self, values): self.values = values + @classmethod + def _constructor_from_sequence(cls, scalars): + return cls(scalars) + def __getitem__(self, item): if isinstance(item, numbers.Integral): return self.values[item] diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 90aac93c68f64..21addf9d1549f 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -33,11 +33,17 @@ def __init__(self, values): raise TypeError self.data = values + @classmethod + def _constructor_from_sequence(cls, scalars): + return cls(scalars) + def __getitem__(self, item): if isinstance(item, numbers.Integral): return self.data[item] elif isinstance(item, np.ndarray) and item.dtype == 'bool': - return type(self)([x for x, m in zip(self, item) if m]) + return self._constructor_from_sequence([ + x for x, m in zip(self, item) if m + ]) else: return type(self)(self.data[item]) @@ -77,7 +83,7 @@ def isna(self): def take(self, indexer, allow_fill=True, fill_value=None): output = [self.data[loc] if loc != -1 else self._na_value for loc in indexer] - return type(self)(output) + return self._constructor_from_sequence(output) def copy(self, deep=False): return type(self)(self.data[:])