-
Notifications
You must be signed in to change notification settings - Fork 2.9k
/
tensorrt_execution_provider.cc
4314 lines (3914 loc) · 207 KB
/
tensorrt_execution_provider.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <fstream>
#include <list>
#include <unordered_set>
#include "core/providers/shared_library/provider_api.h"
#define ORT_API_MANUAL_INIT
#include "core/session/onnxruntime_cxx_api.h"
#include "core/common/common.h"
#include "core/common/narrow.h"
#include "core/common/safeint.h"
#include "tensorrt_execution_provider.h"
#include "tensorrt_execution_provider_utils.h"
#include "tensorrt_execution_provider_custom_ops.h"
#include "onnx_ctx_model_helper.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
#include "core/providers/cuda/gpu_data_transfer.h"
#include "core/session/allocator_adapters.h"
#include "cuda_runtime_api.h"
#include <gsl/gsl>
#include <unordered_map>
#include <utility>
#include <limits>
#include <map>
#include <memory>
#include <filesystem>
// TODO: find a better way to share this
#include "core/providers/cuda/cuda_stream_handle.h"
#ifdef _WIN32
#include <windows.h>
#define LIBTYPE HINSTANCE
#define OPENLIB(libname) LoadLibrary(libname)
#define LIBFUNC(lib, fn) GetProcAddress((lib), (fn))
#else
#include <dlfcn.h>
#define LIBTYPE void*
#define OPENLIB(libname) dlopen((libname), RTLD_LAZY)
#define LIBFUNC(lib, fn) dlsym((lib), (fn))
#endif
#define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr))
using namespace ONNX_NAMESPACE;
using namespace ::onnxruntime::logging;
namespace {
// Check if cycle exists in the graph after partitioning
bool FindCycleHelper(size_t i, const std::list<size_t>* adjacency_map, bool visited[], bool* st, std::vector<size_t>& cycles) {
if (!visited[i]) {
visited[i] = true;
st[i] = true;
for (auto iter = adjacency_map[i].begin(); iter != adjacency_map[i].end(); ++iter) {
if (!visited[*iter] && FindCycleHelper(*iter, adjacency_map, visited, st, cycles)) {
cycles.push_back(*iter);
return true;
} else if (st[*iter]) {
cycles.push_back(*iter);
return true;
}
}
}
st[i] = false;
return false;
}
bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map<std::string, float>& dynamic_range_map) {
// Set dynamic range for input tensors
for (int i = 0; i < network.getNbInputs(); ++i) {
const std::string tensor_name = network.getInput(i)->getName();
auto dynamic_range_iter = dynamic_range_map.find(tensor_name);
if (dynamic_range_iter != dynamic_range_map.end()) {
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
if (!network.getInput(i)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) {
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for network input " << tensor_name;
return false;
}
}
}
// Set dynamic range for activations and weights
for (int i = 0; i < network.getNbLayers(); ++i) {
auto trt_layer = network.getLayer(i);
for (int j = 0, e = trt_layer->getNbOutputs(); j < e; ++j) {
const std::string tensor_name = trt_layer->getOutput(j)->getName();
auto dynamic_range_iter = dynamic_range_map.find(tensor_name);
if (dynamic_range_iter != dynamic_range_map.end()) {
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
if (!trt_layer->getOutput(j)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) {
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for tensor " << tensor_name;
return false;
}
} else if (trt_layer->getType() == nvinfer1::LayerType::kCONSTANT) {
nvinfer1::IConstantLayer* const_layer = static_cast<nvinfer1::IConstantLayer*>(trt_layer);
const std::string const_layer_name = const_layer->getName();
auto trt_weights = const_layer->getWeights();
double max_weight = std::numeric_limits<double>::min();
for (int64_t k = 0, end = trt_weights.count; k < end; ++k) {
double weight{};
switch (trt_weights.type) {
case nvinfer1::DataType::kFLOAT:
weight = static_cast<const float*>(trt_weights.values)[k];
break;
case nvinfer1::DataType::kBOOL:
weight = static_cast<const bool*>(trt_weights.values)[k];
break;
case nvinfer1::DataType::kINT8:
weight = static_cast<const int8_t*>(trt_weights.values)[k];
break;
case nvinfer1::DataType::kHALF:
weight = static_cast<const uint16_t*>(trt_weights.values)[k];
break;
case nvinfer1::DataType::kINT32:
weight = static_cast<const int32_t*>(trt_weights.values)[k];
break;
#if NV_TENSORRT_MAJOR >= 10
case nvinfer1::DataType::kINT64:
weight = static_cast<double>(static_cast<const int64_t*>(trt_weights.values)[k]);
break;
#endif // NV_TENSORRT_MAJOR >= 10
default:
LOGS_DEFAULT(ERROR) << "Found unsupported datatype for layer " << const_layer_name;
return false;
}
max_weight = std::max(max_weight, std::abs(weight));
}
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
if (!trt_layer->getOutput(j)->setDynamicRange(static_cast<float>(-max_weight), static_cast<float>(max_weight))) {
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for layer " << const_layer_name;
return false;
}
}
}
}
return true;
}
std::vector<std::string> SplitToStringVec(std::string const& s, char separator) {
std::vector<std::string> splitted;
for (size_t start = 0; start < s.length();) {
size_t separatorIndex = s.find(separator, start);
if (separatorIndex == std::string::npos) {
separatorIndex = s.length();
}
splitted.emplace_back(s.substr(start, separatorIndex - start));
start = separatorIndex + 1;
}
return splitted;
}
nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) {
nvinfer1::TacticSources disabledTactics = 0;
nvinfer1::TacticSources enabledTactics = 0;
std::vector<std::string> tacticList = SplitToStringVec(tactic_string, ',');
for (auto& t : tacticList) {
bool enable{false};
if (t.front() == '+') {
enable = true;
} else if (t.front() != '-') {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source must be prefixed with + or - skipping: " << t;
}
t.erase(0, 1);
const auto toUpper = [](std::string& sourceName) {
std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(),
[](char c) { return onnxruntime::narrow<char>(std::toupper(c)); });
return sourceName;
};
nvinfer1::TacticSource source{};
t = toUpper(t);
if (t == "CUBLAS") {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS is deprecated in TensorRT 10.0";
#if NV_TENSORRT_MAJOR < 10
source = nvinfer1::TacticSource::kCUBLAS;
#endif
} else if (t == "CUBLASLT" || t == "CUBLAS_LT") {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS_LT is deprecated in TensorRT 9.0";
#if NV_TENSORRT_MAJOR < 9
source = nvinfer1::TacticSource::kCUBLAS_LT;
#endif
} else if (t == "CUDNN") {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUDNN is deprecated in TensorRT 10.0";
#if NV_TENSORRT_MAJOR < 10
source = nvinfer1::TacticSource::kCUDNN;
#endif
} else if (t == "EDGE_MASK_CONVOLUTIONS") {
source = nvinfer1::TacticSource::kEDGE_MASK_CONVOLUTIONS;
} else if (t == "JIT_CONVOLUTIONS") {
source = nvinfer1::TacticSource::kJIT_CONVOLUTIONS;
} else {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source was not found with name: " << t;
}
uint32_t sourceBit = 1U << static_cast<uint32_t>(source);
if (enable) {
enabledTactics |= sourceBit;
} else {
disabledTactics |= sourceBit;
}
}
return enabledTactics & ~disabledTactics;
}
inline std::vector<char> loadTimingCacheFile(const std::string inFileName) {
std::ifstream iFile(inFileName, std::ios::in | std::ios::binary);
if (!iFile) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not read timing cache from: " << inFileName
<< ". A new timing cache will be generated and written.";
return std::vector<char>();
}
iFile.seekg(0, std::ifstream::end);
size_t fsize = iFile.tellg();
iFile.seekg(0, std::ifstream::beg);
std::vector<char> content(fsize);
iFile.read(content.data(), fsize);
iFile.close();
return content;
}
inline void saveTimingCacheFile(const std::string outFileName, const nvinfer1::IHostMemory* blob) {
std::ofstream oFile(outFileName, std::ios::out | std::ios::binary);
if (!oFile) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not write timing cache to: " << outFileName;
return;
}
oFile.write((char*)blob->data(), blob->size());
oFile.close();
}
} // namespace
namespace google {
namespace protobuf {
void ShutdownProtobufLibrary();
}
} // namespace google
struct ShutdownProtobuf {
~ShutdownProtobuf() {
::google::protobuf::ShutdownProtobufLibrary();
}
} g_protobuf;
namespace onnxruntime {
namespace cuda {
template <>
void Impl_Cast(
cudaStream_t stream,
const int64_t* input_data, int32_t* output_data,
size_t count) {
return g_host->cuda__Impl_Cast(static_cast<void*>(stream), input_data, output_data, count);
}
template <>
void Impl_Cast(
cudaStream_t stream,
const int32_t* input_data, int64_t* output_data,
size_t count) {
return g_host->cuda__Impl_Cast(static_cast<void*>(stream), input_data, output_data, count);
}
template <>
void Impl_Cast(
cudaStream_t stream,
const double* input_data, float* output_data,
size_t count) {
return g_host->cuda__Impl_Cast(static_cast<void*>(stream), input_data, output_data, count);
}
template <>
void Impl_Cast(
cudaStream_t stream,
const float* input_data, double* output_data,
size_t count) {
return g_host->cuda__Impl_Cast(static_cast<void*>(stream), input_data, output_data, count);
}
} // namespace cuda
template <>
Status CudaCall<cudaError, false>(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line) {
return g_host->CudaCall_false(retCode, exprString, libName, successCode, msg, file, line);
}
template <>
void CudaCall<cudaError, true>(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line) {
return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line);
}
#ifndef USE_CUDA_MINIMAL
template <>
Status CudaCall<cublasStatus_t, false>(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line) {
return g_host->CudaCall_false(retCode, exprString, libName, successCode, msg, file, line);
}
template <>
void CudaCall<cublasStatus_t, true>(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line) {
return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line);
}
template <>
Status CudaCall<cudnnStatus_t, false>(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg, const char* file, const int line) {
return g_host->CudaCall_false(retCode, exprString, libName, successCode, msg, file, line);
}
template <>
void CudaCall<cudnnStatus_t, true>(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg, const char* file, const int line) {
return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line);
}
#endif
#if NV_TENSORRT_MAJOR >= 10
void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size,
uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept {
// Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr
// even for empty tensors, so allocate a dummy byte.
size = std::max(size, static_cast<uint64_t>(1));
if (size > allocated_size) {
cudaFree(outputPtr);
outputPtr = nullptr;
allocated_size = 0;
if (cudaMalloc(&outputPtr, size) == cudaSuccess) {
allocated_size = size;
}
}
// if cudaMalloc fails, returns nullptr.
return outputPtr;
}
#else
// Only override this method when TensorRT <= 8.6
void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size,
uint64_t /*alignment*/) noexcept {
// Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr
// even for empty tensors, so allocate a dummy byte.
size = std::max(size, static_cast<uint64_t>(1));
if (size > allocated_size) {
cudaFree(outputPtr);
outputPtr = nullptr;
allocated_size = 0;
if (cudaMalloc(&outputPtr, size) == cudaSuccess) {
allocated_size = size;
}
}
// if cudaMalloc fails, returns nullptr.
return outputPtr;
}
#endif
void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept {
output_shapes.clear();
output_shapes.reserve(dims.nbDims);
for (int i = 0; i < dims.nbDims; i++) {
output_shapes.push_back(dims.d[i]);
}
}
class Memcpy final : public OpKernel {
public:
Memcpy(const OpKernelInfo& info) : OpKernel(info) {}
Status Compute(OpKernelContext* ctx) const override {
const auto* X = ctx->Input<Tensor>(0);
ORT_ENFORCE(X != nullptr, "Memcpy: Input tensor is nullptr.");
Tensor* Y = ctx->Output(0, X->Shape());
ORT_ENFORCE(Y != nullptr, "Memcpy: Failed to allocate output tensor.");
auto* gpu_data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device);
if (!gpu_data_transfer)
return Status(common::ONNXRUNTIME, common::EP_FAIL, "gpu data transfer is missing in TRT EP.");
if (!ctx->GetComputeStream())
return Status(common::ONNXRUNTIME, common::EP_FAIL, "Compute Stream is missing in TRT MemCpy kernel's context.");
return gpu_data_transfer->CopyTensorAsync(*X, *Y, *(ctx->GetComputeStream()));
}
};
template <typename T>
KernelCreateInfo BuildKernelCreateInfo();
ONNX_OPERATOR_KERNEL_EX(
MemcpyFromHost,
kOnnxDomain,
1,
kTensorrtExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Memcpy);
ONNX_OPERATOR_KERNEL_EX(
MemcpyToHost,
kOnnxDomain,
1,
kTensorrtExecutionProvider,
(*KernelDefBuilder::Create())
.OutputMemoryType(OrtMemTypeCPUOutput, 0)
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Memcpy);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyToHost);
static std::shared_ptr<KernelRegistry> s_kernel_registry;
void InitializeRegistry() {
s_kernel_registry = KernelRegistry::Create();
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kTensorrtExecutionProvider, kOnnxDomain, 1, MemcpyToHost)>,
};
for (auto& function_table_entry : function_table) {
ORT_THROW_IF_ERROR(s_kernel_registry->Register(function_table_entry()));
}
}
void DeleteRegistry() {
s_kernel_registry.reset();
}
std::shared_ptr<KernelRegistry> TensorrtExecutionProvider::GetKernelRegistry() const {
return s_kernel_registry;
}
// Per TensorRT documentation, logger needs to be a singleton.
TensorrtLogger& GetTensorrtLogger(bool verbose_log) {
const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING;
static TensorrtLogger trt_logger(log_level);
if (log_level != trt_logger.get_level()) {
trt_logger.set_level(verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING);
}
return trt_logger;
}
std::unique_lock<OrtMutex> TensorrtExecutionProvider::GetApiLock() const {
static OrtMutex singleton;
return std::unique_lock<OrtMutex>(singleton);
}
/*
* Get the shape of "shape tensor" input
*/
template <typename T>
Status GetShapeOfShapeTensor(Ort::ConstValue& input_tensor,
void* shape_values,
int shape_size,
cudaStream_t stream) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(shape_values,
input_tensor.GetTensorData<T>(),
shape_size * sizeof(T),
cudaMemcpyDeviceToHost,
stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
return Status::OK();
}
/*
* Apply TensorRT optimization profile shapes from provider options.
*
* This function supports single/multiple profile(s).
* (Note: An optimization profile describes a range of dimensions for each network input)
*
*/
bool ApplyProfileShapesFromProviderOptions(std::vector<nvinfer1::IOptimizationProfile*>& trt_profiles,
nvinfer1::ITensor* input,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_min_shapes,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_max_shapes,
std::unordered_map<std::string, std::vector<std::vector<int64_t>>>& profile_opt_shapes,
ShapeRangesMap& input_explicit_shape_ranges) {
if (trt_profiles.size() == 0) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Number of optimization profiles should be greater than 0, but it's 0.";
return false;
}
const std::string& input_name = input->getName();
if (profile_min_shapes.find(input_name) == profile_min_shapes.end()) {
return false;
}
if (input_explicit_shape_ranges.find(input_name) == input_explicit_shape_ranges.end()) {
std::unordered_map<size_t, std::vector<std::vector<int64_t>>> inner_map;
input_explicit_shape_ranges[input_name] = inner_map;
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Begin to apply profile shapes ...";
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Input tensor name is '" << input_name << "', number of profiles found is " << trt_profiles.size();
for (size_t i = 0; i < trt_profiles.size(); i++) {
nvinfer1::Dims dims = input->getDimensions();
int nb_dims = dims.nbDims;
auto trt_profile = trt_profiles[i];
// Shape tensor
if (input->isShapeTensor()) {
int shape_size = nb_dims == 0 ? 1 : static_cast<int>(profile_min_shapes[input_name][i].size());
std::vector<int32_t> shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shape size of this shape tensor is " << shape_size;
for (int j = 0; j < shape_size; j++) {
auto min_value = profile_min_shapes[input_name][i][j];
auto max_value = profile_max_shapes[input_name][i][j];
auto opt_value = profile_opt_shapes[input_name][i][j];
shapes_min[j] = static_cast<int32_t>(min_value);
shapes_max[j] = static_cast<int32_t>(max_value);
shapes_opt[j] = static_cast<int32_t>(opt_value);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_min.d[" << j << "] is " << shapes_min[j];
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_max.d[" << j << "] is " << shapes_max[j];
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] shapes_opt.d[" << j << "] is " << shapes_opt[j];
if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) {
std::vector<std::vector<int64_t>> profile_vector(trt_profiles.size());
input_explicit_shape_ranges[input_name][j] = profile_vector;
}
input_explicit_shape_ranges[input_name][static_cast<int64_t>(j)][i].push_back(min_value);
input_explicit_shape_ranges[input_name][static_cast<int64_t>(j)][i].push_back(max_value);
input_explicit_shape_ranges[input_name][static_cast<int64_t>(j)][i].push_back(opt_value);
}
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size);
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size);
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size);
}
// Execution tensor
else {
nvinfer1::Dims dims_min, dims_opt, dims_max;
dims_min.nbDims = nb_dims;
dims_max.nbDims = nb_dims;
dims_opt.nbDims = nb_dims;
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] number of dimension of this execution tensor is " << nb_dims;
for (int j = 0; j < nb_dims; j++) {
if (dims.d[j] == -1) {
auto min_value = profile_min_shapes[input_name][i][j];
auto max_value = profile_max_shapes[input_name][i][j];
auto opt_value = profile_opt_shapes[input_name][i][j];
dims_min.d[j] = static_cast<int32_t>(min_value);
dims_max.d[j] = static_cast<int32_t>(max_value);
dims_opt.d[j] = static_cast<int32_t>(opt_value);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_min.d[" << j << "] is " << dims_min.d[j];
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_max.d[" << j << "] is " << dims_max.d[j];
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dims_opt.d[" << j << "] is " << dims_opt.d[j];
if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) {
std::vector<std::vector<int64_t>> profile_vector(trt_profiles.size());
input_explicit_shape_ranges[input_name][j] = profile_vector;
}
input_explicit_shape_ranges[input_name][static_cast<int64_t>(j)][i].push_back(min_value);
input_explicit_shape_ranges[input_name][static_cast<int64_t>(j)][i].push_back(max_value);
input_explicit_shape_ranges[input_name][static_cast<int64_t>(j)][i].push_back(opt_value);
} else {
dims_min.d[j] = dims.d[j];
dims_max.d[j] = dims.d[j];
dims_opt.d[j] = dims.d[j];
}
}
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min);
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max);
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
}
}
return true;
}
/*
* Apply TensorRT optimization profile shapes from input tensor value.
*
* This function supports single/multiple profile(s).
* (Note: An optimization profile describes a range of dimensions for each network input)
*
* @param shape_tensor_values holds "shape tensor -> shape values" for the INT32 shape tensor input across this inference run
* @param shape_tensor_values_int64 holds "shape tensor -> shape values" for the INT64 shape tensor input across this inference run
*/
Status ApplyProfileShapesFromInputTensorValue(std::vector<nvinfer1::IOptimizationProfile*>& trt_profiles,
Ort::KernelContext ctx,
nvinfer1::ITensor* input,
ShapeRangesMap& shape_ranges,
const std::unordered_map<std::string, size_t>& input_indexes,
std::unordered_map<std::string, std::vector<int32_t>>& shape_tensor_values,
std::unordered_map<std::string, std::vector<int64_t>>& shape_tensor_values_int64,
cudaStream_t stream,
bool* engine_update) {
for (size_t i = 0; i < trt_profiles.size(); i++) {
const std::string& input_name = input->getName();
nvinfer1::Dims dims = input->getDimensions();
int nb_dims = dims.nbDims;
size_t input_index = 0;
const auto& iter = input_indexes.find(input_name);
if (iter != input_indexes.end()) {
input_index = iter->second;
}
auto input_tensor = ctx.GetInput(input_index);
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
const auto tensor_shapes = tensor_info.GetShape();
auto& shape_ranges_per_input = shape_ranges[input_name];
auto trt_profile = trt_profiles[i];
// If there are multiple profiles, for second and rest of profiles, simply copy the min/max/opt profile values from the first profile.
// Following "if statement" won't be executed since TRT EP currently only allows single profile for non-explicit profiles case.
if (i > 0) {
if (input->isShapeTensor()) {
// shape tensor
int shape_size = nb_dims == 0 ? 1 : static_cast<int>(tensor_shapes[0]);
std::vector<int32_t> shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size);
for (int j = 0; j < shape_size; j++) {
shapes_min[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN));
shapes_max[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX));
shapes_opt[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT));
}
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size);
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size);
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size);
} else {
// execution tensor
nvinfer1::Dims dims_min, dims_opt, dims_max;
dims_min = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN);
dims_max = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX);
dims_opt = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT);
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min);
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max);
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
}
continue;
}
// Create shape profile
if (input->isShapeTensor()) {
// Get shape values for shape tensor input
const auto tensor_type = tensor_info.GetElementType();
// The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension
int shape_size = dims.nbDims == 0 ? 1 : static_cast<int>(tensor_shapes[0]);
// For setting TRT optimization profile. (Note: the min/opt/max profile values are still int32 even though int64 is supported after TRT 10)
std::vector<int32_t> values(shape_size);
switch (tensor_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
auto buffer = std::make_unique<int32_t[]>(shape_size);
auto status = GetShapeOfShapeTensor<int32_t>(input_tensor, buffer.get(), shape_size, stream);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
}
shape_tensor_values[input_name].resize(shape_size);
for (int j = 0; j < shape_size; ++j) {
shape_tensor_values[input_name][j] = buffer[j];
values[j] = buffer[j];
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
auto buffer = std::make_unique<int64_t[]>(shape_size);
auto status = GetShapeOfShapeTensor<int64_t>(input_tensor, buffer.get(), shape_size, stream);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
}
shape_tensor_values_int64[input_name].resize(shape_size);
for (int j = 0; j < shape_size; ++j) {
shape_tensor_values_int64[input_name][j] = buffer[j];
values[j] = static_cast<int32_t>(buffer[j]);
}
break;
}
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported.");
}
}
// Update shape ranges
std::vector<int32_t> shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size);
int shape_range_size = static_cast<int>(shape_ranges_per_input.size());
if (shape_size == shape_range_size) {
// If shape size matches, check/update shape range
for (int j = 0; j < shape_size; ++j) {
auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile
shapes_min[j] = static_cast<int32_t>(shape_range[0]);
shapes_max[j] = static_cast<int32_t>(shape_range[1]);
shapes_opt[j] = static_cast<int32_t>(shape_range[2]);
const auto& tensor_shape_value = values[j];
// Update shape range lower bound
if (tensor_shape_value < shape_range[0]) {
shape_range[0] = tensor_shape_value;
shapes_min[j] = tensor_shape_value;
*engine_update = true;
}
// Update shape range upper bound
if (tensor_shape_value > shape_range[1]) {
shape_range[1] = tensor_shape_value;
shape_range[2] = tensor_shape_value;
shapes_max[j] = tensor_shape_value;
shapes_opt[j] = tensor_shape_value;
*engine_update = true;
}
}
} else {
// If shape size doesn't match, initialize shape_range with the new shape value
shape_ranges_per_input.clear();
for (int j = 0; j < shape_size; ++j) {
const auto& tensor_shape_value = values[j];
std::vector<std::vector<int64_t>> profile_vector;
std::vector<int64_t> shape_vector{tensor_shape_value, tensor_shape_value, tensor_shape_value};
profile_vector.push_back(shape_vector); // only one profile needed
shape_ranges_per_input[j] = profile_vector;
shapes_min[j] = tensor_shape_value;
shapes_opt[j] = tensor_shape_value;
shapes_max[j] = tensor_shape_value;
}
*engine_update = true;
}
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size);
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size);
trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size);
} else { // Execution tensor
nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims);
for (int j = 0, end = nb_dims; j < end; ++j) {
const auto& tensor_shape = tensor_shapes[j];
if (shape_ranges_per_input.find(j) != shape_ranges_per_input.end()) {
auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile
dims_min.d[j] = static_cast<int32_t>(shape_range[0]);
dims_max.d[j] = static_cast<int32_t>(shape_range[1]);
dims_opt.d[j] = static_cast<int32_t>(shape_range[2]);
// Update minimum dimension
if (tensor_shape < shape_range[0]) {
shape_range[0] = tensor_shape;
dims_min.d[j] = static_cast<int32_t>(tensor_shape);
*engine_update = true;
}
// Update maximum dimension
if (tensor_shape > shape_range[1]) {
shape_range[1] = tensor_shape;
shape_range[2] = tensor_shape;
dims_max.d[j] = static_cast<int32_t>(tensor_shape);
dims_opt.d[j] = static_cast<int32_t>(tensor_shape);
*engine_update = true;
}
}
}
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min);
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max);
trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt);
}
}
return Status::OK();
}
#define CASE_GET_INPUT_TENSOR(DATA_TYPE, SrcT) \
case DATA_TYPE: { \
auto input_tensor_ptr = input_tensor.GetTensorData<SrcT>(); \
if (input_tensor_ptr != nullptr && elem_cnt > 0) { \
data = const_cast<SrcT*>(input_tensor_ptr); \
} else { \
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, 1)); \
data = scratch_buffers.back().get(); \
} \
break; \
}
#define CASE_GET_CAST_INPUT_TENSOR(DATA_TYPE, SrcT, DstT) \
case DATA_TYPE: { \
auto input_tensor_ptr = input_tensor.GetTensorData<SrcT>(); \
if (input_tensor_ptr != nullptr && elem_cnt > 0) { \
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, elem_cnt * sizeof(DstT))); \
data = scratch_buffers.back().get(); \
cuda::Impl_Cast<SrcT, DstT>(stream, input_tensor_ptr, reinterpret_cast<DstT*>(data), elem_cnt); \
} else { \
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, 1)); \
data = scratch_buffers.back().get(); \
} \
break; \
}
#define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \
case DATA_TYPE: { \
auto output_tensor_ptr = output_tensor.GetTensorMutableData<SrcT>(); \
if (output_tensor_ptr != nullptr && elem_cnt > 0) { \
buffers[output_name] = output_tensor_ptr; \
} else { \
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, 1)); \
buffers[output_name] = scratch_buffers.back().get(); \
} \
break; \
}
#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \
case DATA_TYPE: { \
auto output_tensor_ptr = output_tensor.GetTensorMutableData<SrcT>(); \
if (output_tensor_ptr != nullptr && elem_cnt > 0) { \
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, elem_cnt * sizeof(DstT))); \
buffers[output_name] = scratch_buffers.back().get(); \
output_dim_sizes[i] = static_cast<int>(elem_cnt); \
} else { \
scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator<void>(alloc, 1)); \
buffers[output_name] = scratch_buffers.back().get(); \
output_dim_sizes[i] = 1; \
} \
break; \
}
#define CASE_COPY_TENSOR(DATA_TYPE, DstT) \
case DATA_TYPE: { \
auto output_tensor_ptr = output_tensor.GetTensorMutableData<DstT>(); \
if (output_tensor_ptr != nullptr && elem_cnt > 0) { \
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(DstT), cudaMemcpyDeviceToDevice, stream)); \
} \
break; \
}
#define CASE_CAST_TENSOR(DATA_TYPE, SrcT, DstT) \
case DATA_TYPE: { \
auto output_tensor_ptr = output_tensor.GetTensorMutableData<DstT>(); \
if (output_tensor_ptr != nullptr && elem_cnt > 0) { \
cuda::Impl_Cast<SrcT, DstT>(stream, reinterpret_cast<SrcT*>(allocator->getBuffer()), reinterpret_cast<DstT*>(output_tensor_ptr), elem_cnt); \
} \
break; \
}
/*
* Set TensorRT execution context input.
*
* There are two types of input tensor: (1) shape tensor and (2) execution tensor.
* The input buffer binding needs to be handled differently.
*
* @param shape_tensor_values holds "shape tensor -> shape values" for the INT32 shape tensor input across this inference run
* @param shape_tensor_values_int64 holds "shape tensor -> shape values" for the INT64 shape tensor input across this inference run
*/
Status BindContextInput(Ort::KernelContext& ctx,
nvinfer1::ICudaEngine* trt_engine,
nvinfer1::IExecutionContext* trt_context,
const char* input_name,
size_t input_index,
std::unordered_map<std::string, std::vector<int32_t>>& shape_tensor_values,
std::unordered_map<std::string, std::vector<int64_t>>& shape_tensor_values_int64,
std::vector<IAllocatorUniquePtr<void>>& scratch_buffers,
OrtAllocator* alloc,
cudaStream_t stream) {
auto input_tensor = ctx.GetInput(input_index);
auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
const auto tensor_shapes = tensor_info.GetShape();
const auto tensor_type = tensor_info.GetElementType();
/*
* Return the number of elements specified by the tensor shape (all dimensions multiplied by each other).
* For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1.
*
* Examples:<br>
* [] = 1<br>
* [1,3,4] = 12<br>
* [2,0,4] = 0<br>
* [-1,3,4] = -1<br>
*/
const auto elem_cnt = tensor_info.GetElementCount();
if (trt_engine->isShapeInferenceIO(input_name)) {
// Bind "shape tensor" input buffer
// The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension
int shape_size = trt_engine->getTensorShape(input_name).nbDims == 0 ? 1 : static_cast<int>(tensor_shapes[0]);
switch (tensor_type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
// get shape tensor value if not present
if (shape_tensor_values.find(input_name) == shape_tensor_values.end()) {
auto input = std::make_unique<int32_t[]>(shape_size);
auto status = GetShapeOfShapeTensor<int32_t>(input_tensor, input.get(), shape_size, stream);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
}
shape_tensor_values[input_name].resize(shape_size);
for (int i = 0; i < shape_size; ++i) {
shape_tensor_values[input_name][i] = input[i];
}
}
if (!trt_context->setTensorAddress(input_name, &shape_tensor_values[input_name][0])) {
std::string error_input_name = input_name;
std::string error_msg =
"TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" +
error_input_name + "'";
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, error_msg));
}
break;
}
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
// get shape tensor value if not present
if (shape_tensor_values_int64.find(input_name) == shape_tensor_values_int64.end()) {
auto input = std::make_unique<int64_t[]>(shape_size);
auto status = GetShapeOfShapeTensor<int64_t>(input_tensor, input.get(), shape_size, stream);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
}
shape_tensor_values_int64[input_name].resize(shape_size);
for (int i = 0; i < shape_size; ++i) {
shape_tensor_values_int64[input_name][i] = input[i];
}
}
if (!trt_context->setTensorAddress(input_name, &shape_tensor_values_int64[input_name][0])) {
std::string error_input_name = input_name;
std::string error_msg =
"TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" +
error_input_name + "'";
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, error_msg));
}
break;
}
default: {
std::string error_input_name = input_name;
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"The data type of shape tensor should be INT32 or INT64. Please check the data type of " + error_input_name);
}
}
} else {
// Set shape for input tensor which is execution tensor
nvinfer1::Dims dims = trt_context->getTensorShape(input_name);
int nb_dims = dims.nbDims;
for (int j = 0, end = nb_dims; j < end; ++j) {
dims.d[j] = static_cast<int32_t>(tensor_shapes[j]);
}
if (!trt_context->setInputShape(input_name, dims)) {
std::string error_input_name = input_name;
ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'"));
}
// Bind "execution tensor" input buffer
//
// Note: If an engine binding is an empty tensor, it still needs a non-null memory address, and different tensors should have different addresses.
// Therefore, in the case of empty tensor, TRT EP always allocates a dummy byte.
// https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#empty-tensors
void* data = nullptr;
switch (tensor_type) {
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t)
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t)
#if NV_TENSORRT_MAJOR >= 10
CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t)
#else
// Cast int64 input to int32 input because TensorRT < 10 doesn't support int64
CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t)
#endif
// Cast double input to float because TensorRT doesn't support double
CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float)
default: {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.");
}
}
trt_context->setTensorAddress(input_name, data);
}
return Status::OK();
}
/*
* Bind TensorRT execution context output.
*
* Please note that the "data-depedent shape" output needs corresponding allocator provided.
*
*
* param ctx - ORT kernel context
* param trt_context - A pointer to TensorRT Execution context object
* param output_name - Output tensor name
* param output_index - The index of the output to the ORT kernel context
* param output_type - Data type of the output
* param i - Output iteration index
* param output_tensors - Output iteration index to output's ORT value
* param output_dim_sizes - Output iteration index to the multiplocation of its shape's dimensions
* param dds_output_set - DDS output set
* param dds_output_allocator_map - DDS output to its allocator
* param scratch_buffer - The allocation buffer created by TRT EP