Skip to content

Commit

Permalink
Added SVM Classification widget.
Browse files Browse the repository at this point in the history
  • Loading branch information
ales-erjavec committed Feb 18, 2014
1 parent 64e63c2 commit 4ca9869
Show file tree
Hide file tree
Showing 2 changed files with 221 additions and 0 deletions.
53 changes: 53 additions & 0 deletions Orange/widgets/classify/icons/SVM.svg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
168 changes: 168 additions & 0 deletions Orange/widgets/classify/owsvmclassification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# -*- coding: utf-8 -*-

from PyQt4 import QtCore, QtGui
from PyQt4.QtCore import Qt


import Orange.data
from Orange.classification import svm

from Orange.widgets import widget, settings, gui


class OWSVMClassification(widget.OWWidget):
name = "SVM Classification"
description = ""
icon = "icons/SVM.svg"

inputs = [("Data", Orange.data.Table, "set_data")]
outputs = [("Learner", svm.SVMLearner),
("Classifier", svm.SVMClassifier)]

want_main_area = False

learner_name = settings.Setting("SVM Learner")

# 0: c_svc, 1: nu_svc
svmtype = settings.Setting(0)
C = settings.Setting(1.0)
nu = settings.Setting(0.5)
# 0: Linear, 1: Poly, 2: RBF, 3: Sigmoid
kernel_type = settings.Setting(0)
degree = settings.Setting(3)
gamma = settings.Setting(0.0)
coef0 = settings.Setting(0.0)
shrinking = settings.Setting(True),
probability = settings.Setting(False)
tol = settings.Setting(0.001)

def __init__(self, parent=None):
super().__init__(parent)

self.data = None

box = gui.widgetBox(self.controlArea, self.tr("Name"))
gui.lineEdit(box, self, "learner_name")

form = QtGui.QGridLayout()
typebox = gui.radioButtonsInBox(
self.controlArea, self, "svmtype", [],
box=self.tr("SVM Type"),
orientation=form,
)

c_svm = gui.appendRadioButton(typebox, "C-SVM", addToLayout=False)
form.addWidget(c_svm, 0, 0, Qt.AlignLeft)
form.addWidget(QtGui.QLabel(self.tr("Cost (C)")), 0, 1, Qt.AlignRight)
c_spin = gui.doubleSpin(
typebox, self, "C", 0.1, 512.0, 0.1,
decimals=2, addToLayout=False
)

form.addWidget(c_spin, 0, 2)

nu_svm = gui.appendRadioButton(typebox, "谓-SVM", addToLayout=False)
form.addWidget(nu_svm, 1, 0, Qt.AlignLeft)

form.addWidget(
QtGui.QLabel(self.trUtf8("Complexity bound (\u03bd)")),
1, 1, Qt.AlignRight
)

nu_spin = gui.doubleSpin(
typebox, self, "nu", 0.05, 1.0, 0.05,
decimals=2, addToLayout=False
)
form.addWidget(nu_spin, 1, 2)

box = gui.widgetBox(self.controlArea, self.tr("Kernel"))
buttonbox = gui.radioButtonsInBox(
box, self, "kernel_type",
btnLabels=["Linear, x鈭檡",
"Polynomial, (g x鈭檡 + c)^d",
"RBF, exp(-g|x-y|虏)",
"Sigmoid, tanh(g x鈭檡 + c)"],
callback=self._on_kernel_changed
)
parambox = gui.widgetBox(box, orientation="horizontal")
gamma = gui.doubleSpin(
parambox, self, "gamma", 0.0, 10.0, 0.0001,
label=" g: ", orientation="horizontal",
alignment=Qt.AlignRight
)
coef0 = gui.doubleSpin(
parambox, self, "coef0", 0.0, 10.0, 0.0001,
label=" c: ", orientation="horizontal",
alignment=Qt.AlignRight
)
degree = gui.doubleSpin(
parambox, self, "degree", 0.0, 10.0, 0.5,
label=" d: ", orientation="horizontal",
alignment=Qt.AlignRight
)
self._kernel_params = [gamma, coef0, degree]
box = gui.widgetBox(self.controlArea, "Numerical Tolerance")
gui.doubleSpin(box, self, "tol", 1e-7, 1e-3, 5e-7)

gui.button(self.controlArea, self, "&Apply",
callback=self.apply, default=True)

self.setSizePolicy(
QtGui.QSizePolicy(QtGui.QSizePolicy.Fixed,
QtGui.QSizePolicy.Fixed)
)

self.setMinimumWidth(300)

self._on_kernel_changed()

self.apply()

def set_data(self, data):

self.data = data

if data is not None:
self.data = data

self.apply()

def apply(self):
kernel = ["linear", "poly", "rbf", "sigmoid"][self.kernel_type]
common_args = dict(
kernel=kernel,
degree=self.degree,
gamma=self.gamma,
coef0=self.coef0,
tol=self.tol,
)
if self.svmtype == 0:
learner = svm.SVMLearner(C=self.C, **common_args)
else:
learner = svm.NuSVMLearner(nu=self.nu, **common_args)

classifier = None

if self.data is not None:
classifier = learner(self.data)

self.send("Learner", learner)
self.send("Classifier", classifier)

def _on_kernel_changed(self):
enabled = [[False, False, False], # linear
[True, True, True], # poly
[True, False, False], # rbf
[True, True, False]] # sigmoid

mask = enabled[self.kernel_type]
for spin, enabled in zip(self._kernel_params, mask):
spin.setEnabled(enabled)


if __name__ == "__main__":
app = QtGui.QApplication([])
w = OWSVMClassification()
w.set_data(Orange.data.Table("iris"))
w.show()
app.exec_()

0 comments on commit 4ca9869

Please sign in to comment.