Skip to content

Commit

Permalink
Merge pull request #868 from markotoplak/vizrankregression
Browse files Browse the repository at this point in the history
scatterplot rank projections: R2 score for regression
  • Loading branch information
VesnaT committed Dec 17, 2015
2 parents 486ae5d + f6a3de1 commit e5b0753
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion Orange/widgets/visualize/owscatterplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from PyQt4.QtGui import QApplication, QTableView, QStandardItemModel, \
QStandardItem
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import r2_score

import Orange
from Orange.data import Table, Domain, StringVariable, ContinuousVariable, \
Expand Down Expand Up @@ -556,7 +557,11 @@ def run(self):
y = y_full[valid]
knn = NearestNeighbors(n_neighbors=self.k).fit(X)
ind = knn.kneighbors(return_distance=False)
score = norm * np.sum(y[ind] == y.reshape(-1, 1))
if isinstance(self.parent_widget.data.domain.class_var,
DiscreteVariable):
score = norm * np.sum(y[ind] == y.reshape(-1, 1))
else:
score = r2_score(y, np.mean(y[ind], axis=1))
pos = bisect_left(self.scores, score)
self.projectionTableModel.insertRow(
len(self.scores) - pos,
Expand Down

0 comments on commit e5b0753

Please sign in to comment.