Skip to content

Commit

Permalink
[Fix] fix multi device compile error (PaddlePaddle#57530)
Browse files Browse the repository at this point in the history
Add device_id directory when dumping information.
Reduce threads during multi card compilation.
  • Loading branch information
BiynXu committed Sep 21, 2023
1 parent 0fc00cf commit 2170f07
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 31 deletions.
41 changes: 30 additions & 11 deletions paddle/cinn/backends/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ using CompilationStatus = hlir::framework::CompilationStatus;
static constexpr int DebugLogMaxLen = 30000;

void CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
const ir::LoweredFunc& lowered_func, const int gidx) {
const ir::LoweredFunc& lowered_func, const int gidx, const int device_id) {
if (FLAGS_cinn_dump_group_lowered_func.empty() ||
lowered_func.get() == nullptr) {
return;
Expand All @@ -54,34 +54,42 @@ void CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
content << lowered_func;
Dump(FLAGS_cinn_dump_group_lowered_func,
gidx,
device_id,
"lowered_function.txt",
content.str());
}

void CompilationInfoDumper::DumpSourceCodeByGroupIndex(
const std::string& source_code, const int gidx) {
const std::string& source_code, const int gidx, const int device_id) {
if (FLAGS_cinn_dump_group_source_code.empty()) {
return;
}
Dump(FLAGS_cinn_dump_group_source_code, gidx, "source_code.cu", source_code);
Dump(FLAGS_cinn_dump_group_source_code,
gidx,
device_id,
"source_code.cu",
source_code);
}

void CompilationInfoDumper::DumpPtxCodeByGroupIndex(
const std::string& source_ptx, const int gidx) {
const std::string& source_ptx, const int gidx, const int device_id) {
if (FLAGS_cinn_dump_group_ptx.empty()) {
return;
}
Dump(FLAGS_cinn_dump_group_ptx, gidx, "source_ptx.ptx", source_ptx);
Dump(
FLAGS_cinn_dump_group_ptx, gidx, device_id, "source_ptx.ptx", source_ptx);
}

void CompilationInfoDumper::DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx) {
const int gidx,
const int device_id) {
if (FLAGS_cinn_dump_group_instruction.empty() || instr.get() == nullptr) {
return;
}
Dump(FLAGS_cinn_dump_group_instruction,
gidx,
device_id,
"instruction.txt",
instr->DumpInstruction());
}
Expand All @@ -99,6 +107,7 @@ void CompilationInfoDumper::DumpLoweredFunc() {
}
Dump(FLAGS_cinn_dump_group_lowered_func,
idx,
device_id_,
"lowered_function.txt",
content.str());
}
Expand All @@ -115,7 +124,11 @@ void CompilationInfoDumper::DumpSourceCode() {
} else {
dump_str = "[No source code generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_source_code, idx, "source_code.cu", dump_str);
Dump(FLAGS_cinn_dump_group_source_code,
idx,
device_id_,
"source_code.cu",
dump_str);
}
}

Expand All @@ -130,7 +143,8 @@ void CompilationInfoDumper::DumpPtxCode() {
} else {
dump_str = "[No source ptxs generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_ptx, idx, "source_ptx.ptx", dump_str);
Dump(
FLAGS_cinn_dump_group_ptx, idx, device_id_, "source_ptx.ptx", dump_str);
}
}

Expand All @@ -145,16 +159,21 @@ void CompilationInfoDumper::DumpInstruction() {
} else {
dump_str = "[No instruction generated]\n\n" + info_.Message(idx);
}
Dump(FLAGS_cinn_dump_group_instruction, idx, "instruction.txt", dump_str);
Dump(FLAGS_cinn_dump_group_instruction,
idx,
device_id_,
"instruction.txt",
dump_str);
}
}

void CompilationInfoDumper::Dump(const std::string& base_path,
const int idx,
const int device_id,
const std::string& file_name,
const std::string& content) {
auto dump_path =
utils::StringFormat("%s/fusion_group_%d", base_path.c_str(), idx);
auto dump_path = utils::StringFormat(
"%s/device_%d/fusion_group_%d", base_path.c_str(), device_id, idx);
if (!hlir::framework::MakeDirectory(
dump_path, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) {
LOG(WARNING) << "Failed to make directory: \"" << dump_path
Expand Down
19 changes: 13 additions & 6 deletions paddle/cinn/backends/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,28 @@ namespace backends {
*/
class CompilationInfoDumper {
public:
explicit CompilationInfoDumper(const hlir::framework::CompilationResult& info)
: info_(info) {
explicit CompilationInfoDumper(const hlir::framework::CompilationResult& info,
const int device_id)
: info_(info), device_id_(device_id) {
DumpLoweredFunc();
DumpSourceCode();
DumpPtxCode();
DumpInstruction();
}

static void DumpLoweredFuncByGroupIndex(const ir::LoweredFunc& lowered_func,
const int gidx);
const int gidx,
const int device_id);
static void DumpSourceCodeByGroupIndex(const std::string& source_code,
const int gidx);
const int gidx,
const int device_id);
static void DumpPtxCodeByGroupIndex(const std::string& source_ptx,
const int gidx);
const int gidx,
const int device_id);
static void DumpInstructionByGroupIndex(
const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
const int gidx);
const int gidx,
const int device_id);

private:
void DumpLoweredFunc();
Expand All @@ -68,10 +73,12 @@ class CompilationInfoDumper {
void DumpInstruction();
static void Dump(const std::string& base_path,
const int idx,
const int device_id,
const std::string& file_name,
const std::string& content);

const hlir::framework::CompilationResult& info_;
const int device_id_;
};

class SourceCodePrint {
Expand Down
10 changes: 9 additions & 1 deletion paddle/cinn/hlir/framework/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#include <sstream>

#include "paddle/cinn/hlir/framework/visualize_helper.h"
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/runtime/cuda/cuda_util.h"
#endif
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"

Expand Down Expand Up @@ -315,9 +318,14 @@ void Graph::VisualizeGroupedGraph(
const auto& group_dots = VisualizeGroups(groups, fetch_var_ids);
for (int idx = 0; idx < groups.size(); ++idx) {
// Create fusion_group_x folder
int device_id = 0;
#ifdef CINN_WITH_CUDA
cudaGetDevice(&device_id);
#endif
auto group_path =
utils::StringFormat("%s/fusion_group_%d",
utils::StringFormat("%s/device_%d/fusion_group_%d",
FLAGS_cinn_fusion_groups_graphviz_dir.c_str(),
device_id,
idx);
if (!MakeDirectory(group_path,
S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) {
Expand Down
38 changes: 27 additions & 11 deletions paddle/cinn/hlir/framework/parallel_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@ void ParallelCompiler::SplitTask() {
CHECK(context_->lowered_funcs.empty() ||
context_->graph->fusion_groups.size() ==
context_->lowered_funcs.size());
for (int i = 0; i < context_->graph->fusion_groups.size(); ++i) {
tasks_.emplace_back(i, this, context_);
int device_id = 0;
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaGetDevice(&device_id));
#endif
for (int group_id = 0; group_id < context_->graph->fusion_groups.size();
++group_id) {
tasks_.emplace_back(device_id, group_id, this, context_);
}
}

Expand Down Expand Up @@ -126,11 +131,20 @@ void ParallelCompiler::RunTask() {
}

void ParallelCompiler::LaunchTask() {
int device_id = 0;
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaGetDevice(&device_id));
#endif
int num_threads = FLAGS_cinn_parallel_compile_thread;
#if defined(PADDLE_WITH_DISTRIBUTE)
if (device_id > 0) {
num_threads = 1;
}
#endif
// multi thread compilation
std::vector<std::thread> threads;
VLOG(4) << "Compile with " << FLAGS_cinn_parallel_compile_thread
<< " threads";
for (int idx = 1; idx < FLAGS_cinn_parallel_compile_thread; ++idx) {
VLOG(4) << "Compile with " << num_threads << " threads";
for (int idx = 1; idx < num_threads; ++idx) {
threads.emplace_back(&ParallelCompiler::RunTask, this);
}

Expand Down Expand Up @@ -208,7 +222,7 @@ void ParallelCompiler::Task::Lowering() {
pcompiler->result_.SetLoweredFuncs(group_id, lowered_funcs);
}
backends::CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
pcompiler->result_.LoweredFuncs(group_id).front(), group_id);
pcompiler->result_.LoweredFuncs(group_id).front(), group_id, device_id);
}

void ParallelCompiler::Task::CodegenAndJit() {
Expand Down Expand Up @@ -239,8 +253,8 @@ void ParallelCompiler::Task::CodegenAndJit() {
}
CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n"
<< dmodule;
backends::CompilationInfoDumper::DumpSourceCodeByGroupIndex(cuda_c,
group_id);
backends::CompilationInfoDumper::DumpSourceCodeByGroupIndex(
cuda_c, group_id, device_id);
pcompiler->result_.SetSourceCode(group_id, cuda_c);

cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c);
Expand All @@ -249,7 +263,8 @@ void ParallelCompiler::Task::CodegenAndJit() {
backends::nvrtc::Compiler compiler;
auto ptx = compiler(cuda_c);
CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c;
backends::CompilationInfoDumper::DumpPtxCodeByGroupIndex(ptx, group_id);
backends::CompilationInfoDumper::DumpPtxCodeByGroupIndex(
ptx, group_id, device_id);
pcompiler->result_.SetSourcePtx(group_id, ptx);
// load cumodule
cumodule = std::make_unique<CUDAModule>(ptx,
Expand All @@ -260,7 +275,7 @@ void ParallelCompiler::Task::CodegenAndJit() {
// register kernel
backends::RuntimeSymbols symbols;
for (auto& fn : dmodule.functions()) {
auto cufunc = cumodule->GetFunction(0, fn->name);
auto cufunc = cumodule->GetFunction(device_id, fn->name);
CHECK(cufunc);
symbols.RegisterVar(fn->name + "_ptr_", reinterpret_cast<void*>(cufunc));
}
Expand Down Expand Up @@ -291,7 +306,8 @@ void ParallelCompiler::Task::BuildInstruction() {
instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), group->GetFuncName());

instr->Finalize();
backends::CompilationInfoDumper::DumpInstructionByGroupIndex(instr, group_id);
backends::CompilationInfoDumper::DumpInstructionByGroupIndex(
instr, group_id, device_id);
pcompiler->result_.SetInstruction(group_id, std::move(instr));
}

Expand Down
11 changes: 9 additions & 2 deletions paddle/cinn/hlir/framework/parallel_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ namespace framework {
class ParallelCompiler {
public:
struct Task {
Task(int group_id, ParallelCompiler* compiler, CompilationContext* context)
: group_id(group_id), pcompiler(compiler), context(context) {}
Task(int device_id,
int group_id,
ParallelCompiler* compiler,
CompilationContext* context)
: device_id(device_id),
group_id(group_id),
pcompiler(compiler),
context(context) {}
void Lowering();
void CodegenAndJit();
void BuildInstruction();
Expand All @@ -48,6 +54,7 @@ class ParallelCompiler {
CompilationStatus status = CompilationStatus::SUCCESS;
std::string message;

const int device_id;
int group_id;

std::unique_ptr<backends::ExecutionEngine> engine;
Expand Down

0 comments on commit 2170f07

Please sign in to comment.