Skip to content

Commit

Permalink
Merge pull request #6508 from markotoplak/clean-sklimpute
Browse files Browse the repository at this point in the history
Clean up SklImpute
  • Loading branch information
JakaKokosar committed Jul 12, 2023
2 parents 711a863 + 1d076ee commit 929bd7a
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 8 deletions.
7 changes: 2 additions & 5 deletions Orange/preprocess/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,21 +158,18 @@ def __call__(self, data):
if isinstance(data, SqlTable):
return Impute()(data)
imputer = SimpleImputer(strategy=self.strategy)
X = imputer.fit_transform(data.X)
imputer.fit(data.X)
# Create new variables with appropriate `compute_value`, but
# drop the ones which do not have valid `imputer.statistics_`
# (i.e. all NaN columns). `sklearn.preprocessing.Imputer` already
# drops them from the transformed X.
features = [impute.Average()(data, var, value)
features = [var.copy(compute_value=impute.ReplaceUnknowns(var, value))
for var, value in zip(data.domain.attributes,
imputer.statistics_)
if not np.isnan(value)]
assert X.shape[1] == len(features)
domain = Orange.data.Domain(features, data.domain.class_vars,
data.domain.metas)
new_data = data.transform(domain)
with new_data.unlocked(new_data.X):
new_data.X = X
return new_data


Expand Down
43 changes: 42 additions & 1 deletion Orange/tests/test_impute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import scipy.sparse as sp

from Orange import preprocess
from Orange.preprocess import impute
from Orange.preprocess import impute, SklImpute
from Orange import data
from Orange.data import Unknown, Table

Expand Down Expand Up @@ -329,3 +329,44 @@ def test_imputer(self):
auto = data.Table(test_filename('datasets/imports-85.tab'))
auto2 = preprocess.Impute()(auto)
self.assertFalse(np.isnan(auto2.X).any())


class TestSklImpute(unittest.TestCase):

def setUp(self):
nan = np.nan
X = [
[1.0, nan, 0.0],
[2.0, 1.0, 3.0],
[nan, nan, nan]
]
self.imputed_mean = [
[1.0, 1.0, 0.0],
[2.0, 1.0, 3.0],
[1.5, 1.0, 1.5]
]
domain = data.Domain((data.ContinuousVariable(n) for n in "ABC"))
self.table = data.Table.from_numpy(domain, np.array(X))

def test_values(self):
imputed = SklImpute()(self.table)
np.testing.assert_equal(imputed.X, self.imputed_mean)

def test_sparse(self):
sparse = self.table.to_sparse()
self.assertTrue(sp.issparse(sparse.X))
imputed = SklImpute()(sparse)
self.assertTrue(sp.issparse(imputed.X))
np.testing.assert_equal(imputed.X.todense(), self.imputed_mean)

def test_transform(self):
imputed = SklImpute()(self.table)
transformed = self.table.transform(imputed.domain)
np.testing.assert_equal(transformed.X, self.imputed_mean)

def test_transform_sparse(self):
sparse = self.table.to_sparse()
imputed = SklImpute()(sparse)
self.assertTrue(sp.issparse(sparse.X))
transformed = sparse.transform(imputed.domain)
np.testing.assert_equal(transformed.X.todense(), self.imputed_mean)
14 changes: 12 additions & 2 deletions benchmark/bench_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import numpy as np

from Orange.data import Domain, Table, ContinuousVariable
from Orange.preprocess import Normalize
from Orange.preprocess import Normalize, SklImpute

from .base import Benchmark, benchmark


class BenchNormalize(Benchmark):
class SetUpData:

def setUp(self):
cols = 1000
Expand All @@ -21,6 +21,9 @@ def setUp(self):
np.random.RandomState(0).randint(0, 2, (rows, len(self.domain.variables))))
self.normalized_domain = Normalize()(self.table).domain


class BenchNormalize(SetUpData, Benchmark):

@benchmark(number=5)
def bench_normalize_only_transform(self):
self.table.transform(self.normalized_domain)
Expand All @@ -30,3 +33,10 @@ def bench_normalize_only_parameters(self):
# avoid benchmarking transformation
with patch("Orange.data.Table.transform", MagicMock()):
Normalize()(self.table)


class BenchSklImpute(SetUpData, Benchmark):

@benchmark(number=5)
def bench_sklimpute(self):
SklImpute()(self.table)

0 comments on commit 929bd7a

Please sign in to comment.