Skip to content

Commit

Permalink
Merge pull request #6515 from janezd/freeviz-gravity
Browse files Browse the repository at this point in the history
FreeViz: Allow setting ratio btw attractive and repulsive forces
  • Loading branch information
VesnaT committed Sep 12, 2023
2 parents 24596fe + 8bb82de commit 484becf
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 8 deletions.
17 changes: 12 additions & 5 deletions Orange/projection/freeviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class FreeViz(LinearProjector):
projection = FreeVizModel

def __init__(self, weights=None, center=True, scale=True, dim=2, p=1,
initial=None, maxiter=500, alpha=0.1,
initial=None, maxiter=500, alpha=0.1, gravity=None,
atol=1e-5, preprocessors=None):
super().__init__(preprocessors=preprocessors)
self.weights = weights
Expand All @@ -33,6 +33,7 @@ def __init__(self, weights=None, center=True, scale=True, dim=2, p=1,
self.maxiter = maxiter
self.alpha = alpha
self.atol = atol
self.gravity = gravity
self.is_class_discrete = False
self.components_ = None

Expand All @@ -50,6 +51,7 @@ def get_components(self, X, Y):
X, Y, weights=self.weights, center=self.center, scale=self.scale,
dim=self.dim, p=self.p, initial=self.initial,
maxiter=self.maxiter, alpha=self.alpha, atol=self.atol,
gravity=self.gravity,
is_class_discrete=self.is_class_discrete)[1].T

@classmethod
Expand Down Expand Up @@ -104,7 +106,7 @@ def forces_regression(cls, distances, y, p=1):
return F

@classmethod
def forces_classification(cls, distances, y, p=1):
def forces_classification(cls, distances, y, p=1, gravity=None):
diffclass = scipy.spatial.distance.pdist(y.reshape(-1, 1), "hamming") != 0
# handle attractive force
if p == 1:
Expand All @@ -120,6 +122,8 @@ def forces_classification(cls, distances, y, p=1):
F[mask] = 1 / distances[mask]
else:
F[mask] = 1 / (distances[mask] ** p)
if gravity is not None:
F[mask] *= -np.sum(F[~mask]) / np.sum(F[mask]) / gravity
return F

@classmethod
Expand Down Expand Up @@ -180,7 +184,8 @@ def gradient(cls, X, embeddings, forces, embedding_dist=None, weights=None):
return G

@classmethod
def freeviz_gradient(cls, X, y, embedding, p=1, weights=None, is_class_discrete=False):
def freeviz_gradient(cls, X, y, embedding, p=1, weights=None,
gravity=None, is_class_discrete=False):
"""
Return the gradient for the FreeViz [1]_ projection.
Expand Down Expand Up @@ -214,7 +219,7 @@ def freeviz_gradient(cls, X, y, embedding, p=1, weights=None, is_class_discrete=
assert X.ndim == 2 and X.shape[0] == y.shape[0] == embedding.shape[0]
D = scipy.spatial.distance.pdist(embedding)
if is_class_discrete:
forces = cls.forces_classification(D, y, p=p)
forces = cls.forces_classification(D, y, p=p, gravity=gravity)
else:
forces = cls.forces_regression(D, y, p=p)
G = cls.gradient(X, embedding, forces, embedding_dist=D, weights=weights)
Expand All @@ -234,7 +239,8 @@ def _rotate(cls, A):

@classmethod
def freeviz(cls, X, y, weights=None, center=True, scale=True, dim=2, p=1,
initial=None, maxiter=500, alpha=0.1, atol=1e-5, is_class_discrete=False):
initial=None, maxiter=500, alpha=0.1, atol=1e-5, gravity=None,
is_class_discrete=False):
"""
FreeViz
Expand Down Expand Up @@ -341,6 +347,7 @@ def freeviz(cls, X, y, weights=None, center=True, scale=True, dim=2, p=1,
step_i = 0
while step_i < maxiter:
G = cls.freeviz_gradient(X, y, embeddings, p=p, weights=weights,
gravity=gravity,
is_class_discrete=is_class_discrete)

# Scale the changes (the largest anchor move is alpha * radius)
Expand Down
44 changes: 42 additions & 2 deletions Orange/widgets/visualize/owfreeviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from AnyQt.QtCore import Qt, QRectF, QLineF, QPoint
from AnyQt.QtGui import QPalette
from AnyQt.QtGui import QPalette, QFontMetrics
from AnyQt.QtWidgets import QSizePolicy

import pyqtgraph as pg
Expand Down Expand Up @@ -137,9 +137,13 @@ class OWFreeViz(OWAnchorProjectionWidget, ConcurrentWidgetMixin):

settings_version = 3
initialization = settings.Setting(InitType.Circular)
balance = settings.Setting(False)
gravity_index = settings.Setting(4)
GRAPH_CLASS = OWFreeVizGraph
graph = settings.SettingProvider(OWFreeVizGraph)

GravityValues = [0.1, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.5, 3, 4, 5]

class Error(OWAnchorProjectionWidget.Error):
no_class_var = widget.Msg("Data must have a target variable.")
multiple_class_vars = widget.Msg(
Expand All @@ -159,6 +163,7 @@ class Warning(OWAnchorProjectionWidget.Warning):
def __init__(self):
OWAnchorProjectionWidget.__init__(self)
ConcurrentWidgetMixin.__init__(self)
self.__optimized = False

def _add_controls(self):
self.__add_controls_start_box()
Expand All @@ -177,6 +182,20 @@ def __add_controls_start_box(self):
callback=self.__init_combo_changed,
sizePolicy=(QSizePolicy.MinimumExpanding, QSizePolicy.Fixed)
)
box2 = gui.hBox(box)
gui.checkBox(
box2, self, "balance", "Gravity",
callback=self.__gravity_changed)
self.grav_slider = gui.hSlider(
box2, self, "gravity_index",
minValue=0, maxValue=len(self.GravityValues) - 1,
callback=self.__gravity_dragged, createLabel=False)
self.gravity_label = gui.widgetLabel(box2)
self.gravity_label.setFixedWidth(
max(QFontMetrics(self.font()).horizontalAdvance(str(x))
for x in self.GravityValues))
self.gravity_label.setAlignment(Qt.AlignRight)
self.__update_gravity_label()
self.run_button = gui.button(box, self, "Start", self._toggle_run)

@property
Expand All @@ -189,6 +208,21 @@ def effective_data(self):
return self.data.transform(Domain(self.effective_variables,
self.data.domain.class_vars))

def __gravity_dragged(self):
self.balance = True
self.__gravity_changed()

def __update_gravity_label(self):
self.gravity_label.setText(str(self.GravityValues[self.gravity_index]))

def __gravity_changed(self):
gravity = self.GravityValues[self.gravity_index]
if self.projector is not None:
self.projector.gravity = gravity if self.balance else None
self.__update_gravity_label()
if self.task is None and self.__optimized:
self._run()

def __radius_slider_changed(self):
self.graph.update_radius()

Expand Down Expand Up @@ -232,6 +266,7 @@ def on_done(self, result: Result):
self.projection = result.projection
self.graph.set_sample_size(None)
self.run_button.setText("Start")
self.__optimized = True
self.commit.deferred()

def on_exception(self, ex: Exception):
Expand All @@ -253,14 +288,19 @@ def init_projection(self):
anchors = FreeViz.init_radial(len(self.effective_variables)) \
if self.initialization == InitType.Circular \
else FreeViz.init_random(len(self.effective_variables), 2)
if self.balance:
gravity = self.GravityValues[self.gravity_index]
else:
gravity = None
self.projector = FreeViz(scale=False, center=False,
initial=anchors, maxiter=10)
initial=anchors, maxiter=10, gravity=gravity)
data = self.projector.preprocess(self.effective_data)
self.projector.domain = data.domain
self.projector.components_ = anchors.T
self.projection = FreeVizModel(self.projector, self.projector.domain, 2)
self.projection.pre_domain = data.domain
self.projection.name = self.projector.name
self.__optimized = False

def check_data(self):
def error(err):
Expand Down
42 changes: 41 additions & 1 deletion Orange/widgets/visualize/tests/test_owfreeviz.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Test methods with long descriptive names can omit docstrings
# pylint: disable=missing-docstring
import unittest
from unittest.mock import Mock
from unittest.mock import Mock, patch

import numpy as np

Expand Down Expand Up @@ -156,6 +156,46 @@ def test_discrete_attributes(self):
self.assertTrue(self.widget.Warning.removed_features.is_shown())
self.widget.run_button.click()

def test_gravity_slider(self):
w = self.widget

w.balance = False
w.gravity_index = 0

w.grav_slider.setValue(2)
self.assertTrue(w.balance)
self.assertEqual(w.gravity_label.text(), str(w.GravityValues[2]))

w.grav_slider.setValue(3)
self.assertTrue(w.balance)
self.assertEqual(w.gravity_label.text(), str(w.GravityValues[3]))

assert w.projector is None
self.send_signal(self.widget.Inputs.data, Table("zoo"))
self.wait_until_finished()
assert w.projector is not None

# w.projector.gravity has correct value if gravity was set before data
self.assertEqual(w.projector.gravity, w.GravityValues[3])

# ... and if set when the data is already present and projector exists
w.grav_slider.setValue(1)
self.assertEqual(w.projector.gravity, w.GravityValues[1])

# Check that optimization is restarted if the projection is optimized
with patch.object(w, "_run") as run, \
patch.object(w, "_OWFreeViz__optimized", new=True):
w.grav_slider.setValue(2)
self.assertEqual(w.projector.gravity, w.GravityValues[2])
run.assert_called_once()

# Also, check that checkbox also does all that
run.reset_mock()
w.controls.balance.click()
self.assertFalse(w.balance)
self.assertIsNone(w.projector.gravity)
run.assert_called_once()


class TestOWFreeVizRunner(unittest.TestCase):
@classmethod
Expand Down

0 comments on commit 484becf

Please sign in to comment.