Skip to content

Commit

Permalink
Merge pull request #935 from janezd/bootstrap-sample
Browse files Browse the repository at this point in the history
OWDataSampler: Add bootstrap, minor fixes in GUI
  • Loading branch information
VesnaT committed Dec 21, 2015
2 parents 71483a1 + 990f958 commit 33ad543
Showing 1 changed file with 33 additions and 12 deletions.
45 changes: 33 additions & 12 deletions Orange/widgets/data/owdatasampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class OWDataSampler(widget.OWWidget):
resizing_enabled = False

RandomSeed = 42
FixedProportion, FixedSize, CrossValidation = range(3)
FixedProportion, FixedSize, CrossValidation, Bootstrap = range(4)
SqlTime, SqlProportion = range(2)

use_seed = Setting(False)
Expand Down Expand Up @@ -69,7 +69,8 @@ def f():
gui.indentedBox(sampling), self,
"sampleSizePercentage",
minValue=0, maxValue=99, ticks=10, labelFormat="%d %%",
callback=set_sampling_type(self.FixedProportion))
callback=set_sampling_type(self.FixedProportion),
addSpace=12)

gui.appendRadioButton(sampling, "Fixed sample size:")
ibox = gui.indentedBox(sampling)
Expand All @@ -79,10 +80,9 @@ def f():
callback=set_sampling_type(self.FixedSize))
gui.checkBox(
ibox, self, "replacement", "Sample with replacement",
callback=set_sampling_type(self.FixedSize))
gui.separator(sampling, 12)
callback=set_sampling_type(self.FixedSize),
addSpace=12)

gui.separator(sampling, 12)
gui.appendRadioButton(sampling, "Cross Validation:")
form = QtGui.QFormLayout(
formAlignment=Qt.AlignLeft | Qt.AlignTop,
Expand All @@ -99,6 +99,8 @@ def f():
addToLayout=False, callback=self.fold_changed)
form.addRow("Selected fold", self.selected_fold_spin)

gui.appendRadioButton(sampling, "Boostrap")

self.sql_box = gui.widgetBox(self.controlArea, "Sampling Type")
sampling = gui.radioButtons(self.sql_box, self, "sampling_type",
callback=self.sampling_type_changed)
Expand All @@ -115,6 +117,7 @@ def f():
spin.setSuffix(" %")
self.sql_box.setVisible(False)


self.options_box = gui.widgetBox(self.controlArea, "Options")
self.cb_seed = gui.checkBox(
self.options_box, self, "use_seed",
Expand All @@ -129,7 +132,8 @@ def f():
self.cb_sql_dl.setVisible(False)

gui.button(self.controlArea, self, "Sample Data",
callback=self.commit)
callback=self.commit, addSpace=8)
self.controlArea.layout().addWidget(self.report_button)

def sampling_type_changed(self):
self.settings_changed()
Expand Down Expand Up @@ -190,12 +194,13 @@ def commit(self):
self.updateindices()
if self.indices is None:
return
if self.sampling_type in [self.FixedProportion, self.FixedSize]:
if self.sampling_type in (
self.FixedProportion, self.FixedSize, self.Bootstrap):
remaining, sample = self.indices
self.outputInfoLabel.setText(
'Outputting %d instance%s.' %
(len(sample), "s" * (len(sample) != 1)))
else:
elif self.sampling_type == self.CrossValidation:
remaining, sample = self.indices[self.selectedFold - 1]
self.outputInfoLabel.setText(
'Outputting fold %d, %d instance%s.' %
Expand All @@ -215,16 +220,20 @@ def updateindices(self):
num_classes = len(self.data.domain.class_var.values) \
if self.data.domain.has_discrete_class else 0

size = None
if self.sampling_type == self.FixedSize:
size = self.sampleSizeNumber
repl = self.replacement
elif self.sampling_type == self.FixedProportion:
size = np.ceil(self.sampleSizePercentage / 100 * data_length)
repl = False
elif data_length < self.number_of_folds:
err_msg = "Number of folds exceeds the data size"
elif self.sampling_type == self.CrossValidation:
if data_length < self.number_of_folds:
err_msg = "Number of folds exceeds the data size"
else:
assert self.sampling_type == self.Bootstrap

if not repl and (data_length <= size):
if not repl and size is not None and (data_length <= size):
err_msg = "Sample must be smaller than data"
if not repl and data_length <= num_classes and self.stratify:
err_msg = "Not enough data for stratified sampling"
Expand All @@ -241,13 +250,15 @@ def updateindices(self):
self.data.domain.has_discrete_class)
if self.sampling_type == self.FixedSize:
self.indices = sample_random_n(
self.data, self.sampleSizeNumber,
self.data, size,
stratified=stratified, replace=self.replacement,
random_state=rnd)
elif self.sampling_type == self.FixedProportion:
self.indices = sample_random_p(
self.data, self.sampleSizePercentage / 100,
stratified=stratified, random_state=rnd)
elif self.sampling_type == self.Bootstrap:
self.indices = sample_bootstrap(data_length, random_state=rnd)
else:
self.indices = sample_fold_indices(
self.data, self.number_of_folds, stratified=stratified,
Expand Down Expand Up @@ -332,6 +343,16 @@ def sample_random_p(table, p, stratified=False, random_state=None):
return sample_random_n(table, n, stratified, False, random_state)


def sample_bootstrap(size, random_state=None):
rgen = np.random.RandomState(random_state)
sample = rgen.randint(0, size, size)
sample.sort() # not needed for the code below, just for the user
insample = np.ones((size,), dtype=np.bool)
insample[sample] = False
remaining = np.flatnonzero(insample)
return remaining, sample


def test_main():
app = QtGui.QApplication([])
data = Table("iris")
Expand Down

0 comments on commit 33ad543

Please sign in to comment.