Skip to content

Commit

Permalink
Merge pull request #6784 from VesnaT/pls_missing_target
Browse files Browse the repository at this point in the history
PLS: Handle missing target values
  • Loading branch information
lanzagar committed May 10, 2024
2 parents d8818b5 + 538e16d commit 4f4539a
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
11 changes: 10 additions & 1 deletion Orange/regression/pls.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,16 @@ def _transform_with_numpy_output(self, X, Y):
u = y_center @ pls.y_rotations_
"""
pls = self.pls_model.skl_model
t, u = pls.transform(X, Y)
mask = np.isnan(Y).any(axis=1)
n_comp = pls.n_components
t = np.full((len(X), n_comp), np.nan, dtype=float)
u = np.full((len(X), n_comp), np.nan, dtype=float)
if (~mask).sum() > 0:
t_, u_ = pls.transform(X[~mask], Y[~mask])
t[~mask] = t_
u[~mask] = u_
if mask.sum() > 0:
t[mask] = pls.transform(X[mask])
return np.hstack((t, u))

def __call__(self, data):
Expand Down
32 changes: 32 additions & 0 deletions Orange/regression/tests/test_pls.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,38 @@ def test_hash(self):
m = PLSRegressionLearner()(table(10, 5, 2))
self.assertNotEqual(hash(transformer), hash(_PLSCommonTransform(m)))

def test_missing_target(self):
data = table(10, 5, 1)
with data.unlocked(data.Y):
data.Y[::3] = np.nan
pls = PLSRegressionLearner()(data)
proj = pls.project(data)
self.assertFalse(np.isnan(proj.X).any())
self.assertFalse(np.isnan(proj.metas[1::3]).any())
self.assertFalse(np.isnan(proj.metas[2::3]).any())
self.assertTrue(np.isnan(proj.metas[::3]).all())

def test_missing_target_multitarget(self):
data = table(10, 5, 3)
with data.unlocked(data.Y):
data.Y[0] = np.nan
data.Y[1, 1] = np.nan

pls = PLSRegressionLearner()(data)
proj = pls.project(data)
self.assertFalse(np.isnan(proj.X).any())
self.assertFalse(np.isnan(proj.metas[2:]).any())
self.assertTrue(np.isnan(proj.metas[:2]).all())

def test_apply_domain_classless_data(self):
data = Table("housing")
pls = PLSRegressionLearner()(data)
classless_data = data.transform(Domain(data.domain.attributes))[:5]

proj = pls.project(classless_data)
self.assertFalse(np.isnan(proj.X).any())
self.assertTrue(np.isnan(proj.metas).all())


if __name__ == "__main__":
unittest.main()
16 changes: 16 additions & 0 deletions Orange/widgets/model/tests/test_owpls.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
import numpy as np

from Orange.data import Table, Domain, StringVariable
from Orange.widgets.model.owpls import OWPLS
Expand Down Expand Up @@ -87,6 +88,21 @@ def test_output_components_multi_target(self):
self.assertEqual(components.Y.shape, (2, 2))
self.assertEqual(components.metas.shape, (2, 1))

def test_missing_target(self):
data = self._data[:5].copy()
data.Y[[0, 4]] = np.nan
self.send_signal(self.widget.Inputs.data, data)
output = self.get_output(self.widget.Outputs.data)
self.assertFalse(np.isnan(output.metas[:, 3:].astype(float)).any())
self.assertTrue(np.isnan(output.metas[0, 1:3].astype(float)).all())
self.assertTrue(np.isnan(output.metas[4, 1:3].astype(float)).all())
self.assertFalse(np.isnan(output.metas[1:4, 1:3].astype(float)).any())

with data.unlocked(data.Y):
data.Y[:] = np.nan
self.send_signal(self.widget.Inputs.data, data)
self.assertIsNone(self.get_output(self.widget.Outputs.data))


if __name__ == "__main__":
unittest.main()

0 comments on commit 4f4539a

Please sign in to comment.