Skip to content

Commit

Permalink
Merge pull request #9674 from cms-btv-pog/TMVAEvaluator_from-CMSSW_7_4_5
Browse files Browse the repository at this point in the history
Added TMVAEvaluator class
  • Loading branch information
cmsbuild committed Jul 2, 2015
2 parents 1a46361 + 357f848 commit a844efa
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 91 deletions.
1 change: 1 addition & 0 deletions CommonTools/Utils/BuildFile.xml
@@ -1,4 +1,5 @@
<use name="FWCore/Utilities"/>
<use name="FWCore/MessageLogger"/>
<use name="boost"/>
<use name="roothistmatrix"/>
<use name="roottmva"/>
Expand Down
33 changes: 33 additions & 0 deletions CommonTools/Utils/interface/TMVAEvaluator.h
@@ -0,0 +1,33 @@
#ifndef CommonTools_Utils_TMVAEvaluator_h
#define CommonTools_Utils_TMVAEvaluator_h

#include <memory>
#include <string>
#include <vector>
#include <map>

#include "TMVA/Reader.h"


class TMVAEvaluator {

public:
TMVAEvaluator();
~TMVAEvaluator();

void initialize(const std::string & options, const std::string & method, const std::string & weightFile,
const std::vector<std::string> & variables, const std::vector<std::string> & spectators);
float evaluate(const std::map<std::string,float> & inputs);

private:
bool mIsInitialized;

std::string mMethod;
std::unique_ptr<TMVA::Reader> mReader;

std::map<std::string,float> mVariables;
std::map<std::string,float> mSpectators;
};

#endif // CommonTools_Utils_TMVAEvaluator_h

74 changes: 74 additions & 0 deletions CommonTools/Utils/src/TMVAEvaluator.cc
@@ -0,0 +1,74 @@
#include "CommonTools/Utils/interface/TMVAEvaluator.h"

#include "CommonTools/Utils/interface/TMVAZipReader.h"
#include "FWCore/MessageLogger/interface/MessageLogger.h"


TMVAEvaluator::TMVAEvaluator() :
mIsInitialized(false)
{
}


TMVAEvaluator::~TMVAEvaluator()
{
}


void TMVAEvaluator::initialize(const std::string & options, const std::string & method, const std::string & weightFile,
const std::vector<std::string> & variables, const std::vector<std::string> & spectators)
{
// initialize the TMVA reader
mReader.reset(new TMVA::Reader(options.c_str()));
mReader->SetVerbose(false);
mMethod = method;

// add input variables
for(std::vector<std::string>::const_iterator it = variables.begin(); it!=variables.end(); ++it)
{
mVariables.insert( std::pair<std::string,float>(*it,0.) );
mReader->AddVariable(it->c_str(), &mVariables.at(*it));
}

// add spectator variables
for(std::vector<std::string>::const_iterator it = spectators.begin(); it!=spectators.end(); ++it)
{
mSpectators.insert( std::pair<std::string,float>(*it,0.) );
mReader->AddSpectator(it->c_str(), &mSpectators.at(*it));
}

// load the TMVA weights
reco::details::loadTMVAWeights(mReader.get(), mMethod.c_str(), weightFile.c_str());

mIsInitialized = true;
}


float TMVAEvaluator::evaluate(const std::map<std::string,float> & inputs)
{
if(!mIsInitialized)
{
edm::LogError("InitializationError") << "TMVAEvaluator not properly initialized.";
return -99.;
}

if( inputs.size() < mVariables.size() )
{
edm::LogError("MissingInputVariable(s)") << "Too few input variables provided (" << inputs.size() << " provided but " << mVariables.size() << " expected).";
return -99.;
}

// set the input variable values
for(std::map<std::string,float>::iterator it = mVariables.begin(); it!=mVariables.end(); ++it)
{
if (inputs.count(it->first)>0)
it->second = inputs.at(it->first);
else
edm::LogError("MissingInputVariable") << "Variable " << it->first << " is missing from the list of input variables. The returned discriminator value might not be sensible.";
}

// evaluate the MVA
float value = mReader->EvaluateMVA(mMethod.c_str());

return value;
}
Expand Up @@ -2,12 +2,12 @@
#define RecoBTag_SecondaryVertex_CandidateBoostedDoubleSecondaryVertexComputer_h

#include "FWCore/ParameterSet/interface/ParameterSet.h"
#include "CommonTools/Utils/interface/TMVAEvaluator.h"
#include "RecoBTau/JetTagComputer/interface/JetTagComputer.h"
#include "DataFormats/JetReco/interface/JetCollection.h"
#include "DataFormats/VertexReco/interface/Vertex.h"
#include "DataFormats/VertexReco/interface/VertexFwd.h"
#include "DataFormats/Candidate/interface/VertexCompositePtrCandidate.h"
#include "RecoBTag/SecondaryVertex/interface/MvaBoostedDoubleSecondaryVertexEstimator.h"
#include "RecoBTag/SecondaryVertex/interface/TrackKinematics.h"

#include "fastjet/PseudoJet.hh"
Expand Down Expand Up @@ -38,7 +38,7 @@ class CandidateBoostedDoubleSecondaryVertexComputer : public JetTagComputer {

edm::FileInPath weightFile_;
mutable std::mutex m_mutex;
[[cms::thread_guard("m_mutex")]] std::unique_ptr<MvaBoostedDoubleSecondaryVertexEstimator> mvaID;
[[cms::thread_guard("m_mutex")]] std::unique_ptr<TMVAEvaluator> mvaID;
};

#endif // RecoBTag_SecondaryVertex_CandidateBoostedDoubleSecondaryVertexComputer_h

This file was deleted.

Expand Up @@ -20,7 +20,14 @@ CandidateBoostedDoubleSecondaryVertexComputer::CandidateBoostedDoubleSecondaryVe
uses(2, "muonTagInfos");
uses(3, "elecTagInfos");

mvaID.reset(new MvaBoostedDoubleSecondaryVertexEstimator(weightFile_.fullPath()));
mvaID.reset(new TMVAEvaluator());

// variable order needs to be the same as in the training
std::vector<std::string> variables({"PFLepton_ptrel", "z_ratio1", "tau_dot", "SV_mass_0", "SV_vtx_EnergyRatio_0",
"SV_vtx_EnergyRatio_1","PFLepton_IP2D", "tau2/tau1", "nSL", "jetNTracksEtaRel"});
std::vector<std::string> spectators({"massGroomed", "flavour", "nbHadrons", "ptGroomed", "etaGroomed"});

mvaID->initialize("Color:Silent:Error", "BDTG", weightFile_.fullPath(), variables, spectators);
}


Expand All @@ -35,7 +42,7 @@ float CandidateBoostedDoubleSecondaryVertexComputer::discriminator(const TagInfo
// default discriminator value
float value = -10.;

// MvaBoostedDoubleSecondaryVertexEstimator is not thread safe
// TMVAEvaluator is not thread safe
std::lock_guard<std::mutex> lock(m_mutex);

// default variable values
Expand Down Expand Up @@ -143,8 +150,20 @@ float CandidateBoostedDoubleSecondaryVertexComputer::discriminator(const TagInfo
}
}

std::map<std::string,float> inputs;
inputs["z_ratio1"] = z_ratio;
inputs["tau_dot"] = tau_dot;
inputs["SV_mass_0"] = SV_mass_0;
inputs["SV_vtx_EnergyRatio_0"] = SV_EnergyRatio_0;
inputs["SV_vtx_EnergyRatio_1"] = SV_EnergyRatio_1;
inputs["jetNTracksEtaRel"] = vertexNTracks;
inputs["PFLepton_ptrel"] = PFLepton_ptrel;
inputs["PFLepton_IP2D"] = PFLepton_IP2D;
inputs["nSL"] = nSL;
inputs["tau2/tau1"] = tau21;

// evaluate the MVA
value = mvaID->mvaValue(PFLepton_ptrel, z_ratio, tau_dot, SV_mass_0, SV_EnergyRatio_0, SV_EnergyRatio_1, PFLepton_IP2D, tau21, nSL, vertexNTracks);
value = mvaID->evaluate(inputs);

// return the final discriminator value
return value;
Expand Down

This file was deleted.

0 comments on commit a844efa

Please sign in to comment.