Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
371 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
name: "EGM_DRN" | ||
platform: "pytorch_libtorch" | ||
max_batch_size: 0 | ||
input [ | ||
{ | ||
name: "x__0" | ||
data_type: TYPE_FP32 | ||
dims: [-1, 4] | ||
}, | ||
{ | ||
name: "batch__1" | ||
data_type: TYPE_INT64 | ||
dims: [-1] | ||
}, | ||
{ | ||
name: "graphx__2" | ||
data_type: TYPE_FP32 | ||
dims: [-1] | ||
} | ||
] | ||
output [ | ||
{ | ||
name: "dscb__0" | ||
data_type: TYPE_FP32 | ||
dims: [-1, 6] | ||
}] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
<use name="FWCore/Framework"/> | ||
<use name="FWCore/ParameterSet"/> | ||
<use name="FWCore/PluginManager"/> | ||
<use name="HeterogeneousCore/SonicTriton"/> | ||
<flags EDM_PLUGIN="1"/> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
#include "HeterogeneousCore/SonicTriton/interface/TritonEDProducer.h" | ||
|
||
#include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h" | ||
#include "FWCore/ParameterSet/interface/ParameterSetDescription.h" | ||
#include "FWCore/Framework/interface/MakerMacros.h" | ||
|
||
#include <sstream> | ||
#include <string> | ||
#include <vector> | ||
#include <random> | ||
|
||
class TritonGraphHelper { | ||
public: | ||
TritonGraphHelper(edm::ParameterSet const& cfg) | ||
: nodeMin_(cfg.getParameter<unsigned>("nodeMin")), | ||
nodeMax_(cfg.getParameter<unsigned>("nodeMax")), | ||
brief_(cfg.getParameter<bool>("brief")) {} | ||
void makeInput(edm::Event const& iEvent, TritonInputMap& iInput, const std::string& debugName) const { | ||
|
||
//get event-based seed for RNG | ||
unsigned int runNum_uint = static_cast<unsigned int>(iEvent.id().run()); | ||
unsigned int lumiNum_uint = static_cast<unsigned int>(iEvent.id().luminosityBlock()); | ||
unsigned int evNum_uint = static_cast<unsigned int>(iEvent.id().event()); | ||
std::uint32_t seed = (lumiNum_uint << 10) + (runNum_uint << 20) + evNum_uint; | ||
std::mt19937 rng(seed); | ||
|
||
std::uniform_int_distribution<int> randint1(nodeMin_, nodeMax_); | ||
int nnodes = randint1(rng); | ||
|
||
//set shapes | ||
auto& input1 = iInput.at("x__0"); | ||
input1.setShape(0, nnodes); | ||
auto data1 = input1.allocate<float>(); | ||
auto& vdata1 = (*data1)[0]; | ||
|
||
auto& input2 = iInput.at("batch__1"); | ||
input2.setShape(0, nnodes); | ||
auto data2 = input2.allocate<int64_t>(); | ||
auto& vdata2 = (*data2)[0]; | ||
|
||
auto& input3 = iInput.at("graphx__2"); | ||
input3.setShape(0, 2); | ||
auto data3 = input3.allocate<float>(); | ||
auto& vdata3 = (*data3)[0]; | ||
|
||
//fill | ||
std::normal_distribution<float> randx(-10, 4); | ||
for (unsigned i = 0; i < input1.sizeShape(); ++i) { | ||
vdata1.push_back(randx(rng)); | ||
} | ||
|
||
for (unsigned i = 0; i < input2.sizeShape(); ++i) { | ||
vdata2.push_back(0); | ||
} | ||
|
||
for (unsigned i = 0; i < input3.sizeShape(); ++i) { | ||
vdata3.push_back(randx(rng)); | ||
} | ||
|
||
// convert to server format | ||
input1.toServer(data1); | ||
input2.toServer(data2); | ||
input3.toServer(data3); | ||
|
||
edm::LogInfo(debugName) << "input X shape: " << input1.shape()[0] << ", " << input1.shape()[1]; | ||
edm::LogInfo(debugName) << "input batch shape: " << input2.shape()[0]; | ||
edm::LogInfo(debugName) << "input graphx shape: " << input3.shape()[0]; | ||
} | ||
void makeOutput(const TritonOutputMap& iOutput, const std::string& debugName) const { | ||
edm::LogInfo (debugName) << "top of output "; | ||
//check the results | ||
const auto& output1 = iOutput.begin()->second; | ||
// convert from server format | ||
const auto& tmp = output1.fromServer<float>(); | ||
//if (brief_) | ||
edm::LogInfo(debugName) << "output shape: " << output1.shape()[0] << "," << output1.shape()[1]; | ||
//else { | ||
// edm::LogInfo msg(debugName); | ||
// for (int i = 0; i < output1.shape()[0]; ++i) { | ||
// msg << "output " << i << ": "; | ||
// for (int j = 0; j < output1.shape()[1]; ++j) { | ||
// msg << tmp[0][output1.shape()[1] * i + j] << " "; | ||
// } | ||
// msg << "\n"; | ||
// } | ||
//} | ||
} | ||
static void fillPSetDescription(edm::ParameterSetDescription& desc) { | ||
desc.add<unsigned>("nodeMin", 1); | ||
desc.add<unsigned>("nodeMax", 200); | ||
desc.add<bool>("brief", false); | ||
} | ||
|
||
private: | ||
//members | ||
unsigned nodeMin_, nodeMax_; | ||
bool brief_; | ||
}; | ||
|
||
class DRNProducer : public TritonEDProducer<> { | ||
public: | ||
explicit DRNProducer(edm::ParameterSet const& cfg) | ||
: TritonEDProducer<>(cfg, "DRNProducer"), helper_(cfg) {} | ||
void acquire(edm::Event const& iEvent, edm::EventSetup const& iSetup, Input& iInput) override { | ||
helper_.makeInput(iEvent, iInput, debugName_); | ||
} | ||
void produce(edm::Event& iEvent, edm::EventSetup const& iSetup, Output const& iOutput) override { | ||
helper_.makeOutput(iOutput, debugName_); | ||
} | ||
~DRNProducer() override = default; | ||
|
||
static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) { | ||
edm::ParameterSetDescription desc; | ||
TritonClient::fillPSetDescription(desc); | ||
TritonGraphHelper::fillPSetDescription(desc); | ||
//to ensure distinct cfi names | ||
descriptions.addWithDefaultLabel(desc); | ||
} | ||
|
||
private: | ||
//member | ||
TritonGraphHelper helper_; | ||
}; | ||
|
||
DEFINE_FWK_MODULE(DRNProducer); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
<bin file="test_catch2_*.cc" name="testProgressionEGM_DRNTP"> | ||
<use name="FWCore/TestProcessor"/> | ||
<use name="catch2"/> | ||
</bin> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
#include "catch.hpp" | ||
#include "FWCore/TestProcessor/interface/TestProcessor.h" | ||
#include "FWCore/Utilities/interface/Exception.h" | ||
|
||
static constexpr auto s_tag = "[EGM_DRN]"; | ||
|
||
TEST_CASE("Standard checks of EGM_DRN", s_tag) { | ||
const std::string baseConfig{ | ||
R"_(from FWCore.TestProcessor.TestProcess import * | ||
process = TestProcess() | ||
process.toTest = cms.EDProducer("EGM_DRN" | ||
#necessary configuration parameters | ||
) | ||
process.moduleToTest(process.toTest) | ||
)_"}; | ||
|
||
edm::test::TestProcessor::Config config{baseConfig}; | ||
SECTION("base configuration is OK") { REQUIRE_NOTHROW(edm::test::TestProcessor(config)); } | ||
|
||
SECTION("No event data") { | ||
edm::test::TestProcessor tester(config); | ||
|
||
REQUIRE_THROWS_AS(tester.test(), cms::Exception); | ||
//If the module does not throw when given no data, substitute | ||
//REQUIRE_NOTHROW for REQUIRE_THROWS_AS | ||
} | ||
|
||
SECTION("beginJob and endJob only") { | ||
edm::test::TestProcessor tester(config); | ||
|
||
REQUIRE_NOTHROW(tester.testBeginAndEndJobOnly()); | ||
} | ||
|
||
SECTION("Run with no LuminosityBlocks") { | ||
edm::test::TestProcessor tester(config); | ||
|
||
REQUIRE_NOTHROW(tester.testRunWithNoLuminosityBlocks()); | ||
} | ||
|
||
SECTION("LuminosityBlock with no Events") { | ||
edm::test::TestProcessor tester(config); | ||
|
||
REQUIRE_NOTHROW(tester.testLuminosityBlockWithNoEvents()); | ||
} | ||
} | ||
|
||
//Add additional TEST_CASEs to exercise the modules capabilities |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
#define CATCH_CONFIG_MAIN | ||
#include "catch.hpp" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from FWCore.ParameterSet.VarParsing import VarParsing | ||
import FWCore.ParameterSet.Config as cms | ||
import os, sys, json | ||
|
||
# module/model correspondence | ||
models = { | ||
"TritonImageProducer": ["inception_graphdef", "densenet_onnx"], | ||
"TritonGraphProducer": ["gat_test"], | ||
"TritonGraphFilter": ["gat_test"], | ||
"TritonGraphAnalyzer": ["gat_test"], | ||
"DRNProducer" : ["EGM_DRN"] | ||
} | ||
|
||
# other choices | ||
allowed_modes = ["Async","PseudoAsync","Sync"] | ||
allowed_compression = ["none","deflate","gzip"] | ||
allowed_devices = ["auto","cpu","gpu"] | ||
|
||
options = VarParsing() | ||
options.register("maxEvents", -1, VarParsing.multiplicity.singleton, VarParsing.varType.int, "Number of events to process (-1 for all)") | ||
options.register("serverName", "default", VarParsing.multiplicity.singleton, VarParsing.varType.string, "name for server (used internally)") | ||
options.register("address", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "server address") | ||
options.register("port", 8001, VarParsing.multiplicity.singleton, VarParsing.varType.int, "server port") | ||
options.register("timeout", 30, VarParsing.multiplicity.singleton, VarParsing.varType.int, "timeout for requests") | ||
options.register("params", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "json file containing server address/port") | ||
options.register("threads", 1, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of threads") | ||
options.register("streams", 0, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of streams") | ||
options.register("modules", "TritonGraphProducer", VarParsing.multiplicity.list, VarParsing.varType.string, "list of modules to run (choices: {})".format(', '.join(models))) | ||
options.register("models","gat_test", VarParsing.multiplicity.list, VarParsing.varType.string, "list of models (same length as modules, or just 1 entry if all modules use same model)") | ||
options.register("mode","Async", VarParsing.multiplicity.singleton, VarParsing.varType.string, "mode for client (choices: {})".format(', '.join(allowed_modes))) | ||
options.register("verbose", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable verbose output") | ||
options.register("brief", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "briefer output for graph modules") | ||
options.register("unittest", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "unit test mode: reduce input sizes") | ||
options.register("testother", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "also test gRPC communication if shared memory enabled, or vice versa") | ||
options.register("shm", True, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable shared memory") | ||
options.register("compression", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "enable I/O compression (choices: {})".format(', '.join(allowed_compression))) | ||
options.register("ssl", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable SSL authentication for server communication") | ||
options.register("device","auto", VarParsing.multiplicity.singleton, VarParsing.varType.string, "specify device for fallback server (choices: {})".format(', '.join(allowed_devices))) | ||
options.register("docker", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "use Docker for fallback server") | ||
options.register("tries", 0, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of retries for failed request") | ||
options.parseArguments() | ||
|
||
if len(options.params)>0: | ||
with open(options.params,'r') as pfile: | ||
pdict = json.load(pfile) | ||
options.address = pdict["address"] | ||
options.port = int(pdict["port"]) | ||
print("server = "+options.address+":"+str(options.port)) | ||
|
||
# check models and modules | ||
if len(options.modules)!=len(options.models): | ||
# assigning to VarParsing.multiplicity.list actually appends to existing value(s) | ||
if len(options.models)==1: options.models = [options.models[0]]*(len(options.modules)-1) | ||
else: raise ValueError("Arguments for modules and models must have same length") | ||
for im,module in enumerate(options.modules): | ||
if module not in models: | ||
raise ValueError("Unknown module: {}".format(module)) | ||
model = options.models[im] | ||
if model not in models[module]: | ||
raise ValueError("Unsupported model {} for module {}".format(model,module)) | ||
|
||
# check modes | ||
if options.mode not in allowed_modes: | ||
raise ValueError("Unknown mode: {}".format(options.mode)) | ||
|
||
# check compression | ||
if len(options.compression)>0 and options.compression not in allowed_compression: | ||
raise ValueError("Unknown compression setting: {}".format(options.compression)) | ||
|
||
# check devices | ||
options.device = options.device.lower() | ||
if options.device not in allowed_devices: | ||
raise ValueError("Unknown device: {}".format(options.device)) | ||
|
||
from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton | ||
process = cms.Process('tritonTest',enableSonicTriton) | ||
|
||
process.load("HeterogeneousCore.SonicTriton.TritonService_cff") | ||
|
||
process.maxEvents = cms.untracked.PSet( input = cms.untracked.int32(options.maxEvents) ) | ||
|
||
process.source = cms.Source("EmptySource") | ||
|
||
process.TritonService.verbose = options.verbose | ||
process.TritonService.fallback.verbose = options.verbose | ||
process.TritonService.fallback.useDocker = options.docker | ||
if options.device != "auto": | ||
process.TritonService.fallback.useGPU = options.device=="gpu" | ||
if len(options.address)>0: | ||
process.TritonService.servers.append( | ||
cms.PSet( | ||
name = cms.untracked.string(options.serverName), | ||
address = cms.untracked.string(options.address), | ||
port = cms.untracked.uint32(options.port), | ||
useSsl = cms.untracked.bool(options.ssl), | ||
rootCertificates = cms.untracked.string(""), | ||
privateKey = cms.untracked.string(""), | ||
certificateChain = cms.untracked.string(""), | ||
) | ||
) | ||
|
||
# Let it run | ||
process.p = cms.Path() | ||
|
||
modules = { | ||
"Producer": cms.EDProducer, | ||
"Filter": cms.EDFilter, | ||
"Analyzer": cms.EDAnalyzer, | ||
} | ||
|
||
keepMsgs = ['TritonClient','TritonService'] | ||
|
||
for im,module in enumerate(options.modules): | ||
model = options.models[im] | ||
Module = [obj for name,obj in modules.items() if name in module][0] | ||
setattr(process, module, | ||
Module(module, | ||
Client = cms.PSet( | ||
mode = cms.string(options.mode), | ||
preferredServer = cms.untracked.string(""), | ||
timeout = cms.untracked.uint32(options.timeout), | ||
modelName = cms.string(model), | ||
modelVersion = cms.string(""), | ||
modelConfigPath = cms.FileInPath("Progression/EGM_DRN/data/models/{}/config.pbtxt".format(model)), | ||
verbose = cms.untracked.bool(options.verbose), | ||
allowedTries = cms.untracked.uint32(options.tries), | ||
useSharedMemory = cms.untracked.bool(options.shm), | ||
compression = cms.untracked.string(options.compression), | ||
) | ||
) | ||
) | ||
processModule = getattr(process, module) | ||
processModule.nodeMin = cms.uint32(1) | ||
processModule.nodeMax = cms.uint32(200) | ||
processModule.brief = cms.bool(options.brief) | ||
process.p += processModule | ||
keepMsgs.extend([module,module+':TritonClient']) | ||
if options.testother: | ||
# clone modules to test both gRPC and shared memory | ||
_module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM" | ||
setattr(process, _module2, | ||
processModule.clone( | ||
Client = dict(useSharedMemory = not processModule.Client.useSharedMemory) | ||
) | ||
) | ||
processModule2 = getattr(process, _module2) | ||
process.p += processModule2 | ||
|
||
process.load('FWCore/MessageService/MessageLogger_cfi') | ||
process.MessageLogger.cerr.FwkReport.reportEvery = 500 | ||
for msg in keepMsgs: | ||
setattr(process.MessageLogger.cerr,msg, | ||
cms.untracked.PSet( | ||
limit = cms.untracked.int32(10000000), | ||
) | ||
) | ||
|
||
if options.threads>0: | ||
process.options.numberOfThreads = options.threads | ||
process.options.numberOfStreams = options.streams | ||
|