Skip to content

Commit

Permalink
Merge pull request #17 from PrimozGodec/polynomial-classification-nul…
Browse files Browse the repository at this point in the history
…l-fix

[FIX] Polynomial Regression widget crashing on None values in data
  • Loading branch information
ajdapretnar committed Sep 6, 2016
2 parents 3f82e7e + 1075a96 commit b368d9b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 20 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import math

from PyQt4.QtGui import QColor, QSizePolicy, QPalette, QPen, QFont
from PyQt4.QtCore import Qt, QRectF

import sklearn.preprocessing as skl_preprocessing
import pyqtgraph as pg
import numpy as np

from Orange.widgets.widget import OWWidget, Msg
from Orange.data import Table, Domain
from Orange.data.variable import ContinuousVariable, StringVariable
from Orange.regression.linear import (RidgeRegressionLearner, PolynomialLearner,
LinearRegressionLearner, LinearModel)
LinearRegressionLearner)
from Orange.regression import Learner
from Orange.preprocess.preprocess import Preprocess
from Orange.widgets import settings, gui
from Orange.widgets.utils import itemmodels
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
Expand All @@ -28,6 +30,11 @@ class OWUnivariateRegression(OWBaseLearner):
outputs = [("Coefficients", Table),
("Data", Table)]

replaces = [
"Orange.widgets.regression.owunivariateregression."
"OWUnivariateRegression"
]

LEARNER = PolynomialLearner

learner_name = settings.Setting("Univariate Regression")
Expand All @@ -40,6 +47,12 @@ class OWUnivariateRegression(OWBaseLearner):
want_main_area = True
graph_name = 'Regression graph'

class Error(OWWidget.Error):
"""
Class used fro widget warnings.
"""
all_none = Msg("One of the features has no defined values")

def add_main_layout(self):

self.data = None
Expand All @@ -57,9 +70,8 @@ def add_main_layout(self):
self.x_var_model = itemmodels.VariableListModel()
self.comboBoxAttributesX = gui.comboBox(
box, self, value='x_var_index', label="Input: ",
orientation=Qt.Horizontal, callback=self.apply, contentsLength=12)
self.comboBoxAttributesX.setSizePolicy(
QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
orientation=Qt.Horizontal, callback=self.apply,
maximumContentsLength=15)
self.comboBoxAttributesX.setModel(self.x_var_model)
self.expansion_spin = gui.doubleSpin(
gui.indentedBox(box),
Expand All @@ -70,9 +82,8 @@ def add_main_layout(self):
self.y_var_model = itemmodels.VariableListModel()
self.comboBoxAttributesY = gui.comboBox(
box, self, value="y_var_index", label="Target: ",
orientation=Qt.Horizontal, callback=self.apply, contentsLength=12)
self.comboBoxAttributesY.setSizePolicy(
QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
orientation=Qt.Horizontal, callback=self.apply,
maximumContentsLength=15)
self.comboBoxAttributesY.setModel(self.y_var_model)

gui.rubber(self.controlArea)
Expand Down Expand Up @@ -135,7 +146,8 @@ def set_data(self, data):
self.data = data
if data is not None:
cvars = [var for var in data.domain.variables if var.is_continuous]
class_cvars = [var for var in data.domain.class_vars if var.is_continuous]
class_cvars = [var for var in data.domain.class_vars
if var.is_continuous]

self.x_var_model[:] = cvars
self.y_var_model[:] = cvars
Expand Down Expand Up @@ -187,17 +199,27 @@ def plot_regression_line(self, x_data, y_data):

def apply(self):
degree = int(self.polynomialexpansion)
learner = self.LEARNER(preprocessors=self.preprocessors,
degree=degree,
learner=LinearRegressionLearner() if self.learner is None
else self.learner)
learner = self.LEARNER(
preprocessors=self.preprocessors, degree=degree,
learner=LinearRegressionLearner() if self.learner is None
else self.learner)
learner.name = self.learner_name
predictor = None

self.Error.clear()

if self.data is not None:
attributes = self.x_var_model[self.x_var_index]
class_var = self.y_var_model[self.y_var_index]
data_table = Table(Domain([attributes], class_vars=[class_var]), self.data)
data_table = Table(
Domain([attributes], class_vars=[class_var]), self.data)

# all lines has nan
if sum(math.isnan(line[0]) or math.isnan(line.get_class())
for line in data_table) == len(data_table):
self.Error.all_none()
self.clear_plot()
return

predictor = learner(data_table)

Expand All @@ -209,7 +231,8 @@ def apply(self):
x = preprocessed_data.X.ravel()
y = preprocessed_data.Y.ravel()

linspace = np.linspace(min(x), max(x), 1000).reshape(-1,1)
linspace = np.linspace(
np.nanmin(x), np.nanmax(x), 1000).reshape(-1,1)
values = predictor(linspace, predictor.Value)

self.plot_scatter_points(x, y)
Expand Down Expand Up @@ -258,7 +281,9 @@ def send_data(self):
Domain([attributes], class_vars=[class_var]), self.data)
polyfeatures = skl_preprocessing.PolynomialFeatures(
int(self.polynomialexpansion))
x = polyfeatures.fit_transform(data_table.X)

x = data_table.X[~np.isnan(data_table.X).any(axis=1)]
x = polyfeatures.fit_transform(x)

x_label = data_table.domain.attributes[0].name
out_domain = Domain(
Expand Down Expand Up @@ -289,5 +314,3 @@ def add_bottom_buttons(self):
ow.show()
a.exec_()
ow.saveSettings()


Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from Orange.data import Domain, ContinuousVariable
from Orange.widgets.tests.base import WidgetTest
from orangecontrib.educational.widgets.owunivariateregression \
from orangecontrib.educational.widgets.owpolynomialregression \
import OWUnivariateRegression
from Orange.data.table import Table
from Orange.regression import (LinearRegressionLearner,
RandomForestRegressionLearner)
from Orange.preprocess.preprocess import Normalize

class TestOWUnivariateRegression(WidgetTest):
class TestOWPolynomialRegression(WidgetTest):

def setUp(self):
self.widget = self.create_widget(OWUnivariateRegression)
Expand Down Expand Up @@ -48,6 +49,15 @@ def test_set_data(self):
if len(class_variables) == 0
else len(continuous_variables) - len(class_variables))

# check with data with all none
data = Table(Domain([ContinuousVariable('a'),
ContinuousVariable('b')]),
[[None, None], [None, None]])
self.widget.set_data(data)
self.widget.apply()
self.assertIsNone(self.widget.plot_item)
self.assertIsNone(self.widget.scatterplot_item)

def test_add_main_layout(self):
self.assertEqual(self.widget.data, None)
self.assertEqual(self.widget.preprocessors, None)
Expand Down

0 comments on commit b368d9b

Please sign in to comment.