forked from statsmodels/statsmodels
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ENH: Add array_like function to simplify imput checking
Add a lightweight check to simplify input validation and standardization
- Loading branch information
Showing
22 changed files
with
681 additions
and
265 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .validation import array_like, PandasWrapper | ||
|
||
|
||
__all__ = ['array_like', 'PandasWrapper'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from functools import wraps | ||
|
||
import numpy as np | ||
|
||
import statsmodels.tools.validation.validation as v | ||
|
||
|
||
def array_like(pos, name, dtype=np.double, ndim=None, maxdim=None, | ||
shape=None, order='C', contiguous=False): | ||
def inner(func): | ||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
if pos < len(args): | ||
arg = args[pos] | ||
arg = v.array_like(arg, name, dtype, ndim, maxdim, shape, | ||
order, contiguous) | ||
if pos == 0: | ||
args = (arg,) + args[1:] | ||
else: | ||
args = args[:pos] + (arg,) + args[pos + 1:] | ||
else: | ||
arg = kwargs[name] | ||
arg = v.array_like(arg, name, dtype, ndim, maxdim, shape, | ||
order, contiguous) | ||
kwargs[name] = arg | ||
|
||
return func(*args, **kwargs) | ||
return wrapper | ||
return inner |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,226 @@ | ||
import numpy as np | ||
import pandas as pd | ||
import pytest | ||
|
||
from statsmodels.tools.validation import array_like, PandasWrapper | ||
from statsmodels.tools.validation.validation import _right_squeeze | ||
|
||
|
||
@pytest.fixture(params=[True, False]) | ||
def use_pandas(request): | ||
return request.param | ||
|
||
|
||
def gen_data(dim, use_pandas): | ||
if dim == 1: | ||
out = np.empty(10, ) | ||
if use_pandas: | ||
out = pd.Series(out) | ||
elif dim == 2: | ||
out = np.empty((20, 10)) | ||
if use_pandas: | ||
out = pd.DataFrame(out) | ||
else: | ||
out = np.empty(np.arange(5, 5 + dim)) | ||
|
||
return out | ||
|
||
|
||
class TestArrayLike(object): | ||
def test_1d(self, use_pandas): | ||
data = gen_data(1, use_pandas) | ||
a = array_like(data, 'a') | ||
assert a.ndim == 1 | ||
assert a.shape == (10,) | ||
assert type(a) is np.ndarray | ||
|
||
a = array_like(data, 'a', ndim=1) | ||
assert a.ndim == 1 | ||
a = array_like(data, 'a', shape=(10,)) | ||
assert a.shape == (10,) | ||
a = array_like(data, 'a', ndim=1, shape=(None,)) | ||
assert a.ndim == 1 | ||
a = array_like(data, 'a', ndim=2, shape=(10, 1)) | ||
assert a.ndim == 2 | ||
assert a.shape == (10, 1) | ||
|
||
with pytest.raises(ValueError, match='a is required to have shape'): | ||
array_like(data, 'a', shape=(5,)) | ||
|
||
def test_2d(self, use_pandas): | ||
data = gen_data(2, use_pandas) | ||
a = array_like(data, 'a', ndim=2) | ||
assert a.ndim == 2 | ||
assert a.shape == (20, 10) | ||
assert type(a) is np.ndarray | ||
|
||
a = array_like(data, 'a', ndim=2) | ||
assert a.ndim == 2 | ||
a = array_like(data, 'a', ndim=2, shape=(20, None)) | ||
assert a.shape == (20, 10) | ||
a = array_like(data, 'a', ndim=2, shape=(20,)) | ||
assert a.shape == (20, 10) | ||
a = array_like(data, 'a', ndim=2, shape=(None, 10)) | ||
assert a.shape == (20, 10) | ||
|
||
a = array_like(data, 'a', ndim=2, shape=(None, None)) | ||
assert a.ndim == 2 | ||
a = array_like(data, 'a', ndim=3) | ||
assert a.ndim == 3 | ||
assert a.shape == (20, 10, 1) | ||
|
||
with pytest.raises(ValueError, match='a is required to have shape'): | ||
array_like(data, 'a', ndim=2, shape=(10,)) | ||
with pytest.raises(ValueError, match='a is required to have shape'): | ||
array_like(data, 'a', ndim=2, shape=(20, 20)) | ||
with pytest.raises(ValueError, match='a is required to have shape'): | ||
array_like(data, 'a', ndim=2, shape=(None, 20)) | ||
match = 'a is required to have ndim 1 but has ndim 2' | ||
with pytest.raises(ValueError, match=match): | ||
array_like(data, 'a', ndim=1) | ||
match = 'a must have ndim <= 1' | ||
with pytest.raises(ValueError, match=match): | ||
array_like(data, 'a', maxdim=1) | ||
|
||
def test_3d(self): | ||
data = gen_data(3, False) | ||
a = array_like(data, 'a', ndim=3) | ||
assert a.shape == (5, 6, 7) | ||
assert a.ndim == 3 | ||
assert type(a) is np.ndarray | ||
|
||
a = array_like(data, 'a', ndim=3, shape=(5, None, 7)) | ||
assert a.shape == (5, 6, 7) | ||
a = array_like(data, 'a', ndim=3, shape=(None, None, 7)) | ||
assert a.shape == (5, 6, 7) | ||
a = array_like(data, 'a', ndim=5) | ||
assert a.shape == (5, 6, 7, 1, 1) | ||
with pytest.raises(ValueError, match='a is required to have shape'): | ||
array_like(data, 'a', ndim=3, shape=(10,)) | ||
with pytest.raises(ValueError, match='a is required to have shape'): | ||
array_like(data, 'a', ndim=3, shape=(None, None, 5)) | ||
match = 'a is required to have ndim 2 but has ndim 3' | ||
with pytest.raises(ValueError, match=match): | ||
array_like(data, 'a', ndim=2) | ||
match = 'a must have ndim <= 1' | ||
with pytest.raises(ValueError, match=match): | ||
array_like(data, 'a', maxdim=1) | ||
match = 'a must have ndim <= 2' | ||
with pytest.raises(ValueError, match=match): | ||
array_like(data, 'a', maxdim=2) | ||
|
||
def test_right_squeeze_and_pad(self): | ||
data = np.empty((2, 1, 2)) | ||
a = array_like(data, 'a', ndim=3) | ||
assert a.shape == (2, 1, 2) | ||
data = np.empty((2)) | ||
a = array_like(data, 'a', ndim=3) | ||
assert a.shape == (2, 1, 1) | ||
data = np.empty((2, 1)) | ||
a = array_like(data, 'a', ndim=3) | ||
assert a.shape == (2, 1, 1) | ||
|
||
data = np.empty((2, 1, 1, 1)) | ||
a = array_like(data, 'a', ndim=3) | ||
assert a.shape == (2, 1, 1) | ||
|
||
data = np.empty((2, 1, 1, 2, 1, 1)) | ||
with pytest.raises(ValueError): | ||
array_like(data, 'a', ndim=3) | ||
|
||
def test_contiguous(self): | ||
x = np.arange(10) | ||
y = x[::2] | ||
a = array_like(y, 'a', contiguous=True) | ||
assert not y.flags['C_CONTIGUOUS'] | ||
assert a.flags['C_CONTIGUOUS'] | ||
|
||
def test_dtype(self): | ||
x = np.arange(10) | ||
a = array_like(x, 'a', dtype=np.float32) | ||
assert a.dtype == np.float32 | ||
|
||
a = array_like(x, 'a', dtype=np.uint8) | ||
assert a.dtype == np.uint8 | ||
|
||
@pytest.mark.xfail(reason='Failing for now') | ||
def test_dot(self, use_pandas): | ||
data = gen_data(2, use_pandas) | ||
a = array_like(data, 'a') | ||
assert not isinstance(a.T.dot(data), array_like) | ||
assert not isinstance(a.T.dot(a), array_like) | ||
|
||
def test_slice(self, use_pandas): | ||
data = gen_data(2, use_pandas) | ||
a = array_like(data, 'a', ndim=2) | ||
assert type(a[1:]) is np.ndarray | ||
|
||
|
||
def test_right_squeeze(): | ||
x = np.empty((10, 1, 10)) | ||
y = _right_squeeze(x) | ||
assert y.shape == (10, 1, 10) | ||
|
||
x = np.empty((10, 10, 1)) | ||
y = _right_squeeze(x) | ||
assert y.shape == (10, 10) | ||
|
||
x = np.empty((10, 10, 1, 1, 1, 1, 1)) | ||
y = _right_squeeze(x) | ||
assert y.shape == (10, 10) | ||
|
||
x = np.empty((10, 1, 10, 1, 1, 1, 1, 1)) | ||
y = _right_squeeze(x) | ||
assert y.shape == (10, 1, 10) | ||
|
||
|
||
def test_wrap_pandas(use_pandas): | ||
a = gen_data(1, use_pandas) | ||
b = gen_data(1, False) | ||
|
||
wrapped = PandasWrapper(a).wrap(b) | ||
expected_type = pd.Series if use_pandas else np.ndarray | ||
assert isinstance(wrapped, expected_type) | ||
assert not use_pandas or wrapped.name is None | ||
|
||
wrapped = PandasWrapper(a).wrap(b, columns='name') | ||
assert isinstance(wrapped, expected_type) | ||
assert not use_pandas or wrapped.name == 'name' | ||
|
||
wrapped = PandasWrapper(a).wrap(b, columns=['name']) | ||
assert isinstance(wrapped, expected_type) | ||
assert not use_pandas or wrapped.name == 'name' | ||
|
||
expected_type = pd.DataFrame if use_pandas else np.ndarray | ||
wrapped = PandasWrapper(a).wrap(b[:, None]) | ||
assert isinstance(wrapped, expected_type) | ||
assert not use_pandas or wrapped.columns[0] == 0 | ||
|
||
wrapped = PandasWrapper(a).wrap(b[:, None], columns=['name']) | ||
assert isinstance(wrapped, expected_type) | ||
assert not use_pandas or wrapped.columns == ['name'] | ||
|
||
if use_pandas: | ||
match = 'Can only wrap 1 or 2-d array_like' | ||
with pytest.raises(ValueError, match=match): | ||
PandasWrapper(a).wrap(b[:, None, None]) | ||
|
||
match = 'obj must have the same number of elements in axis 0 as' | ||
with pytest.raises(ValueError, match=match): | ||
PandasWrapper(a).wrap(b[:b.shape[0] // 2]) | ||
|
||
|
||
def test_wrap_pandas_append(): | ||
a = gen_data(1, True) | ||
a.name = 'apple' | ||
b = gen_data(1, False) | ||
wrapped = PandasWrapper(a).wrap(b, append='appended') | ||
expected = 'apple_appended' | ||
assert wrapped.name == expected | ||
|
||
a = gen_data(2, True) | ||
a.columns = ['apple_' + str(i) for i in range(a.shape[1])] | ||
b = gen_data(2, False) | ||
wrapped = PandasWrapper(a).wrap(b, append='appended') | ||
expected = [c + '_appended' for c in a.columns] | ||
assert list(wrapped.columns) == expected |
Oops, something went wrong.