Skip to content

Commit 4a9ea67

Browse files
authored
Merge pull request #7153 from janezd/owsplit
Split: Move widget from Prototypes
2 parents 5cee704 + 559b5af commit 4a9ea67

File tree

5 files changed

+744
-0
lines changed

5 files changed

+744
-0
lines changed
Lines changed: 33 additions & 0 deletions
Loading

Orange/widgets/data/owsplit.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
from functools import partial
2+
3+
import numpy as np
4+
5+
from AnyQt.QtCore import Qt
6+
7+
from orangewidget.settings import Setting
8+
9+
from Orange.widgets import gui
10+
from Orange.widgets.settings import ContextSetting, DomainContextHandler
11+
from Orange.widgets.widget import OWWidget, Msg, Output, Input
12+
from Orange.widgets.utils.itemmodels import DomainModel
13+
from Orange.widgets.utils.widgetpreview import WidgetPreview
14+
from Orange.data import \
15+
Table, Domain, DiscreteVariable, StringVariable, ContinuousVariable
16+
from Orange.data.util import SharedComputeValue, get_unique_names
17+
18+
19+
def get_substrings(values, delimiter):
20+
return sorted({ss.strip() for s in values for ss in s.split(delimiter)}
21+
- {""})
22+
23+
24+
class SplitColumnBase:
25+
def __init__(self, data, attr, delimiter):
26+
self.attr = attr
27+
self.delimiter = delimiter
28+
column = set(data.get_column(self.attr))
29+
self.new_values = tuple(get_substrings(column, self.delimiter))
30+
31+
def __eq__(self, other):
32+
return self.attr == other.attr \
33+
and self.delimiter == other.delimiter \
34+
and self.new_values == other.new_values
35+
36+
def __hash__(self):
37+
return hash((self.attr, self.delimiter, self.new_values))
38+
39+
40+
class SplitColumnOneHot(SplitColumnBase):
41+
InheritEq = True
42+
43+
def __call__(self, data):
44+
column = data.get_column(self.attr)
45+
values = [{ss.strip() for ss in s.split(self.delimiter)}
46+
for s in column]
47+
return {v: np.array([i for i, xs in enumerate(values) if v in xs],
48+
dtype=int)
49+
for v in self.new_values}
50+
51+
52+
class SplitColumnCounts(SplitColumnBase):
53+
InheritEq = True
54+
55+
def __call__(self, data):
56+
column = data.get_column(self.attr)
57+
values = [[ss.strip() for ss in s.split(self.delimiter)]
58+
for s in column]
59+
return {v: np.array([xs.count(v) for xs in values], dtype=float)
60+
for v in self.new_values}
61+
62+
63+
class StringEncodingBase(SharedComputeValue):
64+
def __init__(self, fn, new_feature):
65+
super().__init__(fn)
66+
self.new_feature = new_feature
67+
68+
def __eq__(self, other):
69+
return super().__eq__(other) and self.new_feature == other.new_feature
70+
71+
def __hash__(self):
72+
return super().__hash__() ^ hash(self.new_feature)
73+
74+
def compute(self, data, shared_data):
75+
raise NotImplementedError # silence pylint
76+
77+
class OneHotStrings(StringEncodingBase):
78+
InheritEq = True
79+
80+
def compute(self, data, shared_data):
81+
indices = shared_data[self.new_feature]
82+
col = np.zeros(len(data))
83+
col[indices] = 1
84+
return col
85+
86+
87+
class CountStrings(StringEncodingBase):
88+
InheritEq = True
89+
90+
def compute(self, data, shared_data):
91+
return shared_data[self.new_feature]
92+
93+
94+
class DiscreteEncoding:
95+
def __init__(self, variable, delimiter, onehot, value):
96+
self.variable = variable
97+
self.delimiter = delimiter
98+
self.onehot = onehot
99+
self.value = value
100+
101+
def __call__(self, data):
102+
column = data.get_column(self.variable).astype(float)
103+
col = np.zeros(len(column))
104+
col[np.isnan(column)] = np.nan
105+
for val_idx, value in enumerate(self.variable.values):
106+
parts = value.split(self.delimiter)
107+
if self.onehot:
108+
col[column == val_idx] = int(self.value in parts)
109+
else:
110+
col[column == val_idx] = parts.count(self.value)
111+
return col
112+
113+
def __eq__(self, other):
114+
return self.variable == other.variable \
115+
and self.value == other.value \
116+
and self.delimiter == other.delimiter \
117+
and self.onehot == other.onehot
118+
119+
def __hash__(self):
120+
return hash((self.variable, self.value, self.delimiter, self.onehot))
121+
122+
123+
class OWSplit(OWWidget):
124+
name = "Split"
125+
description = "Split text or categorical variables into indicator variables"
126+
category = "Transform"
127+
icon = "icons/Split.svg"
128+
keywords = ["text to columns", "word encoding", "questionnaire", "survey",
129+
"term", "word presence", "word counts", "categorical encoding",
130+
"indicator variables"]
131+
priority = 700
132+
133+
class Inputs:
134+
data = Input("Data", Table)
135+
136+
class Outputs:
137+
data = Output("Data", Table)
138+
139+
class Warning(OWWidget.Warning):
140+
no_disc = Msg("Data contains only numeric variables.")
141+
142+
want_main_area = False
143+
resizing_enabled = False
144+
145+
Categorical, Numerical, Counts = range(3)
146+
OutputLabels = ("Categorical (No, Yes)", "Numerical (0, 1)", "Counts")
147+
148+
settingsHandler = DomainContextHandler()
149+
attribute = ContextSetting(None)
150+
delimiter = ContextSetting(";")
151+
output_type = ContextSetting(Categorical)
152+
auto_apply = Setting(True)
153+
154+
def __init__(self):
155+
super().__init__()
156+
self.data = None
157+
158+
variable_select_box = gui.vBox(self.controlArea, "Variable")
159+
160+
gui.comboBox(variable_select_box, self, "attribute",
161+
orientation=Qt.Horizontal, searchable=True,
162+
callback=self.apply.deferred,
163+
model=DomainModel(valid_types=(StringVariable,
164+
DiscreteVariable)))
165+
le = gui.lineEdit(
166+
variable_select_box, self, "delimiter", "Delimiter: ",
167+
orientation=Qt.Horizontal, callback=self.apply.deferred,
168+
controlWidth=20)
169+
le.box.layout().addStretch(1)
170+
le.setAlignment(Qt.AlignCenter)
171+
172+
gui.radioButtonsInBox(
173+
self.controlArea, self, "output_type", self.OutputLabels,
174+
box="Output Values",
175+
callback=self.apply.deferred)
176+
177+
gui.auto_apply(self.buttonsArea, self, commit=self.apply)
178+
179+
@Inputs.data
180+
def set_data(self, data):
181+
self.closeContext()
182+
self.data = data
183+
184+
model = self.controls.attribute.model()
185+
model.set_domain(data.domain if data is not None else None)
186+
self.Warning.no_disc(shown=data is not None and not model)
187+
if not model:
188+
self.attribute = None
189+
self.data = None
190+
return
191+
self.attribute = model[0]
192+
self.openContext(data)
193+
self.apply.now()
194+
195+
@gui.deferred
196+
def apply(self):
197+
if self.attribute is None:
198+
self.Outputs.data.send(None)
199+
return
200+
var = self.data.domain[self.attribute]
201+
values, computer = self._get_compute_value(var)
202+
new_columns = self._get_new_columns(values, computer)
203+
new_domain = Domain(
204+
self.data.domain.attributes + new_columns,
205+
self.data.domain.class_vars, self.data.domain.metas
206+
)
207+
extended_data = self.data.transform(new_domain)
208+
self.Outputs.data.send(extended_data)
209+
210+
def _get_compute_value(self, var):
211+
if var.is_discrete:
212+
values = get_substrings(var.values, self.delimiter)
213+
computer = partial(
214+
DiscreteEncoding,
215+
var, self.delimiter, self.output_type != self.Counts)
216+
else:
217+
if self.output_type == self.Counts:
218+
sc = SplitColumnCounts(self.data, var, self.delimiter)
219+
computer = partial(CountStrings, sc)
220+
else:
221+
sc = SplitColumnOneHot(self.data, var, self.delimiter)
222+
computer = partial(OneHotStrings, sc)
223+
values = sc.new_values
224+
return values, computer
225+
226+
def _get_new_columns(self, values, computer):
227+
names = get_unique_names(self.data.domain, values, equal_numbers=False)
228+
if self.output_type == self.Categorical:
229+
return tuple(
230+
DiscreteVariable(
231+
name, ("No", "Yes"), compute_value=computer(value))
232+
for value, name in zip(values, names))
233+
else:
234+
return tuple(
235+
ContinuousVariable(
236+
name, compute_value=computer(value))
237+
for value, name in zip(values, names))
238+
239+
240+
if __name__ == "__main__": # pragma: no cover
241+
WidgetPreview(OWSplit).run(Table.from_file("tests/orange-in-education.tab"))

0 commit comments

Comments
 (0)