diff --git a/torch/csrc/jit/codegen/cuda/executor.cpp b/torch/csrc/jit/codegen/cuda/executor.cpp index d2299a0ce5497..2fd96e1313c7a 100644 --- a/torch/csrc/jit/codegen/cuda/executor.cpp +++ b/torch/csrc/jit/codegen/cuda/executor.cpp @@ -1098,7 +1098,8 @@ std::vector FusionExecutor::runFusion( void FusionExecutor::compileRtc( const std::string& code, const std::string& name, - bool structured) { + bool structured, + CompileOptions options) { FUSER_PERF_SCOPE("ExecutorRunFusion::compileRtc"); std::string scode; if (!structured) { @@ -1107,7 +1108,7 @@ void FusionExecutor::compileRtc( scode = code; } fusion_id_ = 1; - options_ = CompileOptions(); + options_ = options; std::tie(compiled_kernel_, last_compiler_log_) = executor_utils::nvrtcCompile(scode, name, fusion_id_); diff --git a/torch/csrc/jit/codegen/cuda/executor.h b/torch/csrc/jit/codegen/cuda/executor.h index 8a56fe957fb8b..7ff6b7da3aaad 100644 --- a/torch/csrc/jit/codegen/cuda/executor.h +++ b/torch/csrc/jit/codegen/cuda/executor.h @@ -141,7 +141,8 @@ class TORCH_CUDA_CU_API FusionExecutor : public NonCopyable { void compileRtc( const std::string& code, const std::string& name, - bool structured = false); + bool structured = false, + CompileOptions options = CompileOptions()); //! Internal tests only. Runs the compiled CUDA kernel from compileRtc. void runRtc(