-
-
Notifications
You must be signed in to change notification settings - Fork 357
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added supported version information * Not needed to use libmyplugins.so anymore
- Loading branch information
1 parent
5b45057
commit 470ed82
Showing
6 changed files
with
539 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
#ifndef __TRT_UTILS_H_ | ||
#define __TRT_UTILS_H_ | ||
|
||
#include <iostream> | ||
#include <vector> | ||
#include <algorithm> | ||
#include <cudnn.h> | ||
|
||
#ifndef CUDA_CHECK | ||
|
||
#define CUDA_CHECK(callstr) \ | ||
{ \ | ||
cudaError_t error_code = callstr; \ | ||
if (error_code != cudaSuccess) { \ | ||
std::cerr << "CUDA error " << error_code << " at " << __FILE__ << ":" << __LINE__; \ | ||
assert(0); \ | ||
} \ | ||
} | ||
|
||
#endif | ||
|
||
namespace Tn | ||
{ | ||
class Profiler : public nvinfer1::IProfiler | ||
{ | ||
public: | ||
void printLayerTimes(int itrationsTimes) | ||
{ | ||
float totalTime = 0; | ||
for (size_t i = 0; i < mProfile.size(); i++) | ||
{ | ||
printf("%-40.40s %4.3fms\n", mProfile[i].first.c_str(), mProfile[i].second / itrationsTimes); | ||
totalTime += mProfile[i].second; | ||
} | ||
printf("Time over all layers: %4.3f\n", totalTime / itrationsTimes); | ||
} | ||
private: | ||
typedef std::pair<std::string, float> Record; | ||
std::vector<Record> mProfile; | ||
|
||
virtual void reportLayerTime(const char* layerName, float ms) | ||
{ | ||
auto record = std::find_if(mProfile.begin(), mProfile.end(), [&](const Record& r){ return r.first == layerName; }); | ||
if (record == mProfile.end()) | ||
mProfile.push_back(std::make_pair(layerName, ms)); | ||
else | ||
record->second += ms; | ||
} | ||
}; | ||
|
||
//Logger for TensorRT info/warning/errors | ||
class Logger : public nvinfer1::ILogger | ||
{ | ||
public: | ||
|
||
Logger(): Logger(Severity::kWARNING) {} | ||
|
||
Logger(Severity severity): reportableSeverity(severity) {} | ||
|
||
void log(Severity severity, const char* msg) override | ||
{ | ||
// suppress messages with severity enum value greater than the reportable | ||
if (severity > reportableSeverity) return; | ||
|
||
switch (severity) | ||
{ | ||
case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break; | ||
case Severity::kERROR: std::cerr << "ERROR: "; break; | ||
case Severity::kWARNING: std::cerr << "WARNING: "; break; | ||
case Severity::kINFO: std::cerr << "INFO: "; break; | ||
default: std::cerr << "UNKNOWN: "; break; | ||
} | ||
std::cerr << msg << std::endl; | ||
} | ||
|
||
Severity reportableSeverity{Severity::kWARNING}; | ||
}; | ||
|
||
template<typename T> | ||
void write(char*& buffer, const T& val) | ||
{ | ||
*reinterpret_cast<T*>(buffer) = val; | ||
buffer += sizeof(T); | ||
} | ||
|
||
template<typename T> | ||
void read(const char*& buffer, T& val) | ||
{ | ||
val = *reinterpret_cast<const T*>(buffer); | ||
buffer += sizeof(T); | ||
} | ||
} | ||
|
||
#endif |
Oops, something went wrong.