Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c711966
commit 7c247a3
Showing
6 changed files
with
132 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,3 +9,4 @@ | |
DummyEncoder, | ||
OrdinalEncoder, | ||
) | ||
from .label import LabelEncoder # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
from __future__ import division | ||
|
||
from operator import getitem | ||
|
||
import dask.array as da | ||
import dask.dataframe as dd | ||
import numpy as np | ||
from sklearn.preprocessing import label as sklabel | ||
from sklearn.utils.validation import check_is_fitted | ||
|
||
|
||
class LabelEncoder(sklabel.LabelEncoder): | ||
|
||
__doc__ = sklabel.LabelEncoder.__doc__ | ||
|
||
def _check_array(self, y): | ||
if isinstance(y, dd.Series): | ||
y = da.asarray(y) | ||
return y | ||
|
||
def fit(self, y): | ||
y = self._check_array(y) | ||
|
||
if isinstance(y, da.Array): | ||
classes_ = da.unique(y) | ||
classes_ = classes_.compute() | ||
else: | ||
classes_ = np.unique(y) | ||
|
||
self.classes_ = classes_ | ||
|
||
return self | ||
|
||
def fit_transform(self, y): | ||
return self.fit(y).transform(y) | ||
|
||
def transform(self, y): | ||
check_is_fitted(self, 'classes_') | ||
y = self._check_array(y) | ||
|
||
if isinstance(y, da.Array): | ||
return da.map_blocks(np.searchsorted, self.classes_, y, | ||
dtype=self.classes_.dtype) | ||
else: | ||
return np.searchsorted(self.classes_, y) | ||
|
||
def inverse_transform(self, y): | ||
check_is_fitted(self, 'classes_') | ||
y = self._check_array(y) | ||
|
||
if isinstance(y, da.Array): | ||
return da.map_blocks(getitem, self.classes_, y, | ||
dtype=self.classes_.dtype) | ||
else: | ||
y = np.asarray(y) | ||
return self.classes_[y] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import pytest | ||
import sklearn.preprocessing as spp | ||
|
||
import dask.array as da | ||
import dask.dataframe as dd | ||
import numpy as np | ||
import pandas as pd | ||
from dask.array.utils import assert_eq as assert_eq_ar | ||
|
||
import dask_ml.preprocessing as dpp | ||
from dask_ml.utils import assert_estimator_equal | ||
|
||
|
||
choices = np.array(['a', 'b', 'c'], dtype=str) | ||
y = np.random.choice(choices, 100) | ||
y = da.from_array(y, chunks=13) | ||
s = dd.from_array(y) | ||
|
||
|
||
@pytest.fixture | ||
def pandas_series(): | ||
y = np.random.choice(['a', 'b', 'c'], 100) | ||
return pd.Series(y) | ||
|
||
|
||
@pytest.fixture | ||
def dask_array(pandas_series): | ||
return da.from_array(pandas_series, chunks=5) | ||
|
||
|
||
class TestLabelEncoder(object): | ||
def test_basic(self): | ||
a = dpp.LabelEncoder() | ||
b = spp.LabelEncoder() | ||
|
||
a.fit(y) | ||
b.fit(y.compute()) | ||
assert_estimator_equal(a, b) | ||
|
||
def test_input_types(self, dask_array, pandas_series): | ||
a = dpp.LabelEncoder() | ||
b = spp.LabelEncoder() | ||
|
||
assert_estimator_equal(a.fit(dask_array), | ||
b.fit(pandas_series)) | ||
|
||
assert_estimator_equal(a.fit(pandas_series), | ||
b.fit(pandas_series)) | ||
|
||
assert_estimator_equal(a.fit(pandas_series.values), | ||
b.fit(pandas_series)) | ||
|
||
assert_estimator_equal(a.fit(dask_array), | ||
b.fit(pandas_series.values)) | ||
|
||
@pytest.mark.parametrize('array', [y, s]) | ||
def test_transform(self, array): | ||
a = dpp.LabelEncoder() | ||
b = spp.LabelEncoder() | ||
|
||
a.fit(array) | ||
b.fit(array.compute()) | ||
|
||
assert_eq_ar(a.transform(array).compute(), | ||
b.transform(array.compute())) | ||
|
||
@pytest.mark.parametrize('array', [y, s]) | ||
def test_inverse_transform(self, array): | ||
|
||
a = dpp.LabelEncoder() | ||
assert_eq_ar(a.inverse_transform(a.fit_transform(array)), | ||
da.asarray(array)) |