Skip to content

Commit

Permalink
Some autopep8
Browse files Browse the repository at this point in the history
  • Loading branch information
mickypaganini committed May 15, 2016
1 parent bd48d63 commit 61e650e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 18 deletions.
8 changes: 4 additions & 4 deletions bbyy_jet_classifier/utils.py
@@ -1,13 +1,13 @@
import os
import os
import logging

def ensure_directory(directory):
if not os.path.exists(directory):
os.makedirs(directory)
if not os.path.exists(directory):
os.makedirs(directory)

def configure_logging():
logging.basicConfig(format="%(levelname)-8s\033[1m%(name)-21s\033[0m: %(message)s")
logging.addLevelName(logging.WARNING, "\033[1;31m{:8}\033[1;0m".format(logging.getLevelName(logging.WARNING)))
logging.addLevelName(logging.ERROR, "\033[1;35m{:8}\033[1;0m".format(logging.getLevelName(logging.ERROR)))
logging.addLevelName(logging.INFO, "\033[1;32m{:8}\033[1;0m".format(logging.getLevelName(logging.INFO)))
logging.addLevelName(logging.DEBUG, "\033[1;34m{:8}\033[1;0m".format(logging.getLevelName(logging.DEBUG)))
logging.addLevelName(logging.DEBUG, "\033[1;34m{:8}\033[1;0m".format(logging.getLevelName(logging.DEBUG)))
54 changes: 40 additions & 14 deletions run_classifier.py
Expand Up @@ -5,34 +5,59 @@
from bbyy_jet_classifier import strategies, process_data, utils
from bbyy_jet_classifier.plotting import plot_inputs, plot_outputs, plot_roc


def parse_args():
parser = argparse.ArgumentParser(description="Run ML algorithms over ROOT TTree input")
parser.add_argument("--input", type=str, help="input file name", required=True)
parser.add_argument("--correct_tree", metavar="NAME_OF_TREE", type=str, help="name of tree containing correctly identified pairs", default="correct")
parser.add_argument("--incorrect_tree", metavar="NAME_OF_TREE", type=str, help="name of tree containing incorrectly identified pairs", default="incorrect")
parser.add_argument("--exclude", type=str, metavar="VARIABLE_NAME", nargs="+", help="list of variables to exclude", default=[])
parser.add_argument("--ftrain", type=float, help="fraction of events to use for training", default=0.7)
parser.add_argument("--train_location", type=str, help="directory with training info")
parser.add_argument("--strategy", nargs='+', help="strategy to use. Options are: RootTMVA, sklBDT.", default=["RootTMVA"])
parser = argparse.ArgumentParser(
description="Run ML algorithms over ROOT TTree input")

parser.add_argument("--input", type=str,
help="input file name", required=True)

parser.add_argument("--correct_tree", metavar="NAME_OF_TREE", type=str,
help="name of tree containing correctly identified pairs", default="correct")

parser.add_argument("--incorrect_tree", metavar="NAME_OF_TREE", type=str,
help="name of tree containing incorrectly identified pairs", default="incorrect")

parser.add_argument("--exclude", type=str, metavar="VARIABLE_NAME", nargs="+",
help="list of variables to exclude", default=[])

parser.add_argument("--ftrain", type=float,
help="fraction of events to use for training", default=0.7)

parser.add_argument("--train_location", type=str,
help="directory with training info")

parser.add_argument("--strategy", nargs='+',
help="strategy to use. Options are: RootTMVA, sklBDT.", default=["RootTMVA"])

args = parser.parse_args()
return args


def check_args(args):
'''
Check the logic of the input arguments
'''
if ((args.ftrain < 0) or (args.ftrain > 1)):
raise ValueError("ftrain can only be a float between 0.0 and 1.0")

if ((args.ftrain == 0) and (args.train_location == None)):
raise ValueError("Training folder required when testing on 100% of the input file to specify which classifier to load. Pass --train_location.")
raise ValueError(
"Training folder required when testing on 100% of the input file to specify which classifier to load. \
Pass --train_location.")

if ((args.ftrain > 0) and (args.train_location != None)):
raise ValueError("Training location is only a valid argument when ftrain == 0, because if you are using {}% of your input data for training, you should not be testing on a separate pre-trained classifier.".format(100*args.ftrain))
raise ValueError("Training location is only a valid argument when ftrain == 0, \
because if you are using {}% of your input data for training, \
you should not be testing on a separate pre-trained classifier.".format(100 * args.ftrain))


if __name__ == "__main__":

# -- Configure logging
utils.configure_logging()
logger = logging.getLogger("RunClassifier")
logger = logging.getLogger("RunClassifier")

# -- Parse arguments
args = parse_args()
Expand All @@ -47,8 +72,8 @@ def check_args(args):
train_location = args.train_location if args.train_location is not None else fileID

# -- Load in root files and return literally everything about the data
classification_variables, variable2type, train_data, test_data, mHmatch_test, pThigh_test = \
process_data.load(args.input, args.correct_tree, args.incorrect_tree, args.exclude, args.ftrain)
classification_variables, variable2type, train_data, test_data, mHmatch_test, pThigh_test = process_data.load(
args.input, args.correct_tree, args.incorrect_tree, args.exclude, args.ftrain)

#-- Plot input distributions
utils.ensure_directory(os.path.join(fileID, "classification_variables"))
Expand All @@ -69,7 +94,8 @@ def check_args(args):
# -- Train classifier
ML_strategy.train(train_data, classification_variables, variable2type)

# -- Plot the classifier output as tested on the training set (only useful if you care to check the performance on the training set)
# -- Plot the classifier output as tested on the training set
# -- (only useful if you care to check the performance on the training set)
yhat_train = ML_strategy.test(train_data, classification_variables, process="training", train_location=fileID)
plot_outputs.classifier_output(ML_strategy, yhat_train, train_data, process="training", fileID=fileID)

Expand Down

0 comments on commit 61e650e

Please sign in to comment.