Skip to content

Commit

Permalink
REF: back IntervalArray by a single ndarray (pandas-dev#37047)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel committed Oct 12, 2020
1 parent b526620 commit 9cb3723
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 115 deletions.
263 changes: 152 additions & 111 deletions pandas/core/arrays/interval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from operator import le, lt
import textwrap
from typing import TYPE_CHECKING, Optional, Tuple, Union, cast

import numpy as np

Expand All @@ -11,14 +12,17 @@
IntervalMixin,
intervals_to_interval_bounds,
)
from pandas._typing import ArrayLike, Dtype
from pandas.compat.numpy import function as nv
from pandas.util._decorators import Appender

from pandas.core.dtypes.cast import maybe_convert_platform
from pandas.core.dtypes.common import (
is_categorical_dtype,
is_datetime64_any_dtype,
is_dtype_equal,
is_float_dtype,
is_integer,
is_integer_dtype,
is_interval_dtype,
is_list_like,
Expand All @@ -45,6 +49,10 @@
from pandas.core.indexers import check_array_indexer
from pandas.core.indexes.base import ensure_index

if TYPE_CHECKING:
from pandas import Index
from pandas.core.arrays import DatetimeArray, TimedeltaArray

_interval_shared_docs = {}

_shared_docs_kwargs = dict(
Expand Down Expand Up @@ -169,6 +177,17 @@ def __new__(
left = data._left
right = data._right
closed = closed or data.closed

if dtype is None or data.dtype == dtype:
# This path will preserve id(result._combined)
# TODO: could also validate dtype before going to simple_new
combined = data._combined
if copy:
combined = combined.copy()
result = cls._simple_new(combined, closed=closed)
if verify_integrity:
result._validate()
return result
else:

# don't allow scalars
Expand All @@ -186,83 +205,22 @@ def __new__(
)
closed = closed or infer_closed

return cls._simple_new(
left,
right,
closed,
copy=copy,
dtype=dtype,
verify_integrity=verify_integrity,
)
closed = closed or "right"
left, right = _maybe_cast_inputs(left, right, copy, dtype)
combined = _get_combined_data(left, right)
result = cls._simple_new(combined, closed=closed)
if verify_integrity:
result._validate()
return result

@classmethod
def _simple_new(
cls, left, right, closed=None, copy=False, dtype=None, verify_integrity=True
):
def _simple_new(cls, data, closed="right"):
result = IntervalMixin.__new__(cls)

closed = closed or "right"
left = ensure_index(left, copy=copy)
right = ensure_index(right, copy=copy)

if dtype is not None:
# GH 19262: dtype must be an IntervalDtype to override inferred
dtype = pandas_dtype(dtype)
if not is_interval_dtype(dtype):
msg = f"dtype must be an IntervalDtype, got {dtype}"
raise TypeError(msg)
elif dtype.subtype is not None:
left = left.astype(dtype.subtype)
right = right.astype(dtype.subtype)

# coerce dtypes to match if needed
if is_float_dtype(left) and is_integer_dtype(right):
right = right.astype(left.dtype)
elif is_float_dtype(right) and is_integer_dtype(left):
left = left.astype(right.dtype)

if type(left) != type(right):
msg = (
f"must not have differing left [{type(left).__name__}] and "
f"right [{type(right).__name__}] types"
)
raise ValueError(msg)
elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype):
# GH 19016
msg = (
"category, object, and string subtypes are not supported "
"for IntervalArray"
)
raise TypeError(msg)
elif isinstance(left, ABCPeriodIndex):
msg = "Period dtypes are not supported, use a PeriodIndex instead"
raise ValueError(msg)
elif isinstance(left, ABCDatetimeIndex) and str(left.tz) != str(right.tz):
msg = (
"left and right must have the same time zone, got "
f"'{left.tz}' and '{right.tz}'"
)
raise ValueError(msg)

# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array

left = maybe_upcast_datetimelike_array(left)
left = extract_array(left, extract_numpy=True)
right = maybe_upcast_datetimelike_array(right)
right = extract_array(right, extract_numpy=True)

lbase = getattr(left, "_ndarray", left).base
rbase = getattr(right, "_ndarray", right).base
if lbase is not None and lbase is rbase:
# If these share data, then setitem could corrupt our IA
right = right.copy()

result._left = left
result._right = right
result._combined = data
result._left = data[:, 0]
result._right = data[:, 1]
result._closed = closed
if verify_integrity:
result._validate()
return result

@classmethod
Expand Down Expand Up @@ -397,10 +355,16 @@ def from_breaks(cls, breaks, closed="right", copy=False, dtype=None):
def from_arrays(cls, left, right, closed="right", copy=False, dtype=None):
left = maybe_convert_platform_interval(left)
right = maybe_convert_platform_interval(right)
if len(left) != len(right):
raise ValueError("left and right must have the same length")

return cls._simple_new(
left, right, closed, copy=copy, dtype=dtype, verify_integrity=True
)
closed = closed or "right"
left, right = _maybe_cast_inputs(left, right, copy, dtype)
combined = _get_combined_data(left, right)

result = cls._simple_new(combined, closed)
result._validate()
return result

_interval_shared_docs["from_tuples"] = textwrap.dedent(
"""
Expand Down Expand Up @@ -506,19 +470,6 @@ def _validate(self):
msg = "left side of interval must be <= right side"
raise ValueError(msg)

def _shallow_copy(self, left, right):
"""
Return a new IntervalArray with the replacement attributes
Parameters
----------
left : Index
Values to be used for the left-side of the intervals.
right : Index
Values to be used for the right-side of the intervals.
"""
return self._simple_new(left, right, closed=self.closed, verify_integrity=False)

# ---------------------------------------------------------------------
# Descriptive

Expand Down Expand Up @@ -546,18 +497,20 @@ def __len__(self) -> int:

def __getitem__(self, key):
key = check_array_indexer(self, key)
left = self._left[key]
right = self._right[key]

if not isinstance(left, (np.ndarray, ExtensionArray)):
# scalar
if is_scalar(left) and isna(left):
result = self._combined[key]

if is_integer(key):
left, right = result[0], result[1]
if isna(left):
return self._fill_value
return Interval(left, right, self.closed)
if np.ndim(left) > 1:

# TODO: need to watch out for incorrectly-reducing getitem
if np.ndim(result) > 2:
# GH#30588 multi-dimensional indexer disallowed
raise ValueError("multi-dimensional indexing not allowed")
return self._shallow_copy(left, right)
return type(self)._simple_new(result, closed=self.closed)

def __setitem__(self, key, value):
value_left, value_right = self._validate_setitem_value(value)
Expand Down Expand Up @@ -651,7 +604,8 @@ def fillna(self, value=None, method=None, limit=None):

left = self.left.fillna(value=value_left)
right = self.right.fillna(value=value_right)
return self._shallow_copy(left, right)
combined = _get_combined_data(left, right)
return type(self)._simple_new(combined, closed=self.closed)

def astype(self, dtype, copy=True):
"""
Expand Down Expand Up @@ -693,7 +647,9 @@ def astype(self, dtype, copy=True):
f"Cannot convert {self.dtype} to {dtype}; subtypes are incompatible"
)
raise TypeError(msg) from err
return self._shallow_copy(new_left, new_right)
# TODO: do astype directly on self._combined
combined = _get_combined_data(new_left, new_right)
return type(self)._simple_new(combined, closed=self.closed)
elif is_categorical_dtype(dtype):
return Categorical(np.asarray(self))
elif isinstance(dtype, StringDtype):
Expand Down Expand Up @@ -734,9 +690,11 @@ def _concat_same_type(cls, to_concat):
raise ValueError("Intervals must all be closed on the same side.")
closed = closed.pop()

# TODO: will this mess up on dt64tz?
left = np.concatenate([interval.left for interval in to_concat])
right = np.concatenate([interval.right for interval in to_concat])
return cls._simple_new(left, right, closed=closed, copy=False)
combined = _get_combined_data(left, right) # TODO: 1-stage concat
return cls._simple_new(combined, closed=closed)

def copy(self):
"""
Expand All @@ -746,11 +704,8 @@ def copy(self):
-------
IntervalArray
"""
left = self._left.copy()
right = self._right.copy()
closed = self.closed
# TODO: Could skip verify_integrity here.
return type(self).from_arrays(left, right, closed=closed)
combined = self._combined.copy()
return type(self)._simple_new(combined, closed=self.closed)

def isna(self) -> np.ndarray:
return isna(self._left)
Expand Down Expand Up @@ -843,7 +798,8 @@ def take(self, indices, allow_fill=False, fill_value=None, axis=None, **kwargs):
self._right, indices, allow_fill=allow_fill, fill_value=fill_right
)

return self._shallow_copy(left_take, right_take)
combined = _get_combined_data(left_take, right_take)
return type(self)._simple_new(combined, closed=self.closed)

def _validate_listlike(self, value):
# list-like of intervals
Expand Down Expand Up @@ -1170,10 +1126,7 @@ def set_closed(self, closed):
if closed not in VALID_CLOSED:
msg = f"invalid option for 'closed': {closed}"
raise ValueError(msg)

return type(self)._simple_new(
left=self._left, right=self._right, closed=closed, verify_integrity=False
)
return type(self)._simple_new(self._combined, closed=closed)

_interval_shared_docs[
"is_non_overlapping_monotonic"
Expand Down Expand Up @@ -1314,9 +1267,8 @@ def to_tuples(self, na_tuple=True):
@Appender(_extension_array_shared_docs["repeat"] % _shared_docs_kwargs)
def repeat(self, repeats, axis=None):
nv.validate_repeat(tuple(), dict(axis=axis))
left_repeat = self.left.repeat(repeats)
right_repeat = self.right.repeat(repeats)
return self._shallow_copy(left=left_repeat, right=right_repeat)
combined = self._combined.repeat(repeats, 0)
return type(self)._simple_new(combined, closed=self.closed)

_interval_shared_docs["contains"] = textwrap.dedent(
"""
Expand Down Expand Up @@ -1399,3 +1351,92 @@ def maybe_convert_platform_interval(values):
values = np.asarray(values)

return maybe_convert_platform(values)


def _maybe_cast_inputs(
left_orig: Union["Index", ArrayLike],
right_orig: Union["Index", ArrayLike],
copy: bool,
dtype: Optional[Dtype],
) -> Tuple["Index", "Index"]:
left = ensure_index(left_orig, copy=copy)
right = ensure_index(right_orig, copy=copy)

if dtype is not None:
# GH#19262: dtype must be an IntervalDtype to override inferred
dtype = pandas_dtype(dtype)
if not is_interval_dtype(dtype):
msg = f"dtype must be an IntervalDtype, got {dtype}"
raise TypeError(msg)
dtype = cast(IntervalDtype, dtype)
if dtype.subtype is not None:
left = left.astype(dtype.subtype)
right = right.astype(dtype.subtype)

# coerce dtypes to match if needed
if is_float_dtype(left) and is_integer_dtype(right):
right = right.astype(left.dtype)
elif is_float_dtype(right) and is_integer_dtype(left):
left = left.astype(right.dtype)

if type(left) != type(right):
msg = (
f"must not have differing left [{type(left).__name__}] and "
f"right [{type(right).__name__}] types"
)
raise ValueError(msg)
elif is_categorical_dtype(left.dtype) or is_string_dtype(left.dtype):
# GH#19016
msg = (
"category, object, and string subtypes are not supported "
"for IntervalArray"
)
raise TypeError(msg)
elif isinstance(left, ABCPeriodIndex):
msg = "Period dtypes are not supported, use a PeriodIndex instead"
raise ValueError(msg)
elif isinstance(left, ABCDatetimeIndex) and not is_dtype_equal(
left.dtype, right.dtype
):
left_arr = cast("DatetimeArray", left._data)
right_arr = cast("DatetimeArray", right._data)
msg = (
"left and right must have the same time zone, got "
f"'{left_arr.tz}' and '{right_arr.tz}'"
)
raise ValueError(msg)

return left, right


def _get_combined_data(
left: Union["Index", ArrayLike], right: Union["Index", ArrayLike]
) -> Union[np.ndarray, "DatetimeArray", "TimedeltaArray"]:
# For dt64/td64 we want DatetimeArray/TimedeltaArray instead of ndarray
from pandas.core.ops.array_ops import maybe_upcast_datetimelike_array

left = maybe_upcast_datetimelike_array(left)
left = extract_array(left, extract_numpy=True)
right = maybe_upcast_datetimelike_array(right)
right = extract_array(right, extract_numpy=True)

lbase = getattr(left, "_ndarray", left).base
rbase = getattr(right, "_ndarray", right).base
if lbase is not None and lbase is rbase:
# If these share data, then setitem could corrupt our IA
right = right.copy()

if isinstance(left, np.ndarray):
assert isinstance(right, np.ndarray) # for mypy
combined = np.concatenate(
[left.reshape(-1, 1), right.reshape(-1, 1)],
axis=1,
)
else:
left = cast(Union["DatetimeArray", "TimedeltaArray"], left)
right = cast(Union["DatetimeArray", "TimedeltaArray"], right)
combined = type(left)._concat_same_type(
[left.reshape(-1, 1), right.reshape(-1, 1)],
axis=1,
)
return combined
Loading

0 comments on commit 9cb3723

Please sign in to comment.