|
| 1 | +from functools import partial |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from AnyQt.QtCore import Qt |
| 6 | + |
| 7 | +from orangewidget.settings import Setting |
| 8 | + |
| 9 | +from Orange.widgets import gui |
| 10 | +from Orange.widgets.settings import ContextSetting, DomainContextHandler |
| 11 | +from Orange.widgets.widget import OWWidget, Msg, Output, Input |
| 12 | +from Orange.widgets.utils.itemmodels import DomainModel |
| 13 | +from Orange.widgets.utils.widgetpreview import WidgetPreview |
| 14 | +from Orange.data import \ |
| 15 | + Table, Domain, DiscreteVariable, StringVariable, ContinuousVariable |
| 16 | +from Orange.data.util import SharedComputeValue, get_unique_names |
| 17 | + |
| 18 | + |
| 19 | +def get_substrings(values, delimiter): |
| 20 | + return sorted({ss.strip() for s in values for ss in s.split(delimiter)} |
| 21 | + - {""}) |
| 22 | + |
| 23 | + |
| 24 | +class SplitColumnBase: |
| 25 | + def __init__(self, data, attr, delimiter): |
| 26 | + self.attr = attr |
| 27 | + self.delimiter = delimiter |
| 28 | + column = set(data.get_column(self.attr)) |
| 29 | + self.new_values = tuple(get_substrings(column, self.delimiter)) |
| 30 | + |
| 31 | + def __eq__(self, other): |
| 32 | + return self.attr == other.attr \ |
| 33 | + and self.delimiter == other.delimiter \ |
| 34 | + and self.new_values == other.new_values |
| 35 | + |
| 36 | + def __hash__(self): |
| 37 | + return hash((self.attr, self.delimiter, self.new_values)) |
| 38 | + |
| 39 | + |
| 40 | +class SplitColumnOneHot(SplitColumnBase): |
| 41 | + InheritEq = True |
| 42 | + |
| 43 | + def __call__(self, data): |
| 44 | + column = data.get_column(self.attr) |
| 45 | + values = [{ss.strip() for ss in s.split(self.delimiter)} |
| 46 | + for s in column] |
| 47 | + return {v: np.array([i for i, xs in enumerate(values) if v in xs], |
| 48 | + dtype=int) |
| 49 | + for v in self.new_values} |
| 50 | + |
| 51 | + |
| 52 | +class SplitColumnCounts(SplitColumnBase): |
| 53 | + InheritEq = True |
| 54 | + |
| 55 | + def __call__(self, data): |
| 56 | + column = data.get_column(self.attr) |
| 57 | + values = [[ss.strip() for ss in s.split(self.delimiter)] |
| 58 | + for s in column] |
| 59 | + return {v: np.array([xs.count(v) for xs in values], dtype=float) |
| 60 | + for v in self.new_values} |
| 61 | + |
| 62 | + |
| 63 | +class StringEncodingBase(SharedComputeValue): |
| 64 | + def __init__(self, fn, new_feature): |
| 65 | + super().__init__(fn) |
| 66 | + self.new_feature = new_feature |
| 67 | + |
| 68 | + def __eq__(self, other): |
| 69 | + return super().__eq__(other) and self.new_feature == other.new_feature |
| 70 | + |
| 71 | + def __hash__(self): |
| 72 | + return super().__hash__() ^ hash(self.new_feature) |
| 73 | + |
| 74 | + def compute(self, data, shared_data): |
| 75 | + raise NotImplementedError # silence pylint |
| 76 | + |
| 77 | +class OneHotStrings(StringEncodingBase): |
| 78 | + InheritEq = True |
| 79 | + |
| 80 | + def compute(self, data, shared_data): |
| 81 | + indices = shared_data[self.new_feature] |
| 82 | + col = np.zeros(len(data)) |
| 83 | + col[indices] = 1 |
| 84 | + return col |
| 85 | + |
| 86 | + |
| 87 | +class CountStrings(StringEncodingBase): |
| 88 | + InheritEq = True |
| 89 | + |
| 90 | + def compute(self, data, shared_data): |
| 91 | + return shared_data[self.new_feature] |
| 92 | + |
| 93 | + |
| 94 | +class DiscreteEncoding: |
| 95 | + def __init__(self, variable, delimiter, onehot, value): |
| 96 | + self.variable = variable |
| 97 | + self.delimiter = delimiter |
| 98 | + self.onehot = onehot |
| 99 | + self.value = value |
| 100 | + |
| 101 | + def __call__(self, data): |
| 102 | + column = data.get_column(self.variable).astype(float) |
| 103 | + col = np.zeros(len(column)) |
| 104 | + col[np.isnan(column)] = np.nan |
| 105 | + for val_idx, value in enumerate(self.variable.values): |
| 106 | + parts = value.split(self.delimiter) |
| 107 | + if self.onehot: |
| 108 | + col[column == val_idx] = int(self.value in parts) |
| 109 | + else: |
| 110 | + col[column == val_idx] = parts.count(self.value) |
| 111 | + return col |
| 112 | + |
| 113 | + def __eq__(self, other): |
| 114 | + return self.variable == other.variable \ |
| 115 | + and self.value == other.value \ |
| 116 | + and self.delimiter == other.delimiter \ |
| 117 | + and self.onehot == other.onehot |
| 118 | + |
| 119 | + def __hash__(self): |
| 120 | + return hash((self.variable, self.value, self.delimiter, self.onehot)) |
| 121 | + |
| 122 | + |
| 123 | +class OWSplit(OWWidget): |
| 124 | + name = "Split" |
| 125 | + description = "Split text or categorical variables into indicator variables" |
| 126 | + category = "Transform" |
| 127 | + icon = "icons/Split.svg" |
| 128 | + keywords = ["text to columns", "word encoding", "questionnaire", "survey", |
| 129 | + "term", "word presence", "word counts", "categorical encoding", |
| 130 | + "indicator variables"] |
| 131 | + priority = 700 |
| 132 | + |
| 133 | + class Inputs: |
| 134 | + data = Input("Data", Table) |
| 135 | + |
| 136 | + class Outputs: |
| 137 | + data = Output("Data", Table) |
| 138 | + |
| 139 | + class Warning(OWWidget.Warning): |
| 140 | + no_disc = Msg("Data contains only numeric variables.") |
| 141 | + |
| 142 | + want_main_area = False |
| 143 | + resizing_enabled = False |
| 144 | + |
| 145 | + Categorical, Numerical, Counts = range(3) |
| 146 | + OutputLabels = ("Categorical (No, Yes)", "Numerical (0, 1)", "Counts") |
| 147 | + |
| 148 | + settingsHandler = DomainContextHandler() |
| 149 | + attribute = ContextSetting(None) |
| 150 | + delimiter = ContextSetting(";") |
| 151 | + output_type = ContextSetting(Categorical) |
| 152 | + auto_apply = Setting(True) |
| 153 | + |
| 154 | + def __init__(self): |
| 155 | + super().__init__() |
| 156 | + self.data = None |
| 157 | + |
| 158 | + variable_select_box = gui.vBox(self.controlArea, "Variable") |
| 159 | + |
| 160 | + gui.comboBox(variable_select_box, self, "attribute", |
| 161 | + orientation=Qt.Horizontal, searchable=True, |
| 162 | + callback=self.apply.deferred, |
| 163 | + model=DomainModel(valid_types=(StringVariable, |
| 164 | + DiscreteVariable))) |
| 165 | + le = gui.lineEdit( |
| 166 | + variable_select_box, self, "delimiter", "Delimiter: ", |
| 167 | + orientation=Qt.Horizontal, callback=self.apply.deferred, |
| 168 | + controlWidth=20) |
| 169 | + le.box.layout().addStretch(1) |
| 170 | + le.setAlignment(Qt.AlignCenter) |
| 171 | + |
| 172 | + gui.radioButtonsInBox( |
| 173 | + self.controlArea, self, "output_type", self.OutputLabels, |
| 174 | + box="Output Values", |
| 175 | + callback=self.apply.deferred) |
| 176 | + |
| 177 | + gui.auto_apply(self.buttonsArea, self, commit=self.apply) |
| 178 | + |
| 179 | + @Inputs.data |
| 180 | + def set_data(self, data): |
| 181 | + self.closeContext() |
| 182 | + self.data = data |
| 183 | + |
| 184 | + model = self.controls.attribute.model() |
| 185 | + model.set_domain(data.domain if data is not None else None) |
| 186 | + self.Warning.no_disc(shown=data is not None and not model) |
| 187 | + if not model: |
| 188 | + self.attribute = None |
| 189 | + self.data = None |
| 190 | + return |
| 191 | + self.attribute = model[0] |
| 192 | + self.openContext(data) |
| 193 | + self.apply.now() |
| 194 | + |
| 195 | + @gui.deferred |
| 196 | + def apply(self): |
| 197 | + if self.attribute is None: |
| 198 | + self.Outputs.data.send(None) |
| 199 | + return |
| 200 | + var = self.data.domain[self.attribute] |
| 201 | + values, computer = self._get_compute_value(var) |
| 202 | + new_columns = self._get_new_columns(values, computer) |
| 203 | + new_domain = Domain( |
| 204 | + self.data.domain.attributes + new_columns, |
| 205 | + self.data.domain.class_vars, self.data.domain.metas |
| 206 | + ) |
| 207 | + extended_data = self.data.transform(new_domain) |
| 208 | + self.Outputs.data.send(extended_data) |
| 209 | + |
| 210 | + def _get_compute_value(self, var): |
| 211 | + if var.is_discrete: |
| 212 | + values = get_substrings(var.values, self.delimiter) |
| 213 | + computer = partial( |
| 214 | + DiscreteEncoding, |
| 215 | + var, self.delimiter, self.output_type != self.Counts) |
| 216 | + else: |
| 217 | + if self.output_type == self.Counts: |
| 218 | + sc = SplitColumnCounts(self.data, var, self.delimiter) |
| 219 | + computer = partial(CountStrings, sc) |
| 220 | + else: |
| 221 | + sc = SplitColumnOneHot(self.data, var, self.delimiter) |
| 222 | + computer = partial(OneHotStrings, sc) |
| 223 | + values = sc.new_values |
| 224 | + return values, computer |
| 225 | + |
| 226 | + def _get_new_columns(self, values, computer): |
| 227 | + names = get_unique_names(self.data.domain, values, equal_numbers=False) |
| 228 | + if self.output_type == self.Categorical: |
| 229 | + return tuple( |
| 230 | + DiscreteVariable( |
| 231 | + name, ("No", "Yes"), compute_value=computer(value)) |
| 232 | + for value, name in zip(values, names)) |
| 233 | + else: |
| 234 | + return tuple( |
| 235 | + ContinuousVariable( |
| 236 | + name, compute_value=computer(value)) |
| 237 | + for value, name in zip(values, names)) |
| 238 | + |
| 239 | + |
| 240 | +if __name__ == "__main__": # pragma: no cover |
| 241 | + WidgetPreview(OWSplit).run(Table.from_file("tests/orange-in-education.tab")) |
0 commit comments