Skip to content

Commit

Permalink
Merge pull request #25 from jerneju/refactoring-2
Browse files Browse the repository at this point in the history
[ENH] Datafusion: new signals (refactoring)
  • Loading branch information
kernc committed Jan 12, 2018
2 parents 643b8af + b5fa446 commit 909285c
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 84 deletions.
18 changes: 10 additions & 8 deletions orangecontrib/datafusion/widgets/owchaining.py
@@ -1,6 +1,7 @@
import numpy as np

from Orange.widgets import widget, gui, settings
from Orange.widgets.widget import Input, Output
from Orange.widgets.utils.itemmodels import PyTableModel

from skfusion import fusion
Expand All @@ -10,10 +11,6 @@

from AnyQt.QtCore import pyqtSignal

class Output:
RELATION = 'Relation'


class ChainingGraphView(GraphView):
def __init__(self, parent):
super().__init__(parent)
Expand All @@ -33,8 +30,12 @@ class OWChaining(widget.OWWidget):
"another type through chaining of latent factors."
priority = 30000
icon = "icons/LatentChaining.svg"
inputs = [("Fitted fusion graph", FittedFusionGraph, "on_fuser_change")]
outputs = [(Output.RELATION, Relation)]

class Inputs:
fitted_fusion_graph = Input("Fitted fusion graph", FittedFusionGraph)

class Outputs:
relation = Output("Relation", Relation)

pref_complete = settings.Setting(0) # Complete chaining to feature space

Expand Down Expand Up @@ -78,7 +79,7 @@ def selected_row(row):
self.graphview.clearSelection()
self._highlight_relations(chain)
data = self.fuser.compute_chain(chain, self.pref_complete)
self.send(Output.RELATION, data)
self.Outputs.relation.send(data)

table.selected_row.connect(selected_row)
box.layout().addWidget(table)
Expand Down Expand Up @@ -109,7 +110,7 @@ def _highlight_relations(self, relations):
edge.selected = True

def _populate_table(self, chains=[]):
self.send(Output.RELATION, None)
self.Outputs.relation.send(None)
model = []
for chain in chains:
columns = [str(self.startNode.name)]
Expand All @@ -123,6 +124,7 @@ def _populate_table(self, chains=[]):
self.model.wrap(model)
self.table.hideColumn(0)

@Inputs.fitted_fusion_graph
def on_fuser_change(self, fuser):
self.fuser = fuser
self._populate_table()
Expand Down
12 changes: 8 additions & 4 deletions orangecontrib/datafusion/widgets/owcompletionscoring.py
@@ -1,13 +1,15 @@
from collections import OrderedDict

import numpy as np

from skfusion import fusion

from AnyQt.QtCore import Qt
from AnyQt.QtGui import QFont
from AnyQt.QtWidgets import QTableWidgetItem, QTableWidget

from Orange.widgets import widget, gui
from Orange.widgets.widget import Input
from orangecontrib.datafusion.models import Relation, RelationCompleter
from orangecontrib.datafusion.widgets.owfusiongraph import \
relation_str
Expand Down Expand Up @@ -36,10 +38,10 @@ class OWCompletionScoring(widget.OWWidget):
"root mean squared error (RMSE)."
priority = 40000
icon = 'icons/CompletionScoring.svg'
inputs = [
('Fitted fusion graph', RelationCompleter, 'on_fuser_change', widget.Multiple),
('Relation', Relation, 'on_relation_change', widget.Multiple),
]

class Inputs:
fitted_fusion_graph = Input('Fitted fusion graph', RelationCompleter, multiple=True)
relation = Input('Relation', Relation, multiple=True)

want_main_area = True
want_control_area = False
Expand Down Expand Up @@ -101,12 +103,14 @@ def update_table(self, fusers, relations):
def update(self):
self.table.update_table(self.fusers, self.relations)

@Inputs.fitted_fusion_graph
def on_fuser_change(self, fuser, id):
if fuser:
self.fusers[id] = [fuser]
else: del self.fusers[id]
self.update()

@Inputs.relation
def on_relation_change(self, relation, id):
if relation:
self.relations[id] = relation.relation
Expand Down
29 changes: 14 additions & 15 deletions orangecontrib/datafusion/widgets/owfusiongraph.py
Expand Up @@ -2,6 +2,8 @@

from Orange.widgets import widget, gui, settings
from Orange.widgets.utils.itemmodels import PyTableModel
from Orange.widgets.widget import Input, Output

from skfusion import fusion
from orangecontrib.datafusion.models import Relation, FusionGraph, FittedFusionGraph
from orangecontrib.datafusion.widgets.graphview import GraphView, Node, Edge
Expand All @@ -18,12 +20,6 @@
]


class Output:
RELATION = 'Relation'
FUSION_GRAPH = 'Fusion Graph'
FUSER = 'Fitted Fusion Graph'


def rel_shape(relation):
return '{}×{}'.format(*relation.shape)

Expand Down Expand Up @@ -58,12 +54,14 @@ class OWFusionGraph(widget.OWWidget):
"collective matrix factorization."
priority = 10000
icon = "icons/FusionGraph.svg"
inputs = [("Relation", Relation, "on_relation_change", widget.Multiple)]
outputs = [
(Output.RELATION, Relation),
(Output.FUSER, FittedFusionGraph, widget.Default),
(Output.FUSION_GRAPH, FusionGraph),
]

class Inputs:
relation = Input("Relation", Relation, multiple=True)

class Outputs:
relation = Output('Relation', Relation)
fuser = Output('Fitted Fusion Graph', FittedFusionGraph, default=True)
fusion_graph = Output('Fusion Graph', FusionGraph)

pref_algo_name = settings.Setting('')
pref_algorithm = settings.Setting(0)
Expand Down Expand Up @@ -124,7 +122,7 @@ def selectionChanged(self, selected, deselected):
assert len(selected) == 1
data = self._parent.tablemodel[selected[0].top()][0]
relation = Relation(data)
self._parent.send(Output.RELATION, relation)
self._parent.Outputs.relation.send(relation)

model = self.tablemodel = PyTableModel(parent=self)
table = self.table = TableView(self,
Expand Down Expand Up @@ -184,14 +182,15 @@ def commit(self):
finally:
self.progressbar.finish()
self.fuser.name = self.pref_algo_name
self.send(Output.FUSER, FittedFusionGraph(self.fuser))
self.Outputs.fuser.send(FittedFusionGraph(self.fuser))

def _populate_table(self, relations=None):
self.tablemodel.wrap([[rel, rel_shape(rel.data)] + rel_cols(rel)
for rel in relations or self.graph.relations])
self.table.hideColumn(0)
self.table.selectRow(0)

@Inputs.relation
def on_relation_change(self, relation, id):
def _on_remove_relation(id):
try: relation = self.relations.pop(id)
Expand Down Expand Up @@ -220,7 +219,7 @@ def _on_add_relation(relation, id):
for rel in self.graph.relations)
else
100)
self.send(Output.FUSION_GRAPH, FusionGraph(self.graph))
self.Outputs.fusion_graph.send(FusionGraph(self.graph))
# this ensures gui.label-s get updated
self.n_object_types = self.graph.n_object_types
self.n_relations = self.graph.n_relations
Expand Down
19 changes: 10 additions & 9 deletions orangecontrib/datafusion/widgets/owimdbactors.py
Expand Up @@ -3,15 +3,12 @@
from AnyQt.QtWidgets import QSizePolicy

from orangecontrib.datafusion.models import Relation
from Orange.widgets.widget import OWWidget
from Orange.widgets import widget, gui, settings
from Orange.widgets.widget import OWWidget, Input, Output
from orangecontrib.datafusion import movielens

from skfusion import fusion

MOVIE_ACTORS = "Movie Actors"
ACTORS_ACTORS = "Costarring Actors"


class OWIMDbActors(OWWidget):
name = "IMDb Actors"
Expand All @@ -21,9 +18,12 @@ class OWIMDbActors(OWWidget):
want_main_area = False
resizing_enabled = False

inputs = [("Filter", Relation, "set_data")]
outputs = [(MOVIE_ACTORS, Relation),
(ACTORS_ACTORS, Relation)]
class Inputs:
filter = Input("Filter", Relation)

class Outputs:
movie_actors = Output("Movie Actors", Relation)
actors_actors = Output("Costarring Actors", Relation)

percent = settings.Setting(10)

Expand All @@ -46,6 +46,7 @@ def __init__(self):

self.movies = None

@Inputs.filter
def set_data(self, relation):
if relation is not None:
assert isinstance(relation, Relation)
Expand All @@ -67,12 +68,12 @@ def send_output(self):
movies_actors = fusion.Relation(movie_actor_mat.T, name='play in',
row_type=movielens.ObjectType.Actors, row_names=actors,
col_type=movielens.ObjectType.Movies, col_names=self.movies)
self.send(MOVIE_ACTORS, Relation(movies_actors))
self.Outputs.movie_actors.send(Relation(movies_actors))

actors_actors = fusion.Relation(actor_actor_mat, name='costar with',
row_type=movielens.ObjectType.Actors, row_names=actors,
col_type=movielens.ObjectType.Actors, col_names=actors)
self.send(ACTORS_ACTORS, Relation(actors_actors))
self.Outputs.actors_actors.send(Relation(actors_actors))


if __name__ == "__main__":
Expand Down
19 changes: 11 additions & 8 deletions orangecontrib/datafusion/widgets/owlatentfactors.py
@@ -1,5 +1,7 @@
from Orange.widgets import widget, gui, settings
from Orange.widgets.utils.itemmodels import PyTableModel
from Orange.widgets.widget import Input, Output

from skfusion import fusion
from orangecontrib.datafusion.widgets.owfusiongraph import rel_shape, rel_cols
from orangecontrib.datafusion.models import Relation, FittedFusionGraph
Expand All @@ -11,10 +13,6 @@ def is_constraint(relation):
return relation.row_type == relation.col_type


class Output:
RELATION = 'Relation'


class LatentGraphView(GraphView):
def itemClicked(self, item):
if isinstance(item, Edge) and item.source is item.dest: return
Expand All @@ -28,8 +26,12 @@ class OWLatentFactors(widget.OWWidget):
"select a latent factor for further analysis."
priority = 20000
icon = "icons/LatentFactors.svg"
inputs = [("Fitted fusion graph", FittedFusionGraph, "on_fuser_change")]
outputs = [(Output.RELATION, Relation)]

class Inputs:
fitted_fusion_graph = Input("Fitted fusion graph", FittedFusionGraph)

class Outputs:
relation = Output("Relation", Relation)

autorun = settings.Setting(True)

Expand Down Expand Up @@ -132,13 +134,13 @@ def _f(selected, deselected):

def commit(self, item):
data = Relation.create(*item, graph=self.fuser) if item else None
self.send(Output.RELATION, data)
self.Outputs.relation.send(data)

def _populate_tables(self, factors=None, backbones=None, reset=False):
if not self.fuser: return
self.model_factors.clear()
self.model_backbones.clear()
self.send(Output.RELATION, None)
self.Outputs.relation.send(None)
if factors or reset:
for otype, matrices in factors or self.fuser.factors_.items():
M = matrices[0]
Expand All @@ -156,6 +158,7 @@ def _populate_tables(self, factors=None, backbones=None, reset=False):
self.model_completions.append([(M, rel.row_type, rel.col_type), rel_shape(M.data)] + rel_cols(rel))
self.table_completions.hideColumn(0)

@Inputs.fitted_fusion_graph
def on_fuser_change(self, fuser):
self.fuser = fuser
self._populate_tables(reset=True)
Expand Down
28 changes: 13 additions & 15 deletions orangecontrib/datafusion/widgets/owmeanfuser.py
Expand Up @@ -3,6 +3,7 @@

from Orange.widgets import widget, gui, settings
from Orange.widgets.utils.itemmodels import PyTableModel
from Orange.widgets.widget import Input, Output

from skfusion import fusion
from orangecontrib.datafusion.models import Relation, FusionGraph, RelationCompleter
Expand All @@ -11,11 +12,6 @@
import numpy as np


class Output:
FUSER = 'Mean-fitted fusion graph'
RELATION = 'Relation'


class MeanBy:
ROWS = 'Rows'
COLUMNS = 'Columns'
Expand Down Expand Up @@ -69,14 +65,14 @@ class OWMeanFuser(widget.OWWidget):
name = 'Mean Fuser'
priority = 55000
icon = 'icons/MeanFuser.svg'
inputs = [
('Fusion graph', FusionGraph, 'on_fusion_graph_change'),
('Relation', Relation, 'on_relation_change', widget.Multiple),
]
outputs = [
(Output.FUSER, MeanFuser, widget.Default),
(Output.RELATION, Relation)
]

class Inputs:
fusion_graph = Input('Fusion graph', FusionGraph)
relation = Input('Relation', Relation, multiple=True)

class Outputs:
fuser = Output('Mean-fitted fusion graph', MeanFuser, default=True)
relation = Output('Relation', Relation)

want_main_area = False

Expand Down Expand Up @@ -120,7 +116,7 @@ def selectionChanged(self, *args):

def commit(self, item=None):
self.fuser = MeanFuser(self.mean_by)
self.send(Output.FUSER, self.fuser)
self.Outputs.fuser.send(self.fuser)
rows = [i.row() for i in self.table.selectionModel().selectedRows()]
if self.model.rowCount() and rows:
relation = self.model[rows[0]][0]
Expand All @@ -130,7 +126,7 @@ def commit(self, item=None):
self.graph)
else:
data = None
self.send(Output.RELATION, data)
self.Outputs.relation.send(data)

def update_table(self):
self.model.wrap([([rel, rel_shape(rel.data)] +
Expand All @@ -147,6 +143,7 @@ def _remove_relation(self, relation):
if not self.relations[relation]:
del self.relations[relation]

@Inputs.fusion_graph
def on_fusion_graph_change(self, graph):
if graph:
self.graph = graph
Expand All @@ -159,6 +156,7 @@ def on_fusion_graph_change(self, graph):
self.update_table()
self.commit()

@Inputs.relation
def on_relation_change(self, relation, id):
try: self._remove_relation(self.id_relations.pop(id))
except KeyError: pass
Expand Down

0 comments on commit 909285c

Please sign in to comment.