Skip to content

Commit

Permalink
Working producer on random data
Browse files Browse the repository at this point in the history
  • Loading branch information
ssrothman committed Oct 21, 2021
1 parent 5a29efb commit 922ad68
Show file tree
Hide file tree
Showing 8 changed files with 371 additions and 0 deletions.
Binary file not shown.
27 changes: 27 additions & 0 deletions Progression/EGM_DRN/data/models/EGM_DRN/config.pbtxt
@@ -0,0 +1,27 @@
name: "EGM_DRN"
platform: "pytorch_libtorch"
max_batch_size: 0
input [
{
name: "x__0"
data_type: TYPE_FP32
dims: [-1, 4]
},
{
name: "batch__1"
data_type: TYPE_INT64
dims: [-1]
},
{
name: "graphx__2"
data_type: TYPE_FP32
dims: [-1]
}
]
output [
{
name: "dscb__0"
data_type: TYPE_FP32
dims: [-1, 6]
}]

5 changes: 5 additions & 0 deletions Progression/EGM_DRN/plugins/BuildFile.xml
@@ -0,0 +1,5 @@
<use name="FWCore/Framework"/>
<use name="FWCore/ParameterSet"/>
<use name="FWCore/PluginManager"/>
<use name="HeterogeneousCore/SonicTriton"/>
<flags EDM_PLUGIN="1"/>
125 changes: 125 additions & 0 deletions Progression/EGM_DRN/plugins/TritonGraphModules.cc
@@ -0,0 +1,125 @@
#include "HeterogeneousCore/SonicTriton/interface/TritonEDProducer.h"

#include "FWCore/ParameterSet/interface/ConfigurationDescriptions.h"
#include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
#include "FWCore/Framework/interface/MakerMacros.h"

#include <sstream>
#include <string>
#include <vector>
#include <random>

class TritonGraphHelper {
public:
TritonGraphHelper(edm::ParameterSet const& cfg)
: nodeMin_(cfg.getParameter<unsigned>("nodeMin")),
nodeMax_(cfg.getParameter<unsigned>("nodeMax")),
brief_(cfg.getParameter<bool>("brief")) {}
void makeInput(edm::Event const& iEvent, TritonInputMap& iInput, const std::string& debugName) const {

//get event-based seed for RNG
unsigned int runNum_uint = static_cast<unsigned int>(iEvent.id().run());
unsigned int lumiNum_uint = static_cast<unsigned int>(iEvent.id().luminosityBlock());
unsigned int evNum_uint = static_cast<unsigned int>(iEvent.id().event());
std::uint32_t seed = (lumiNum_uint << 10) + (runNum_uint << 20) + evNum_uint;
std::mt19937 rng(seed);

std::uniform_int_distribution<int> randint1(nodeMin_, nodeMax_);
int nnodes = randint1(rng);

//set shapes
auto& input1 = iInput.at("x__0");
input1.setShape(0, nnodes);
auto data1 = input1.allocate<float>();
auto& vdata1 = (*data1)[0];

auto& input2 = iInput.at("batch__1");
input2.setShape(0, nnodes);
auto data2 = input2.allocate<int64_t>();
auto& vdata2 = (*data2)[0];

auto& input3 = iInput.at("graphx__2");
input3.setShape(0, 2);
auto data3 = input3.allocate<float>();
auto& vdata3 = (*data3)[0];

//fill
std::normal_distribution<float> randx(-10, 4);
for (unsigned i = 0; i < input1.sizeShape(); ++i) {
vdata1.push_back(randx(rng));
}

for (unsigned i = 0; i < input2.sizeShape(); ++i) {
vdata2.push_back(0);
}

for (unsigned i = 0; i < input3.sizeShape(); ++i) {
vdata3.push_back(randx(rng));
}

// convert to server format
input1.toServer(data1);
input2.toServer(data2);
input3.toServer(data3);

edm::LogInfo(debugName) << "input X shape: " << input1.shape()[0] << ", " << input1.shape()[1];
edm::LogInfo(debugName) << "input batch shape: " << input2.shape()[0];
edm::LogInfo(debugName) << "input graphx shape: " << input3.shape()[0];
}
void makeOutput(const TritonOutputMap& iOutput, const std::string& debugName) const {
edm::LogInfo (debugName) << "top of output ";
//check the results
const auto& output1 = iOutput.begin()->second;
// convert from server format
const auto& tmp = output1.fromServer<float>();
//if (brief_)
edm::LogInfo(debugName) << "output shape: " << output1.shape()[0] << "," << output1.shape()[1];
//else {
// edm::LogInfo msg(debugName);
// for (int i = 0; i < output1.shape()[0]; ++i) {
// msg << "output " << i << ": ";
// for (int j = 0; j < output1.shape()[1]; ++j) {
// msg << tmp[0][output1.shape()[1] * i + j] << " ";
// }
// msg << "\n";
// }
//}
}
static void fillPSetDescription(edm::ParameterSetDescription& desc) {
desc.add<unsigned>("nodeMin", 1);
desc.add<unsigned>("nodeMax", 200);
desc.add<bool>("brief", false);
}

private:
//members
unsigned nodeMin_, nodeMax_;
bool brief_;
};

class DRNProducer : public TritonEDProducer<> {
public:
explicit DRNProducer(edm::ParameterSet const& cfg)
: TritonEDProducer<>(cfg, "DRNProducer"), helper_(cfg) {}
void acquire(edm::Event const& iEvent, edm::EventSetup const& iSetup, Input& iInput) override {
helper_.makeInput(iEvent, iInput, debugName_);
}
void produce(edm::Event& iEvent, edm::EventSetup const& iSetup, Output const& iOutput) override {
helper_.makeOutput(iOutput, debugName_);
}
~DRNProducer() override = default;

static void fillDescriptions(edm::ConfigurationDescriptions& descriptions) {
edm::ParameterSetDescription desc;
TritonClient::fillPSetDescription(desc);
TritonGraphHelper::fillPSetDescription(desc);
//to ensure distinct cfi names
descriptions.addWithDefaultLabel(desc);
}

private:
//member
TritonGraphHelper helper_;
};

DEFINE_FWK_MODULE(DRNProducer);
4 changes: 4 additions & 0 deletions Progression/EGM_DRN/test/BuildFile.xml
@@ -0,0 +1,4 @@
<bin file="test_catch2_*.cc" name="testProgressionEGM_DRNTP">
<use name="FWCore/TestProcessor"/>
<use name="catch2"/>
</bin>
47 changes: 47 additions & 0 deletions Progression/EGM_DRN/test/test_catch2_EGM_DRN.cc
@@ -0,0 +1,47 @@
#include "catch.hpp"
#include "FWCore/TestProcessor/interface/TestProcessor.h"
#include "FWCore/Utilities/interface/Exception.h"

static constexpr auto s_tag = "[EGM_DRN]";

TEST_CASE("Standard checks of EGM_DRN", s_tag) {
const std::string baseConfig{
R"_(from FWCore.TestProcessor.TestProcess import *
process = TestProcess()
process.toTest = cms.EDProducer("EGM_DRN"
#necessary configuration parameters
)
process.moduleToTest(process.toTest)
)_"};

edm::test::TestProcessor::Config config{baseConfig};
SECTION("base configuration is OK") { REQUIRE_NOTHROW(edm::test::TestProcessor(config)); }

SECTION("No event data") {
edm::test::TestProcessor tester(config);

REQUIRE_THROWS_AS(tester.test(), cms::Exception);
//If the module does not throw when given no data, substitute
//REQUIRE_NOTHROW for REQUIRE_THROWS_AS
}

SECTION("beginJob and endJob only") {
edm::test::TestProcessor tester(config);

REQUIRE_NOTHROW(tester.testBeginAndEndJobOnly());
}

SECTION("Run with no LuminosityBlocks") {
edm::test::TestProcessor tester(config);

REQUIRE_NOTHROW(tester.testRunWithNoLuminosityBlocks());
}

SECTION("LuminosityBlock with no Events") {
edm::test::TestProcessor tester(config);

REQUIRE_NOTHROW(tester.testLuminosityBlockWithNoEvents());
}
}

//Add additional TEST_CASEs to exercise the modules capabilities
2 changes: 2 additions & 0 deletions Progression/EGM_DRN/test/test_catch2_main.cc
@@ -0,0 +1,2 @@
#define CATCH_CONFIG_MAIN
#include "catch.hpp"
161 changes: 161 additions & 0 deletions Progression/EGM_DRN/test/tritonTest_cfg.py
@@ -0,0 +1,161 @@
from FWCore.ParameterSet.VarParsing import VarParsing
import FWCore.ParameterSet.Config as cms
import os, sys, json

# module/model correspondence
models = {
"TritonImageProducer": ["inception_graphdef", "densenet_onnx"],
"TritonGraphProducer": ["gat_test"],
"TritonGraphFilter": ["gat_test"],
"TritonGraphAnalyzer": ["gat_test"],
"DRNProducer" : ["EGM_DRN"]
}

# other choices
allowed_modes = ["Async","PseudoAsync","Sync"]
allowed_compression = ["none","deflate","gzip"]
allowed_devices = ["auto","cpu","gpu"]

options = VarParsing()
options.register("maxEvents", -1, VarParsing.multiplicity.singleton, VarParsing.varType.int, "Number of events to process (-1 for all)")
options.register("serverName", "default", VarParsing.multiplicity.singleton, VarParsing.varType.string, "name for server (used internally)")
options.register("address", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "server address")
options.register("port", 8001, VarParsing.multiplicity.singleton, VarParsing.varType.int, "server port")
options.register("timeout", 30, VarParsing.multiplicity.singleton, VarParsing.varType.int, "timeout for requests")
options.register("params", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "json file containing server address/port")
options.register("threads", 1, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of threads")
options.register("streams", 0, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of streams")
options.register("modules", "TritonGraphProducer", VarParsing.multiplicity.list, VarParsing.varType.string, "list of modules to run (choices: {})".format(', '.join(models)))
options.register("models","gat_test", VarParsing.multiplicity.list, VarParsing.varType.string, "list of models (same length as modules, or just 1 entry if all modules use same model)")
options.register("mode","Async", VarParsing.multiplicity.singleton, VarParsing.varType.string, "mode for client (choices: {})".format(', '.join(allowed_modes)))
options.register("verbose", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable verbose output")
options.register("brief", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "briefer output for graph modules")
options.register("unittest", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "unit test mode: reduce input sizes")
options.register("testother", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "also test gRPC communication if shared memory enabled, or vice versa")
options.register("shm", True, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable shared memory")
options.register("compression", "", VarParsing.multiplicity.singleton, VarParsing.varType.string, "enable I/O compression (choices: {})".format(', '.join(allowed_compression)))
options.register("ssl", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "enable SSL authentication for server communication")
options.register("device","auto", VarParsing.multiplicity.singleton, VarParsing.varType.string, "specify device for fallback server (choices: {})".format(', '.join(allowed_devices)))
options.register("docker", False, VarParsing.multiplicity.singleton, VarParsing.varType.bool, "use Docker for fallback server")
options.register("tries", 0, VarParsing.multiplicity.singleton, VarParsing.varType.int, "number of retries for failed request")
options.parseArguments()

if len(options.params)>0:
with open(options.params,'r') as pfile:
pdict = json.load(pfile)
options.address = pdict["address"]
options.port = int(pdict["port"])
print("server = "+options.address+":"+str(options.port))

# check models and modules
if len(options.modules)!=len(options.models):
# assigning to VarParsing.multiplicity.list actually appends to existing value(s)
if len(options.models)==1: options.models = [options.models[0]]*(len(options.modules)-1)
else: raise ValueError("Arguments for modules and models must have same length")
for im,module in enumerate(options.modules):
if module not in models:
raise ValueError("Unknown module: {}".format(module))
model = options.models[im]
if model not in models[module]:
raise ValueError("Unsupported model {} for module {}".format(model,module))

# check modes
if options.mode not in allowed_modes:
raise ValueError("Unknown mode: {}".format(options.mode))

# check compression
if len(options.compression)>0 and options.compression not in allowed_compression:
raise ValueError("Unknown compression setting: {}".format(options.compression))

# check devices
options.device = options.device.lower()
if options.device not in allowed_devices:
raise ValueError("Unknown device: {}".format(options.device))

from Configuration.ProcessModifiers.enableSonicTriton_cff import enableSonicTriton
process = cms.Process('tritonTest',enableSonicTriton)

process.load("HeterogeneousCore.SonicTriton.TritonService_cff")

process.maxEvents = cms.untracked.PSet( input = cms.untracked.int32(options.maxEvents) )

process.source = cms.Source("EmptySource")

process.TritonService.verbose = options.verbose
process.TritonService.fallback.verbose = options.verbose
process.TritonService.fallback.useDocker = options.docker
if options.device != "auto":
process.TritonService.fallback.useGPU = options.device=="gpu"
if len(options.address)>0:
process.TritonService.servers.append(
cms.PSet(
name = cms.untracked.string(options.serverName),
address = cms.untracked.string(options.address),
port = cms.untracked.uint32(options.port),
useSsl = cms.untracked.bool(options.ssl),
rootCertificates = cms.untracked.string(""),
privateKey = cms.untracked.string(""),
certificateChain = cms.untracked.string(""),
)
)

# Let it run
process.p = cms.Path()

modules = {
"Producer": cms.EDProducer,
"Filter": cms.EDFilter,
"Analyzer": cms.EDAnalyzer,
}

keepMsgs = ['TritonClient','TritonService']

for im,module in enumerate(options.modules):
model = options.models[im]
Module = [obj for name,obj in modules.items() if name in module][0]
setattr(process, module,
Module(module,
Client = cms.PSet(
mode = cms.string(options.mode),
preferredServer = cms.untracked.string(""),
timeout = cms.untracked.uint32(options.timeout),
modelName = cms.string(model),
modelVersion = cms.string(""),
modelConfigPath = cms.FileInPath("Progression/EGM_DRN/data/models/{}/config.pbtxt".format(model)),
verbose = cms.untracked.bool(options.verbose),
allowedTries = cms.untracked.uint32(options.tries),
useSharedMemory = cms.untracked.bool(options.shm),
compression = cms.untracked.string(options.compression),
)
)
)
processModule = getattr(process, module)
processModule.nodeMin = cms.uint32(1)
processModule.nodeMax = cms.uint32(200)
processModule.brief = cms.bool(options.brief)
process.p += processModule
keepMsgs.extend([module,module+':TritonClient'])
if options.testother:
# clone modules to test both gRPC and shared memory
_module2 = module+"GRPC" if processModule.Client.useSharedMemory else "SHM"
setattr(process, _module2,
processModule.clone(
Client = dict(useSharedMemory = not processModule.Client.useSharedMemory)
)
)
processModule2 = getattr(process, _module2)
process.p += processModule2

process.load('FWCore/MessageService/MessageLogger_cfi')
process.MessageLogger.cerr.FwkReport.reportEvery = 500
for msg in keepMsgs:
setattr(process.MessageLogger.cerr,msg,
cms.untracked.PSet(
limit = cms.untracked.int32(10000000),
)
)

if options.threads>0:
process.options.numberOfThreads = options.threads
process.options.numberOfStreams = options.streams

0 comments on commit 922ad68

Please sign in to comment.