Skip to content

Commit

Permalink
Merge pull request #1 from kandrosov/deepTauId_v2_work
Browse files Browse the repository at this point in the history
Skeleton for DeepTau v2
  • Loading branch information
MRD2F committed Apr 18, 2019
2 parents 88fc537 + 8287204 commit 188b60a
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 34 deletions.
2 changes: 2 additions & 0 deletions RecoTauTag/RecoTau/interface/DeepTauBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class DeepTauBase : public edm::stream::EDProducer<edm::GlobalCache<DeepTauCache

protected:
edm::EDGetTokenT<TauCollection> tausToken_;
edm::EDGetTokenT<pat::PackedCandidateCollection> pfcand_token_;
edm::EDGetTokenT<reco::VertexCollection> vtx_token_;
std::map<std::string, WPMap> workingPoints_;
OutputCollection outputs_;
const DeepTauCache* cache_;
Expand Down
11 changes: 3 additions & 8 deletions RecoTauTag/RecoTau/plugins/DPFIsolation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,29 +67,26 @@ class DPFIsolation : public deep_tau::DeepTauBase {

explicit DPFIsolation(const edm::ParameterSet& cfg,const deep_tau::DeepTauCache* cache) :
DeepTauBase(cfg, GetOutputs(), cache),
pfcand_token(consumes<pat::PackedCandidateCollection>(cfg.getParameter<edm::InputTag>("pfcands"))),
vtx_token(consumes<reco::VertexCollection>(cfg.getParameter<edm::InputTag>("vertices"))),
graphVersion(cfg.getParameter<unsigned>("version"))
{
const auto& shape = cache_->getGraph().node(0).attr().at("shape").shape();

if(!(graphVersion == 1 || graphVersion == 0 ))
throw cms::Exception("DPFIsolation") << "unknown version of the graph_ file.";
throw cms::Exception("DPFIsolation") << "unknown version of the graph file.";

if(!(shape.dim(1).size() == getNumberOfParticles(graphVersion) && shape.dim(2).size() == GetNumberOfFeatures(graphVersion)))
throw cms::Exception("DPFIsolation") << "number of inputs does not match the expected inputs for the given version";

}

private:
tensorflow::Tensor getPredictions(edm::Event& event, const edm::EventSetup& es,
edm::Handle<TauCollection> taus) override
{
edm::Handle<pat::PackedCandidateCollection> pfcands;
event.getByToken(pfcand_token, pfcands);
event.getByToken(pfcand_token_, pfcands);

edm::Handle<reco::VertexCollection> vertices;
event.getByToken(vtx_token, vertices);
event.getByToken(vtx_token_, vertices);

tensorflow::Tensor tensor(tensorflow::DT_FLOAT, {1,
static_cast<int>(getNumberOfParticles(graphVersion)), static_cast<int>(GetNumberOfFeatures(graphVersion))});
Expand Down Expand Up @@ -394,8 +391,6 @@ class DPFIsolation : public deep_tau::DeepTauBase {
}

private:
edm::EDGetTokenT<pat::PackedCandidateCollection> pfcand_token;
edm::EDGetTokenT<reco::VertexCollection> vtx_token;
unsigned graphVersion;
};

Expand Down
Loading

0 comments on commit 188b60a

Please sign in to comment.