diff --git a/torch_glow/src/CachingGraphRunner.cpp b/torch_glow/src/CachingGraphRunner.cpp index abd621ccc1..cbe0b57130 100644 --- a/torch_glow/src/CachingGraphRunner.cpp +++ b/torch_glow/src/CachingGraphRunner.cpp @@ -351,8 +351,7 @@ int64_t CachingGraphRunner::runOnJit(torch::jit::Stack &stack) { std::lock_guard guard(runJitLock); bool temp = getGlobalPyTorchLoaderSettingsMutable().fusionPassEnabled; getGlobalPyTorchLoaderSettingsMutable().fusionPassEnabled = false; - int64_t startTime; - startTime = TraceEvent::now(); + int64_t startTime = TraceEvent::now(); ptGraphExecutor_.run(stack); int64_t runTime = TraceEvent::now() - startTime; getGlobalPyTorchLoaderSettingsMutable().fusionPassEnabled = temp; @@ -415,12 +414,6 @@ Error CachingGraphRunner::runImpl(const PerGlowGraphInfo &info, // Run the subgraph using JIT for comparison with Glow. torch::jit::Stack copyStack; if (settings.writeToOnnx || settings.jitVsGlowCompare) { - - // We will use original graph for runOnJit, which means the first input - // should be module. - if (origGraph_ != nullptr) { - copyStack.push_back(module_); - } for (auto &ival : stack) { if (ival.isTensor()) { copyStack.push_back(ival.deepcopy()); @@ -886,18 +879,11 @@ Error CachingGraphRunner::warmCache(const std::vector &inputMeta, CachingGraphRunner::CachingGraphRunner( std::shared_ptr graph, std::shared_ptr hostManager, - PyTorchLoaderSettings defaultSettings, bool useRunOnly, - std::shared_ptr origGraph, c10::IValue module) - : graph_(graph), origGraph_(origGraph), ptGraphExecutor_(graph, "forward"), - module_(module), hostManager_(hostManager), + PyTorchLoaderSettings defaultSettings, bool useRunOnly) + : graph_(graph), ptGraphExecutor_(graph, "forward"), + hostManager_(hostManager), backend_(*EXIT_ON_ERR(hostManager->getBackend())), defaultSettings_(std::move(defaultSettings)), useRunOnly_(useRunOnly) { - - if (origGraph_ != nullptr) { - ptGraphExecutor_ = torch::jit::GraphExecutor(origGraph_, "forward"); - } else { - ptGraphExecutor_ = torch::jit::GraphExecutor(graph_, "forward"); - } mergedTraceContext_ = glow::make_unique(TraceLevel::STANDARD); } diff --git a/torch_glow/src/CachingGraphRunner.h b/torch_glow/src/CachingGraphRunner.h index 34876f2ec9..630c21b323 100644 --- a/torch_glow/src/CachingGraphRunner.h +++ b/torch_glow/src/CachingGraphRunner.h @@ -60,17 +60,9 @@ class CachingGraphRunner { /// for. std::shared_ptr graph_; - /// The PyTorch JIT Graph that this CachingGraphRunner caches for before - /// any preprocessing is done. - std::shared_ptr origGraph_; - /// GraphExecutor used to execute graph_ on PyTorch for debugging purposes. torch::jit::GraphExecutor ptGraphExecutor_; - /// The PyTorch module of the graph. - /// It is used as first input when running origGraph_ on JIT. - c10::IValue module_; - /// The HostManager used to store and run Glow graphs. std::shared_ptr hostManager_; @@ -191,9 +183,7 @@ class CachingGraphRunner { public: CachingGraphRunner(std::shared_ptr graph, std::shared_ptr hostManager, - PyTorchLoaderSettings settings, bool useRunOnly = false, - std::shared_ptr origGraph = nullptr, - c10::IValue module = c10::IValue()); + PyTorchLoaderSettings settings, bool useRunOnly = false); ~CachingGraphRunner(); diff --git a/torch_glow/src/TorchGlowBackend.cpp b/torch_glow/src/TorchGlowBackend.cpp index d93540cce1..95c9e69594 100644 --- a/torch_glow/src/TorchGlowBackend.cpp +++ b/torch_glow/src/TorchGlowBackend.cpp @@ -484,11 +484,60 @@ static Error ProcessPackedParams(torch::jit::Graph &graph, return Error::success(); } +/// Implementation of to_backend preprocess method for Glow. \returns the +/// preprocessed Module if successful or an Error otherwise which is converted +/// to an exception for handling within PyTorch. +static Expected +preprocessImpl(torch::jit::Module origModule, + c10::impl::GenericDict method_compile_spec) { + // Preprocess each method + for (const auto &kv : method_compile_spec) { + const auto &methodName = kv.key().toStringRef(); + auto method = origModule.get_method(methodName); + auto graph = method.graph(); + + GraphOutputType graphOutputType; + ASSIGN_VALUE_OR_RETURN_ERR(graphOutputType, + checkGraphInputsAndOutputs(*graph)); + + // Output lists no supported yet + if (graphOutputType == GraphOutputType::TENSOR_LIST) { + return MAKE_ERR("Tensor list output not supported."); + } + + detail::fuseConcat(graph); + torch::jit::Inline(*graph); + RewriteQuantPackedParamOps(graph); + RETURN_IF_ERR(ProcessPackedParams(*graph, origModule._ivalue())); + } + + // Freeze + auto preprocModule = torch::jit::freeze_module(origModule); + + // Cleanup JIT graphs + for (const auto &kv : method_compile_spec) { + const auto &methodName = kv.key().toStringRef(); + auto method = preprocModule.get_method(methodName); + auto graph = method.graph(); + EliminateDeadCode(graph); + EliminateCommonSubexpression(graph); + ConstantPooling(graph); + } + + return preprocModule; +} + c10::IValue TorchGlowBackend::preprocess(c10::IValue mod, c10::impl::GenericDict method_compile_spec) { - // We do nothing in the preprocess, instead we do them in compile() - return mod; + + torch::jit::Module origModule = mod.toModule(); + origModule.eval(); + auto resOrErr = preprocessImpl(origModule, method_compile_spec); + if (!resOrErr) { + throw std::runtime_error(ERR_TO_STRING(resOrErr.takeError())); + } + return resOrErr->_ivalue(); } Error applySettingsOverrideFlagsToPyTorchLoaderSettings( @@ -555,57 +604,17 @@ Error applyFuserSettingsToPyTorchLoaderSettings( static Expected, std::unique_ptr>>> -compileImpl(const torch::jit::Module &origModule, +compileImpl(const torch::jit::Module &module, const c10::impl::GenericDict &method_compile_spec) { std::unordered_map, std::unique_ptr>> methodToRunnerMap; - std::unordered_map> - nameToOrigGraph; - - for (const auto &kv : method_compile_spec) { - const auto &methodName = kv.key().toStringRef(); - auto method = origModule.get_method(methodName); - auto graph = method.graph(); - nameToOrigGraph[methodName] = graph->copy(); - - GraphOutputType graphOutputType; - ASSIGN_VALUE_OR_RETURN_ERR(graphOutputType, - checkGraphInputsAndOutputs(*graph)); - - // Output lists no supported yet - if (graphOutputType == GraphOutputType::TENSOR_LIST) { - return MAKE_ERR("Tensor list output not supported."); - } - - detail::fuseConcat(graph); - torch::jit::Inline(*graph); - RewriteQuantPackedParamOps(graph); - RETURN_IF_ERR(ProcessPackedParams(*graph, origModule._ivalue())); - } - - // Freeze - auto preprocModule = torch::jit::freeze_module(origModule); - - // Cleanup JIT graphs - for (const auto &kv : method_compile_spec) { - const auto &methodName = kv.key().toStringRef(); - auto method = preprocModule.get_method(methodName); - auto graph = method.graph(); - EliminateDeadCode(graph); - EliminateCommonSubexpression(graph); - ConstantPooling(graph); - } // Compile each method for (const auto &kv : method_compile_spec) { const auto methodName = kv.key().toString()->string(); - const auto &method = preprocModule.get_method(methodName); - auto it = nameToOrigGraph.find(methodName); - CHECK(it != nameToOrigGraph.end()) - << "Cannot find corresponding original graph for graph: " << methodName; - auto origGraph = it->second; + const auto &method = module.get_method(methodName); const CompilationSpec &spec = *kv.value().toCustomClass(); RETURN_IF_ERR(spec.validate()); @@ -648,7 +657,7 @@ compileImpl(const torch::jit::Module &origModule, // Run fusion flow using JIT graph runner std::unique_ptr runner = std::make_unique( - preprocModule._ivalue(), graph, baseSettings); + module._ivalue(), graph, baseSettings); methodToRunnerMap.emplace(methodName, std::make_pair(nullptr, std::move(runner))); } else { @@ -664,8 +673,7 @@ compileImpl(const torch::jit::Module &origModule, graph, glow::getHostManager(baseSettings.backendName, baseSettings.numDevices), - baseSettings, /*useRunOnly*/ true, origGraph, - origModule._ivalue()); + baseSettings, /*useRunOnly*/ true); // Compile each compilation group for (const auto &compilationGroup : spec.compilation_groups) { @@ -696,7 +704,6 @@ TorchGlowBackend::compile(c10::IValue processed, c10::impl::GenericDict method_compile_spec) { auto module = processed.toModule().clone(); - module.eval(); auto runnersOrErr = compileImpl(module, method_compile_spec); if (!runnersOrErr) {