Skip to content

Commit

Permalink
Merge pull request #88 from Iron-Stark/patch-19
Browse files Browse the repository at this point in the history
Update Scikit SVM implementation.
  • Loading branch information
rcurtin committed Jul 13, 2017
2 parents 2c61fb9 + c141a37 commit c38c4fd
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 16 deletions.
32 changes: 17 additions & 15 deletions methods/scikit/svm.py
Expand Up @@ -47,6 +47,7 @@ def __init__(self, dataset, timeout=0, verbose=True):
self.dataset = dataset
self.timeout = timeout
self.model = None
self.predictions = None
self.opts = {}

'''
Expand All @@ -58,7 +59,7 @@ def __init__(self, dataset, timeout=0, verbose=True):
'''
def BuildModel(self, data, labels):
# Create and train the classifier.
svm = ssvm.SVC(**opts)
svm = ssvm.SVC(**self.opts)
svm.fit(data, labels)
return svm

Expand Down Expand Up @@ -91,8 +92,7 @@ def RunSVMScikit(q):
if "max_iterations" in options:
self.opts["max_iter"] = int(options.pop("max_iterations"))
if "decision_function_shape" in options:
self.opts["decision_function_shape"] =
str(options.pop("decision_function_shape"))
self.opts["decision_function_shape"] = str(options.pop("decision_function_shape"))

if len(options) > 0:
Log.Fatal("Unknown parameters: " + str(options))
Expand All @@ -102,18 +102,27 @@ def RunSVMScikit(q):
with totalTimer:
self.model = self.BuildModel(trainData, labels)
# Run Support vector machines on the test dataset.
self.model.predict(testData)
self.predictions = self.model.predict(testData)
except Exception as e:
Log.Debug(str(e))
q.put(-1)
return -1

time = totalTimer.ElapsedTime()
q.put(time)
if len(self.dataset) > 1:
q.put((time, self.predictions))
else:
q.put(time)

return time

return timeout(RunSVMScikit, self.timeout)
result = timeout(RunSVMScikit, self.timeout)
# Check for error, in this case the tuple doesn't contain extra information.
if len(result) > 1:
self.predictions = result[1]
return result[0]

return result

'''
Perform the Support vector machines. If the method has been
Expand All @@ -140,21 +149,14 @@ def RunMetrics(self, options):

if len(self.dataset) >= 3:

# Check if we need to create a model.
if not self.model:
trainData, labels = SplitTrainData(self.dataset)
self.model = self.BuildModel(trainData, labels)

testData = LoadDataset(self.dataset[1])
truelabels = LoadDataset(self.dataset[2])
predictedlabels = self.model.predict(testData)

confusionMatrix = Metrics.ConfusionMatrix(truelabels, predictedlabels)
confusionMatrix = Metrics.ConfusionMatrix(truelabels, self.predictions)
metrics['ACC'] = Metrics.AverageAccuracy(confusionMatrix)
metrics['MCC'] = Metrics.MCCMultiClass(confusionMatrix)
metrics['Precision'] = Metrics.AvgPrecision(confusionMatrix)
metrics['Recall'] = Metrics.AvgRecall(confusionMatrix)
metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictedlabels)
metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, self.predictions)

return metrics

56 changes: 56 additions & 0 deletions tests/benchmark_svm.py
@@ -0,0 +1,56 @@
'''
@file benchmark_svm.py
Test for the svm scripts.
'''

import unittest

import os, sys, inspect

# Import the util path, this method even works if the path contains
# symlinks to modules.
cmd_subfolder = os.path.realpath(os.path.abspath(os.path.join(
os.path.split(inspect.getfile(inspect.currentframe()))[0], '../util')))
if cmd_subfolder not in sys.path:
sys.path.insert(0, cmd_subfolder)

from loader import *

'''
Test the scikit svm script.
'''

class SVM_SCIKIT_TEST(unittest.TestCase):

'''
Test initialization.
'''
def setUp(self):
self.dataset = ['datasets/iris_train.csv','datasets/iris_test.csv','datasets/iris_labels.csv']
self.verbose = False
self.timeout = 9000

module = Loader.ImportModuleFromPath("methods/scikit/svm.py")
obj = getattr(module, "SVM")
self.instance = obj(self.dataset, verbose=self.verbose, timeout=self.timeout)

'''
Test the constructor.
'''
def test_Constructor(self):
self.assertEqual(self.instance.verbose, self.verbose)
self.assertEqual(self.instance.timeout, self.timeout)
self.assertEqual(self.instance.dataset, self.dataset)

'''
Test the 'RunMetrics' function.
'''
def test_RunMetrics(self):
result = self.instance.RunMetrics({})
self.assertTrue(result["Runtime"] > 0)
self.assertTrue(result["ACC"] > 0)
self.assertTrue(result["Precision"] > 0)
self.assertTrue(result["Recall"] > 0)

if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion tests/tests.py
Expand Up @@ -35,7 +35,8 @@
'benchmark_range_search',
'benchmark_sparse_coding',
'benchmark_svr',
'benchmark_adaboost'
'benchmark_adaboost',
'benchmark_svm'
]

def load_tests(loader, tests, pattern):
Expand Down

0 comments on commit c38c4fd

Please sign in to comment.