From ee8ef33a5591b534cf587d347af11e48ba7a15d4 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Wed, 17 Aug 2022 11:28:22 -0700 Subject: [PATCH] Minor fix for the debug interface of using PTX directly (#1917) --- torch/csrc/jit/codegen/cuda/executor.cpp | 5 +++-- torch/csrc/jit/codegen/cuda/executor.h | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index d2299a0ce5497..2fd96e1313c7a 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -1098,7 +1098,8 @@ std::vector FusionExecutor::runFusion( void FusionExecutor::compileRtc( const std::string& code, const std::string& name, - bool structured) { + bool structured, + CompileOptions options) { FUSER_PERF_SCOPE("ExecutorRunFusion::compileRtc"); std::string scode; if (!structured) { @@ -1107,7 +1108,7 @@ void FusionExecutor::compileRtc( scode = code; } fusion_id_ = 1; - options_ = CompileOptions(); + options_ = options; std::tie(compiled_kernel_, last_compiler_log_) = executor_utils::nvrtcCompile(scode, name, fusion_id_); diff --git a/torch/csrc/jit/codegen/cuda/executor.h b/torch/csrc/jit/codegen/cuda/executor.h index 8a56fe957fb8b..7ff6b7da3aaad 100644 --- a/torch/csrc/jit/codegen/cuda/executor.h +++ b/torch/csrc/jit/codegen/cuda/executor.h @@ -141,7 +141,8 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { void compileRtc( const std::string& code, const std::string& name, - bool structured = false); + bool structured = false, + CompileOptions options = CompileOptions()); //! Internal tests only. Runs the compiled CUDA kernel from compileRtc. void runRtc(