Skip to content

Commit

Permalink
Restructure output directories to test using trainings from SM sample
Browse files Browse the repository at this point in the history
  • Loading branch information
mickypaganini committed May 13, 2016
1 parent 0dba521 commit 344f306
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 24 deletions.
10 changes: 5 additions & 5 deletions bbyy_jet_classifier/plotting/plot_roc.py
Expand Up @@ -53,15 +53,15 @@ def signal_eff_bkg_rejection(ML_strategy, mHmatch_test, pThigh_test, yhat_test,
cPickle.dump(discrim_dict[ML_strategy.name], open(os.path.join(ML_strategy.output_directory, "pickle", "{}_ROC.pkl".format(ML_strategy.name)), "wb"), cPickle.HIGHEST_PROTOCOL)


def roc_comparison():
def roc_comparison(fileID):
"""
Definition:
------------
Quick script to load and compare ROC curves produced from different classifiers
"""
TMVABDT = cPickle.load(open(os.path.join("output", "RootTMVA", "pickle", "root_tmva_ROC.pkl"), "rb"))
sklBDT = cPickle.load(open(os.path.join("output", "sklBDT", "pickle", "skl_BDT_ROC.pkl"), "rb"))
dots = cPickle.load(open(os.path.join("output", "sklBDT", "pickle", "old_strategies_dict.pkl"), "rb"))
TMVABDT = cPickle.load(open(os.path.join(fileID, "RootTMVA", "pickle", "root_tmva_ROC.pkl"), "rb"))
sklBDT = cPickle.load(open(os.path.join(fileID, "sklBDT", "pickle", "skl_BDT_ROC.pkl"), "rb"))
dots = cPickle.load(open(os.path.join(fileID, "sklBDT", "pickle", "old_strategies_dict.pkl"), "rb"))

sklBDT["color"] = "green"
curves = {"sklBDT": sklBDT, "RootTMVA": TMVABDT}
Expand All @@ -74,4 +74,4 @@ def roc_comparison():
plt.plot(dots["eff_pT_signal"], 1.0 / dots["eff_pT_bkg"], marker="o", color="b", label=r"Highest p$_{T}$", linewidth=0) # add point for "pThigh" strategy
plt.legend()
plot_atlas.use_atlas_labels(plt.axes())
figure.savefig(os.path.join("output", "ROCcomparison.pdf"))
figure.savefig(os.path.join(fileID, "ROCcomparison.pdf"))
4 changes: 2 additions & 2 deletions bbyy_jet_classifier/strategies/root_tmva.py
Expand Up @@ -56,7 +56,7 @@ def train(self, train_data, classification_variables, variable_dict):
shutil.rmtree(os.path.join(self.output_directory, "weights"))
shutil.move("weights", self.output_directory)

def test(self, data, classification_variables, process):
def test(self, data, classification_variables, process, train_location):
"""
Definition:
-----------
Expand All @@ -83,7 +83,7 @@ def test(self, data, classification_variables, process):
reader.AddVariable(v_name, array.array("f", [0]))

# -- Load TMVA results
reader.BookMVA("BDT", os.path.join(self.output_directory, "weights", "TMVAClassification_BDT.weights.xml"))
reader.BookMVA("BDT", os.path.join(train_location, self.default_output_subdir, "weights", "TMVAClassification_BDT.weights.xml"))

yhat = evaluate_reader(reader, "BDT", data['X'])
return yhat
8 changes: 3 additions & 5 deletions bbyy_jet_classifier/strategies/skl_BDT.py
Expand Up @@ -37,7 +37,7 @@ def train(self, train_data, classification_variables, variable_dict):
self.ensure_directory(os.path.join(self.output_directory, "pickle"))
joblib.dump(classifier, os.path.join(self.output_directory, "pickle", "sklBDT_clf.pkl"), protocol=cPickle.HIGHEST_PROTOCOL)

def test(self, data, classification_variables, process):
def test(self, data, classification_variables, process, train_location):
"""
Definition:
-----------
Expand All @@ -51,6 +51,7 @@ def test(self, data, classification_variables, process):
w = array of dim (# examples) with event weights
process = string to identify whether we are evaluating performance on the train or test set, usually "training" or "testing"
classification_variables = list of names of variables used for classification
train_location = string that specifies the fileID of the sample to use as a training (e.g. 'SM_merged' or 'X350_hh')
Returns:
--------
Expand All @@ -59,14 +60,11 @@ def test(self, data, classification_variables, process):
logging.getLogger("sklBDT.test").info("Evaluating performance...")

# -- Load scikit classifier
classifier = joblib.load(os.path.join(self.output_directory, 'pickle', 'sklBDT_clf.pkl'))
classifier = joblib.load(os.path.join(train_location, self.default_output_subdir, 'pickle', 'sklBDT_clf.pkl'))

# -- Get classifier predictions
yhat = classifier.predict_proba(data['X'])[:, 1]

# -- Load scikit classifier
classifier = joblib.load(os.path.join(self.output_directory, "pickle", "sklBDT_clf.pkl"))

# -- Log classification scores
logging.getLogger("sklBDT.test").info("{} accuracy = {:.2f}%".format(process, 100 * classifier.score(data['X'], data['y'], sample_weight=data['w'])))
for output_line in classification_report(data['y'], classifier.predict(data['X']), target_names=["correct", "incorrect"], sample_weight=data['w']).splitlines():
Expand Down
32 changes: 20 additions & 12 deletions run_classifier.py
Expand Up @@ -5,24 +5,32 @@
from bbyy_jet_classifier import strategies, process_data
from bbyy_jet_classifier.plotting import plot_inputs, plot_outputs, plot_roc

if __name__ == "__main__":
# -- Configure logging
def configure_logging():
logging.basicConfig(format="%(levelname)-8s\033[1m%(name)-21s\033[0m: %(message)s")
logging.addLevelName(logging.WARNING, "\033[1;31m{:8}\033[1;0m".format(logging.getLevelName(logging.WARNING)))
logging.addLevelName(logging.ERROR, "\033[1;35m{:8}\033[1;0m".format(logging.getLevelName(logging.ERROR)))
logging.addLevelName(logging.INFO, "\033[1;32m{:8}\033[1;0m".format(logging.getLevelName(logging.INFO)))
logging.addLevelName(logging.DEBUG, "\033[1;34m{:8}\033[1;0m".format(logging.getLevelName(logging.DEBUG)))

# -- Parse arguments
def parse_args():
parser = argparse.ArgumentParser(description="Run ML algorithms over ROOT TTree input")
parser.add_argument("--input", type=str, help="input file name", required=True)
parser.add_argument("--output", type=str, help="output directory", default="output")
#parser.add_argument("--output", type=str, help="output directory", default="output")
parser.add_argument("--train_location", type=str, help="directory with training info", default="SM_merged")
parser.add_argument("--correct_tree", metavar="NAME_OF_TREE", type=str, help="name of tree containing correctly identified pairs", default="correct")
parser.add_argument("--incorrect_tree", metavar="NAME_OF_TREE", type=str, help="name of tree containing incorrectly identified pairs", default="incorrect")
parser.add_argument("--exclude", type=str, metavar="VARIABLE_NAME", nargs="+", help="list of variables to exclude", default=[])
parser.add_argument("--ftrain", type=float, help="fraction of events to use for training", default=0.7)
parser.add_argument("--strategy", nargs='+', help="strategy to use. Options are: RootTMVA, sklBDT.", default="RootTMVA")
args = parser.parse_args()
return args

if __name__ == "__main__":

# -- Configure logging
configure_logging()
# -- Parse arguments
args = parse_args()

# -- Check that input file exists
if not os.path.isfile(args.input):
Expand All @@ -34,16 +42,16 @@
process_data.load(args.input, args.correct_tree, args.incorrect_tree, args.exclude, args.ftrain)

#-- Plot input distributions
strategies.BaseStrategy.ensure_directory(os.path.join(args.output, "classification_variables"))
plot_inputs.input_distributions(classification_variables, train_data, test_data, directory=os.path.join(args.output, "classification_variables"))
strategies.BaseStrategy.ensure_directory(os.path.join(fileID, "classification_variables"))
plot_inputs.input_distributions(classification_variables, train_data, test_data, directory=os.path.join(fileID, "classification_variables"))

# -- Sequentially evaluate all the desired strategies on the same train/test sample
for strategy_name in args.strategy:

# -- Construct dictionary of available strategies
if not strategy_name in strategies.__dict__.keys():
raise AttributeError("{} is not a valid strategy".format(args.strategy))
ML_strategy = getattr(strategies, strategy_name)(args.output)
ML_strategy = getattr(strategies, strategy_name)(fileID)

# -- Training!
if args.ftrain > 0:
Expand All @@ -53,7 +61,7 @@
ML_strategy.train(train_data, classification_variables, variable_dict)

# -- Plot the classifier output as tested on the training set (only useful if you care to check the performance on the training set)
yhat_train = ML_strategy.test(train_data, classification_variables, process="training")
yhat_train = ML_strategy.test(train_data, classification_variables, process="training", train_location=args.train_location)
plot_outputs.classifier_output(ML_strategy, yhat_train, train_data, process="training", fileID=fileID)

else:
Expand All @@ -62,7 +70,7 @@
# -- Testing!
if args.ftrain < 1:
# -- Test classifier
yhat_test = ML_strategy.test(test_data, classification_variables, process="testing")
yhat_test = ML_strategy.test(test_data, classification_variables, process="testing", train_location=args.train_location)

# -- Plot output testing distributions from classifier and old strategies
plot_outputs.classifier_output(ML_strategy, yhat_test, test_data, process="testing", fileID=args.input.replace(".root", "").split("/")[-1])
Expand All @@ -76,6 +84,6 @@
else:
logging.getLogger("RunClassifier").info("100% of the sample was used for training -- no independent testing can be performed.")

# -- if there is more than one strategy, plot the ROC comparison
if len(args.strategy) > 1:
plot_roc.roc_comparison()
# -- if there is more than one strategy and we aren't only training, plot the ROC comparison
if (len(args.strategy) > 1 and (args.ftrain < 1)):
plot_roc.roc_comparison(fileID)

0 comments on commit 344f306

Please sign in to comment.