-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updating Weka implementations. #95
Changes from 2 commits
d7d050c
dfc9bed
8cee8cd
b5ab343
54c4287
5f39c92
6aaeec2
e7cc809
b6a4903
50da5db
54b27e6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
''' | ||
@file dtc.py | ||
Class to benchmark the weka Decision Stump Classifier method. | ||
''' | ||
|
||
import os | ||
import sys | ||
import inspect | ||
|
||
# Import the util path, this method even works if the path contains symlinks to | ||
# modules. | ||
cmd_subfolder = os.path.realpath(os.path.abspath(os.path.join( | ||
os.path.split(inspect.getfile(inspect.currentframe()))[0], "../../util"))) | ||
if cmd_subfolder not in sys.path: | ||
sys.path.insert(0, cmd_subfolder) | ||
|
||
#Import the metrics definitions path. | ||
metrics_folder = os.path.realpath(os.path.abspath(os.path.join( | ||
os.path.split(inspect.getfile(inspect.currentframe()))[0], "../metrics"))) | ||
if metrics_folder not in sys.path: | ||
sys.path.insert(0, metrics_folder) | ||
|
||
from log import * | ||
from profiler import * | ||
from definitions import * | ||
from misc import * | ||
|
||
import shlex | ||
import subprocess | ||
import re | ||
import collections | ||
import numpy as np | ||
|
||
''' | ||
This class implements the Decision Stump Classifier benchmark. | ||
''' | ||
class DTC(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably rename the class here, to avoid confusions with the |
||
|
||
''' | ||
Create the Decision Stump Classifier benchmark instance. | ||
@param dataset - Input dataset to perform DECISIONSTUMP on. | ||
@param timeout - The time until the timeout. Default no timeout. | ||
@param path - Path to the mlpack executable. | ||
@param verbose - Display informational messages. | ||
''' | ||
def __init__(self, dataset, timeout=0, path=os.environ["JAVAPATH"], | ||
verbose=True): | ||
self.verbose = verbose | ||
self.dataset = dataset | ||
self.path = path | ||
self.timeout = timeout | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should add a destructor here to clean up, e.g. remove the |
||
''' | ||
Decision Stump Classifier. If the method has been successfully completed return | ||
the elapsed time in seconds. | ||
@param options - Extra options for the method. | ||
@return - Elapsed time in seconds or a negative value if the method was not | ||
successful. | ||
''' | ||
def RunMetrics(self, options): | ||
Log.Info("Perform DECISIONSTUMP.", self.verbose) | ||
|
||
if len(options) > 0: | ||
Log.Fatal("Unknown parameters: " + str(options)) | ||
raise Exception("unknown parameters") | ||
|
||
if len(self.dataset) < 2: | ||
Log.Fatal("This method requires two or more datasets.") | ||
return -1 | ||
|
||
# Split the command using shell-like syntax. | ||
cmd = shlex.split("java -classpath " + self.path + "/weka.jar" + | ||
":methods/weka" + " DECISIONSTUMP -t " + self.dataset[0] + " -T " + | ||
self.dataset[1]) | ||
|
||
# Run command with the nessecary arguments and return its output as a byte | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo: "necessary" :) |
||
# string. We have untrusted input so we disable all shell based features. | ||
try: | ||
s = subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=False, | ||
timeout=self.timeout) | ||
except subprocess.TimeoutExpired as e: | ||
Log.Warn(str(e)) | ||
return -2 | ||
except Exception as e: | ||
Log.Fatal("Could not execute command: " + str(cmd)) | ||
return -1 | ||
|
||
# Datastructure to store the results. | ||
metrics = {} | ||
|
||
# Parse data: runtime. | ||
timer = self.parseTimer(s) | ||
|
||
if timer != -1: | ||
predictions = np.genfromtxt("weka_predicted.csv", delimiter=',') | ||
truelabels = np.genfromtxt(self.dataset[2], delimiter = ',') | ||
metrics['Runtime'] = timer.total_time | ||
confusionMatrix = Metrics.ConfusionMatrix(truelabels, self.predictions) | ||
metrics['ACC'] = Metrics.AverageAccuracy(confusionMatrix) | ||
metrics['MCC'] = Metrics.MCCMultiClass(confusionMatrix) | ||
metrics['Precision'] = Metrics.AvgPrecision(confusionMatrix) | ||
metrics['Recall'] = Metrics.AvgRecall(confusionMatrix) | ||
metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, self.predictions) | ||
|
||
Log.Info(("total time: %fs" % (metrics['Runtime'])), self.verbose) | ||
|
||
return metrics | ||
|
||
''' | ||
Parse the timer data form a given string. | ||
@param data - String to parse timer data from. | ||
@return - Namedtuple that contains the timer data or -1 in case of an error. | ||
''' | ||
def parseTimer(self, data): | ||
# Compile the regular expression pattern into a regular expression object to | ||
# parse the timer data. | ||
pattern = re.compile(r""" | ||
.*?total_time: (?P<total_time>.*?)s.*? | ||
""", re.VERBOSE|re.MULTILINE|re.DOTALL) | ||
|
||
match = pattern.match(data.decode()) | ||
if not match: | ||
Log.Fatal("Can't parse the data: wrong format") | ||
return -1 | ||
else: | ||
# Create a namedtuple and return the timer data. | ||
timer = collections.namedtuple("timer", ["total_time"]) | ||
|
||
if match.group("total_time").count(".") == 1: | ||
return timer(float(match.group("total_time"))) | ||
else: | ||
return timer(float(match.group("total_time").replace(",", "."))) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
''' | ||
@file dtc.py | ||
Class to benchmark the weka Decision Tree Classifier method. | ||
''' | ||
|
||
import os | ||
import sys | ||
import inspect | ||
|
||
# Import the util path, this method even works if the path contains symlinks to | ||
# modules. | ||
cmd_subfolder = os.path.realpath(os.path.abspath(os.path.join( | ||
os.path.split(inspect.getfile(inspect.currentframe()))[0], "../../util"))) | ||
if cmd_subfolder not in sys.path: | ||
sys.path.insert(0, cmd_subfolder) | ||
|
||
#Import the metrics definitions path. | ||
metrics_folder = os.path.realpath(os.path.abspath(os.path.join( | ||
os.path.split(inspect.getfile(inspect.currentframe()))[0], "../metrics"))) | ||
if metrics_folder not in sys.path: | ||
sys.path.insert(0, metrics_folder) | ||
|
||
from log import * | ||
from profiler import * | ||
from definitions import * | ||
from misc import * | ||
|
||
import shlex | ||
import subprocess | ||
import re | ||
import collections | ||
import numpy as np | ||
|
||
''' | ||
This class implements the Decision Tree Classifier benchmark. | ||
''' | ||
class DTC(object): | ||
|
||
''' | ||
Create the Decision Tree Classifier benchmark instance. | ||
@param dataset - Input dataset to perform DTC on. | ||
@param timeout - The time until the timeout. Default no timeout. | ||
@param path - Path to the mlpack executable. | ||
@param verbose - Display informational messages. | ||
''' | ||
def __init__(self, dataset, timeout=0, path=os.environ["JAVAPATH"], | ||
verbose=True): | ||
self.verbose = verbose | ||
self.dataset = dataset | ||
self.path = path | ||
self.timeout = timeout | ||
|
||
''' | ||
Decision Tree Classifier. If the method has been successfully completed return | ||
the elapsed time in seconds. | ||
@param options - Extra options for the method. | ||
@return - Elapsed time in seconds or a negative value if the method was not | ||
successful. | ||
''' | ||
def RunMetrics(self, options): | ||
Log.Info("Perform DTC.", self.verbose) | ||
|
||
if len(options) > 0: | ||
Log.Fatal("Unknown parameters: " + str(options)) | ||
raise Exception("unknown parameters") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should at least add a minimum leaf size parameter here. I realize that the other decision tree benchmark implementations don't support that, but they should, so we can at least start with this one. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can I please know how do I use the options in the weka code. I see -M number is available to specify but where does this fit in the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In case of J48 cModel = new J48();
cModel.setOptions(weka.core.Utils.splitOptions("-M 2")); You can also add more options e.g.: |
||
|
||
if len(self.dataset) < 2: | ||
Log.Fatal("This method requires two or more datasets.") | ||
return -1 | ||
|
||
# Split the command using shell-like syntax. | ||
cmd = shlex.split("java -classpath " + self.path + "/weka.jar" + | ||
":methods/weka" + " DTC -t " + self.dataset[0] + " -T " + | ||
self.dataset[1]) | ||
|
||
# Run command with the nessecary arguments and return its output as a byte | ||
# string. We have untrusted input so we disable all shell based features. | ||
try: | ||
s = subprocess.check_output(cmd, stderr=subprocess.STDOUT, shell=False, | ||
timeout=self.timeout) | ||
except subprocess.TimeoutExpired as e: | ||
Log.Warn(str(e)) | ||
return -2 | ||
except Exception as e: | ||
Log.Fatal("Could not execute command: " + str(cmd)) | ||
return -1 | ||
|
||
# Datastructure to store the results. | ||
metrics = {} | ||
|
||
# Parse data: runtime. | ||
timer = self.parseTimer(s) | ||
|
||
if timer != -1: | ||
predictions = np.genfromtxt("weka_predicted.csv", delimiter=',') | ||
truelabels = np.genfromtxt(self.dataset[2], delimiter = ',') | ||
|
||
metrics['Runtime'] = timer.total_time | ||
confusionMatrix = Metrics.ConfusionMatrix(truelabels, predictions) | ||
metrics['ACC'] = Metrics.AverageAccuracy(confusionMatrix) | ||
metrics['MCC'] = Metrics.MCCMultiClass(confusionMatrix) | ||
metrics['Precision'] = Metrics.AvgPrecision(confusionMatrix) | ||
metrics['Recall'] = Metrics.AvgRecall(confusionMatrix) | ||
metrics['MSE'] = Metrics.SimpleMeanSquaredError(truelabels, predictions) | ||
Log.Info(("total time: %fs" % (metrics['Runtime'])), self.verbose) | ||
|
||
return metrics | ||
|
||
''' | ||
Parse the timer data form a given string. | ||
@param data - String to parse timer data from. | ||
@return - Namedtuple that contains the timer data or -1 in case of an error. | ||
''' | ||
def parseTimer(self, data): | ||
# Compile the regular expression pattern into a regular expression object to | ||
# parse the timer data. | ||
pattern = re.compile(r""" | ||
.*?total_time: (?P<total_time>.*?)s.*? | ||
""", re.VERBOSE|re.MULTILINE|re.DOTALL) | ||
|
||
match = pattern.match(data.decode()) | ||
if not match: | ||
Log.Fatal("Can't parse the data: wrong format") | ||
return -1 | ||
else: | ||
# Create a namedtuple and return the timer data. | ||
timer = collections.namedtuple("timer", ["total_time"]) | ||
|
||
if match.group("total_time").count(".") == 1: | ||
return timer(float(match.group("total_time"))) | ||
else: | ||
return timer(float(match.group("total_time").replace(",", "."))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this called
LogisticRegression
in the other blocks?