Skip to content

Commit

Permalink
Add function to refine FastSim DeepJet discriminators
Browse files Browse the repository at this point in the history
  • Loading branch information
wolfmor committed Jan 17, 2023
1 parent a5ca8f8 commit 4ca1787
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 3 deletions.
60 changes: 60 additions & 0 deletions PhysicsTools/NanoAOD/python/jetsAK4_CHS_cff.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,66 @@ def nanoAOD_addDeepInfoAK4CHS(process,addDeepBTag,addDeepFlavour):
## DeepInfoAK4CHS:End
#################################################

#
# ML-based FastSim refinement
#
from Configuration.Eras.Modifier_fastSim_cff import fastSim
def nanoAOD_refineFastSim_bTagDeepFlav(process):

fastSim.toModify( process.jetTable.variables,
btagDeepFlavBunrefined = process.jetTable.variables.btagDeepFlavB.clone(),
btagDeepFlavCvBunrefined = process.jetTable.variables.btagDeepFlavCvB.clone(),
btagDeepFlavCvLunrefined = process.jetTable.variables.btagDeepFlavCvL.clone(),
btagDeepFlavQGunrefined = process.jetTable.variables.btagDeepFlavQG.clone(),
)

fastSim.toModify( process.jetTable.variables,
btagDeepFlavB = None,
btagDeepFlavCvB = None,
btagDeepFlavCvL = None,
btagDeepFlavQG = None,
)

fastSim.toModify( process.jetTable.externalVariables,
btagDeepFlavB = ExtVar(cms.InputTag("btagDeepFlavRefineNN:btagDeepFlavBrefined"), float, doc="DeepJet b+bb+lepb tag discriminator", precision=10),
btagDeepFlavCvB = ExtVar(cms.InputTag("btagDeepFlavRefineNN:btagDeepFlavCvBrefined"), float, doc="DeepJet c vs b+bb+lepb discriminator", precision=10),
btagDeepFlavCvL = ExtVar(cms.InputTag("btagDeepFlavRefineNN:btagDeepFlavCvLrefined"), float, doc="DeepJet c vs uds+g discriminator", precision=10),
btagDeepFlavQG = ExtVar(cms.InputTag("btagDeepFlavRefineNN:btagDeepFlavQGrefined"), float, doc="DeepJet g vs uds discriminator", precision=10),
)

process.btagDeepFlavRefineNN= cms.EDProducer("JetBaseMVAValueMapProducer",
backend = cms.string("ONNX"),
batch_eval = cms.bool(True),
disableONNXGraphOpt = cms.bool(True),

src = cms.InputTag("linkedObjects","jets"),

weightFile=cms.FileInPath("PhysicsTools/NanoAOD/data/btagDeepFlavRefineNN_CHS.onnx"),
name = cms.string("btagDeepFlavRefineNN"),

isClassifier = cms.bool(False),
variablesOrder = cms.vstring(["GenJet_pt","GenJet_eta","Jet_hadronFlavour",
"Jet_btagDeepFlavB","Jet_btagDeepFlavCvB","Jet_btagDeepFlavCvL","Jet_btagDeepFlavQG"]),
variables = cms.PSet(
GenJet_pt = cms.string("?genJetFwdRef().backRef().isNonnull()?genJetFwdRef().backRef().pt():pt"),
GenJet_eta = cms.string("?genJetFwdRef().backRef().isNonnull()?genJetFwdRef().backRef().eta():eta"),
Jet_hadronFlavour = cms.string("hadronFlavour()"),
Jet_btagDeepFlavB = cms.string("bDiscriminator('pfDeepFlavourJetTags:probb')+bDiscriminator('pfDeepFlavourJetTags:probbb')+bDiscriminator('pfDeepFlavourJetTags:problepb')"),
Jet_btagDeepFlavCvB = cms.string("?(bDiscriminator('pfDeepFlavourJetTags:probc')+bDiscriminator('pfDeepFlavourJetTags:probb')+bDiscriminator('pfDeepFlavourJetTags:probbb')+bDiscriminator('pfDeepFlavourJetTags:problepb'))>0?bDiscriminator('pfDeepFlavourJetTags:probc')/(bDiscriminator('pfDeepFlavourJetTags:probc')+bDiscriminator('pfDeepFlavourJetTags:probb')+bDiscriminator('pfDeepFlavourJetTags:probbb')+bDiscriminator('pfDeepFlavourJetTags:problepb')):-1"),
Jet_btagDeepFlavCvL = cms.string("?(bDiscriminator('pfDeepFlavourJetTags:probc')+bDiscriminator('pfDeepFlavourJetTags:probuds')+bDiscriminator('pfDeepFlavourJetTags:probg'))>0?bDiscriminator('pfDeepFlavourJetTags:probc')/(bDiscriminator('pfDeepFlavourJetTags:probc')+bDiscriminator('pfDeepFlavourJetTags:probuds')+bDiscriminator('pfDeepFlavourJetTags:probg')):-1"),
Jet_btagDeepFlavQG = cms.string("?(bDiscriminator('pfDeepFlavourJetTags:probg')+bDiscriminator('pfDeepFlavourJetTags:probuds'))>0?bDiscriminator('pfDeepFlavourJetTags:probg')/(bDiscriminator('pfDeepFlavourJetTags:probg')+bDiscriminator('pfDeepFlavourJetTags:probuds')):-1"),
),
inputTensorName = cms.string("input"),
outputTensorName = cms.string("output"),
outputNames = cms.vstring(["btagDeepFlavBrefined","btagDeepFlavCvBrefined","btagDeepFlavCvLrefined","btagDeepFlavQGrefined"]),
outputFormulas = cms.vstring(["at(0)","at(1)","at(2)","at(3)"]),
)

fastSim.toModify(process.jetTablesTask, process.jetTablesTask.add(process.btagDeepFlavRefineNN))

return process


################################################################################
# JETS FOR MET type1
################################################################################
Expand Down
15 changes: 12 additions & 3 deletions PhysicsTools/PatAlgos/interface/BaseMVAValueMapProducer.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,19 @@

class BaseMVACache {
public:
BaseMVACache(const std::string& model_path, const std::string& backend) {
BaseMVACache(const std::string& model_path, const std::string& backend, const bool disableONNXGraphOpt) {
if (backend == "TF") {
graph_.reset(tensorflow::loadGraphDef(model_path));
tf_session_ = tensorflow::createSession(graph_.get());
} else if (backend == "ONNX") {
ort_ = std::make_unique<cms::Ort::ONNXRuntime>(model_path);
if (disableONNXGraphOpt) {
Ort::SessionOptions sess_opts;
sess_opts = cms::Ort::ONNXRuntime::defaultSessionOptions();
sess_opts.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL);
ort_ = std::make_unique<cms::Ort::ONNXRuntime>(model_path, &sess_opts);
} else {
ort_ = std::make_unique<cms::Ort::ONNXRuntime>(model_path);
}
}
}
~BaseMVACache() { tensorflow::closeSession(tf_session_); }
Expand Down Expand Up @@ -270,7 +277,8 @@ void BaseMVAValueMapProducer<T>::produce(edm::Event& iEvent, const edm::EventSet
template <typename T>
std::unique_ptr<BaseMVACache> BaseMVAValueMapProducer<T>::initializeGlobalCache(const edm::ParameterSet& cfg) {
return std::make_unique<BaseMVACache>(cfg.getParameter<edm::FileInPath>("weightFile").fullPath(),
cfg.getParameter<std::string>("backend"));
cfg.getParameter<std::string>("backend"),
cfg.getParameter<bool>("disableONNXGraphOpt"));
}

template <typename T>
Expand All @@ -295,6 +303,7 @@ edm::ParameterSetDescription BaseMVAValueMapProducer<T>::getDescription() {
desc.add<std::vector<std::string>>("outputFormulas", std::vector<std::string>())
->setComment("Formulas to be used to post process the output");
desc.add<bool>("batch_eval", false)->setComment("Run inference in batch instead of per-object");
desc.add<bool>("disableONNXGraphOpt", false)->setComment("Disable ONNX runtime graph optimization");

return desc;
}
Expand Down

0 comments on commit 4ca1787

Please sign in to comment.