From f19ca37c23c7cc63f9560079ee6f062ae230c2fd Mon Sep 17 00:00:00 2001 From: seemingwang Date: Thu, 16 Jun 2022 11:40:09 +0800 Subject: [PATCH] support graph inference (#34) * gpu_graph_infer * simplify infer * fix * remove logs * remove logs * change logs --- paddle/fluid/framework/data_feed.cu | 229 ++++++++++++++-------- paddle/fluid/framework/data_feed.h | 4 +- paddle/fluid/framework/data_feed.proto | 1 + paddle/fluid/framework/device_worker.cc | 68 +++++-- paddle/fluid/framework/device_worker.h | 4 +- paddle/fluid/framework/trainer.cc | 1 - paddle/fluid/framework/trainer_desc.proto | 2 +- python/paddle/fluid/dataset.py | 2 + python/paddle/fluid/trainer_desc.py | 3 + python/paddle/fluid/trainer_factory.py | 3 + 10 files changed, 217 insertions(+), 100 deletions(-) diff --git a/paddle/fluid/framework/data_feed.cu b/paddle/fluid/framework/data_feed.cu index 51293bcb1c547..2327ca6e9115a 100644 --- a/paddle/fluid/framework/data_feed.cu +++ b/paddle/fluid/framework/data_feed.cu @@ -157,6 +157,14 @@ __global__ void GraphFillCVMKernel(int64_t *tensor, int len) { CUDA_KERNEL_LOOP(idx, len) { tensor[idx] = 1; } } +__global__ void CopyDuplicateKeys(int64_t *dist_tensor, uint64_t *src_tensor, + int len) { + CUDA_KERNEL_LOOP(idx, len) { + dist_tensor[idx * 2] = src_tensor[idx]; + dist_tensor[idx * 2 + 1] = src_tensor[idx]; + } +} + int GraphDataGenerator::AcquireInstance(BufState *state) { // if (state->GetNextStep()) { @@ -174,8 +182,9 @@ int GraphDataGenerator::AcquireInstance(BufState *state) { // TODO opt __global__ void GraphFillFeatureKernel(int64_t *id_tensor, int *fill_ins_num, - int64_t *walk, int64_t *feature, int *row, int central_word, - int step, int len, int col_num, int slot_num) { + int64_t *walk, int64_t *feature, + int *row, int central_word, int step, + int len, int col_num, int slot_num) { __shared__ int32_t local_key[CUDA_NUM_THREADS * 16]; __shared__ int local_num; __shared__ int global_num; @@ -191,7 +200,8 @@ __global__ void GraphFillFeatureKernel(int64_t *id_tensor, int *fill_ins_num, size_t dst = atomicAdd(&local_num, 1); for (int i = 0; i < slot_num; ++i) { local_key[dst * 2 * slot_num + i * 2] = feature[src * slot_num + i]; - local_key[dst * 2 * slot_num + i * 2 + 1] = feature[(src + step) * slot_num + i]; + local_key[dst * 2 * slot_num + i * 2 + 1] = + feature[(src + step) * slot_num + i]; } } } @@ -203,8 +213,8 @@ __global__ void GraphFillFeatureKernel(int64_t *id_tensor, int *fill_ins_num, if (threadIdx.x < local_num) { for (int i = 0; i < slot_num; ++i) { - id_tensor[(global_num * 2 + 2 * threadIdx.x) * slot_num + i] - = local_key[(2 * threadIdx.x) * slot_num + i]; + id_tensor[(global_num * 2 + 2 * threadIdx.x) * slot_num + i] = + local_key[(2 * threadIdx.x) * slot_num + i]; id_tensor[(global_num * 2 + 2 * threadIdx.x + 1) * slot_num + i] = local_key[(2 * threadIdx.x + 1) * slot_num + i]; } @@ -247,27 +257,27 @@ __global__ void GraphFillIdKernel(int64_t *id_tensor, int *fill_ins_num, } } -__global__ void GraphFillSlotKernel(int64_t *id_tensor, int64_t* feature_buf, int len, - int total_ins, int slot_num) { +__global__ void GraphFillSlotKernel(int64_t *id_tensor, int64_t *feature_buf, + int len, int total_ins, int slot_num) { CUDA_KERNEL_LOOP(idx, len) { int slot_idx = idx / total_ins; int ins_idx = idx % total_ins; - ((int64_t*)(id_tensor[slot_idx]))[ins_idx] = feature_buf[ins_idx * slot_num + slot_idx]; + ((int64_t *)(id_tensor[slot_idx]))[ins_idx] = + feature_buf[ins_idx * slot_num + slot_idx]; } } -__global__ void GraphFillSlotLodKernelOpt(int64_t *id_tensor, int len, int total_ins) { +__global__ void GraphFillSlotLodKernelOpt(int64_t *id_tensor, int len, + int total_ins) { CUDA_KERNEL_LOOP(idx, len) { int slot_idx = idx / total_ins; int ins_idx = idx % total_ins; - ((int64_t*)(id_tensor[slot_idx]))[ins_idx] = ins_idx; + ((int64_t *)(id_tensor[slot_idx]))[ins_idx] = ins_idx; } } __global__ void GraphFillSlotLodKernel(int64_t *id_tensor, int len) { - CUDA_KERNEL_LOOP(idx, len) { - id_tensor[idx] = idx; - } + CUDA_KERNEL_LOOP(idx, len) { id_tensor[idx] = idx; } } int GraphDataGenerator::FillInsBuf() { @@ -296,20 +306,22 @@ int GraphDataGenerator::FillInsBuf() { if (!FLAGS_enable_opt_get_features && slot_num_ > 0) { FillFeatureBuf(d_walk_, d_feature_); if (debug_mode_) { - int len = buf_size_ > 5000? 5000: buf_size_; + int len = buf_size_ > 5000 ? 5000 : buf_size_; uint64_t h_walk[len]; cudaMemcpy(h_walk, d_walk_->ptr(), len * sizeof(uint64_t), - cudaMemcpyDeviceToHost); + cudaMemcpyDeviceToHost); uint64_t h_feature[len * slot_num_]; - cudaMemcpy(h_feature, d_feature_->ptr(), len * slot_num_ * sizeof(uint64_t), - cudaMemcpyDeviceToHost); - for(int i = 0; i < len; ++i) { + cudaMemcpy(h_feature, d_feature_->ptr(), + len * slot_num_ * sizeof(uint64_t), cudaMemcpyDeviceToHost); + for (int i = 0; i < len; ++i) { std::stringstream ss; for (int j = 0; j < slot_num_; ++j) { ss << h_feature[i * slot_num_ + j] << " "; } - VLOG(2) << "aft FillFeatureBuf, gpu[" << gpuid_ << "] walk[" << i << "] = " << (uint64_t)h_walk[i] - << " feature[" << i * slot_num_ << ".." << (i + 1) * slot_num_ << "] = " << ss.str(); + VLOG(2) << "aft FillFeatureBuf, gpu[" << gpuid_ << "] walk[" << i + << "] = " << (uint64_t)h_walk[i] << " feature[" + << i * slot_num_ << ".." << (i + 1) * slot_num_ + << "] = " << ss.str(); } } } @@ -334,10 +346,11 @@ int GraphDataGenerator::FillInsBuf() { int64_t *feature = reinterpret_cast(d_feature_->ptr()); cudaMemsetAsync(d_pair_num, 0, sizeof(int), stream_); int len = buf_state_.len; - VLOG(2) << "feature_buf start[" << ins_buf_pair_len_ * 2 * slot_num_ << "] len[" << len << "]"; + VLOG(2) << "feature_buf start[" << ins_buf_pair_len_ * 2 * slot_num_ + << "] len[" << len << "]"; GraphFillFeatureKernel<<>>( - feature_buf + ins_buf_pair_len_ * 2 * slot_num_, d_pair_num, walk, feature, - random_row + buf_state_.cursor, buf_state_.central_word, + feature_buf + ins_buf_pair_len_ * 2 * slot_num_, d_pair_num, walk, + feature, random_row + buf_state_.cursor, buf_state_.central_word, window_step_[buf_state_.step], len, walk_len_, slot_num_); } @@ -358,8 +371,9 @@ int GraphDataGenerator::FillInsBuf() { if (!FLAGS_enable_opt_get_features && slot_num_ > 0) { int64_t *feature_buf = reinterpret_cast(d_feature_buf_->ptr()); int64_t h_feature_buf[(batch_size_ * 2 * 2) * slot_num_]; - cudaMemcpy(h_feature_buf, feature_buf, (batch_size_ * 2 * 2) * slot_num_ * sizeof(int64_t), - cudaMemcpyDeviceToHost); + cudaMemcpy(h_feature_buf, feature_buf, + (batch_size_ * 2 * 2) * slot_num_ * sizeof(int64_t), + cudaMemcpyDeviceToHost); for (int xx = 0; xx < (batch_size_ * 2 * 2) * slot_num_; xx++) { VLOG(2) << "h_feature_buf[" << xx << "]: " << h_feature_buf[xx]; } @@ -368,8 +382,44 @@ int GraphDataGenerator::FillInsBuf() { return ins_buf_pair_len_; } - int GraphDataGenerator::GenerateBatch() { + if (!gpu_graph_training_) { + while (cursor_ < h_device_keys_.size()) { + size_t device_key_size = h_device_keys_[cursor_]->size(); + if (infer_node_type_start_[cursor_] >= device_key_size) { + cursor_++; + continue; + } + int total_instance = + (infer_node_type_start_[cursor_] + batch_size_ <= device_key_size) + ? batch_size_ + : device_key_size - infer_node_type_start_[cursor_]; + uint64_t *d_type_keys = + reinterpret_cast(d_device_keys_[cursor_]->ptr()); + d_type_keys += infer_node_type_start_[cursor_]; + infer_node_type_start_[cursor_] += total_instance; + total_instance *= 2; + id_tensor_ptr_ = feed_vec_[0]->mutable_data({total_instance, 1}, + this->place_); + show_tensor_ptr_ = + feed_vec_[1]->mutable_data({total_instance}, this->place_); + clk_tensor_ptr_ = + feed_vec_[2]->mutable_data({total_instance}, this->place_); + /* + cudaMemcpyAsync(id_tensor_ptr_, d_type_keys, sizeof(int64_t) * total_instance, + cudaMemcpyDeviceToDevice, stream_); + */ + CopyDuplicateKeys<<>>(id_tensor_ptr_, d_type_keys, + total_instance / 2); + GraphFillCVMKernel<<>>(show_tensor_ptr_, total_instance); + GraphFillCVMKernel<<>>(clk_tensor_ptr_, total_instance); + return total_instance / 2; + } + return 0; + } platform::CUDADeviceGuard guard(gpuid_); int res = 0; while (ins_buf_pair_len_ < batch_size_) { @@ -393,20 +443,22 @@ int GraphDataGenerator::GenerateBatch() { clk_tensor_ptr_ = feed_vec_[2]->mutable_data({total_instance}, this->place_); - int64_t* slot_tensor_ptr_[slot_num_]; - int64_t* slot_lod_tensor_ptr_[slot_num_]; + int64_t *slot_tensor_ptr_[slot_num_]; + int64_t *slot_lod_tensor_ptr_[slot_num_]; if (slot_num_ > 0) { for (int i = 0; i < slot_num_; ++i) { - slot_tensor_ptr_[i] = - feed_vec_[3 + 2 * i]->mutable_data({total_instance, 1}, this->place_); - slot_lod_tensor_ptr_[i] = - feed_vec_[3 + 2 * i + 1]->mutable_data({total_instance + 1}, this->place_); + slot_tensor_ptr_[i] = feed_vec_[3 + 2 * i]->mutable_data( + {total_instance, 1}, this->place_); + slot_lod_tensor_ptr_[i] = feed_vec_[3 + 2 * i + 1]->mutable_data( + {total_instance + 1}, this->place_); } if (FLAGS_enable_opt_get_features) { - cudaMemcpyAsync(d_slot_tensor_ptr_->ptr(), slot_tensor_ptr_, sizeof(int64_t*) * slot_num_, - cudaMemcpyHostToDevice, stream_); - cudaMemcpyAsync(d_slot_lod_tensor_ptr_->ptr(), slot_lod_tensor_ptr_, sizeof(int64_t*) * slot_num_, - cudaMemcpyHostToDevice, stream_); + cudaMemcpyAsync(d_slot_tensor_ptr_->ptr(), slot_tensor_ptr_, + sizeof(int64_t *) * slot_num_, cudaMemcpyHostToDevice, + stream_); + cudaMemcpyAsync(d_slot_lod_tensor_ptr_->ptr(), slot_lod_tensor_ptr_, + sizeof(int64_t *) * slot_num_, cudaMemcpyHostToDevice, + stream_); } } @@ -429,41 +481,48 @@ int GraphDataGenerator::GenerateBatch() { if (debug_mode_) { uint64_t h_walk[total_instance]; cudaMemcpy(h_walk, ins_cursor, total_instance * sizeof(uint64_t), - cudaMemcpyDeviceToHost); + cudaMemcpyDeviceToHost); uint64_t h_feature[total_instance * slot_num_]; - cudaMemcpy(h_feature, feature_buf, total_instance * slot_num_ * sizeof(uint64_t), - cudaMemcpyDeviceToHost); - for(int i = 0; i < total_instance; ++i) { + cudaMemcpy(h_feature, feature_buf, + total_instance * slot_num_ * sizeof(uint64_t), + cudaMemcpyDeviceToHost); + for (int i = 0; i < total_instance; ++i) { std::stringstream ss; for (int j = 0; j < slot_num_; ++j) { ss << h_feature[i * slot_num_ + j] << " "; } - VLOG(2) << "aft FillFeatureBuf, gpu[" << gpuid_ << "] walk[" << i << "] = " << (uint64_t)h_walk[i] - << " feature[" << i * slot_num_ << ".." << (i + 1) * slot_num_ << "] = " << ss.str(); + VLOG(2) << "aft FillFeatureBuf, gpu[" << gpuid_ << "] walk[" << i + << "] = " << (uint64_t)h_walk[i] << " feature[" + << i * slot_num_ << ".." << (i + 1) * slot_num_ + << "] = " << ss.str(); } } GraphFillSlotKernel<<>>( - (int64_t*)d_slot_tensor_ptr_->ptr(), feature_buf, - total_instance * slot_num_, total_instance, slot_num_); + CUDA_NUM_THREADS, 0, stream_>>>( + (int64_t *)d_slot_tensor_ptr_->ptr(), feature_buf, + total_instance * slot_num_, total_instance, slot_num_); GraphFillSlotLodKernelOpt<<>>( - (int64_t*)d_slot_lod_tensor_ptr_->ptr(), (total_instance + 1) * slot_num_, - total_instance + 1); - } else { + CUDA_NUM_THREADS, 0, stream_>>>( + (int64_t *)d_slot_lod_tensor_ptr_->ptr(), + (total_instance + 1) * slot_num_, total_instance + 1); + } else { for (int i = 0; i < slot_num_; ++i) { - int feature_buf_offset = (ins_buf_pair_len_ * 2 - total_instance) * slot_num_ + i * 2; + int feature_buf_offset = + (ins_buf_pair_len_ * 2 - total_instance) * slot_num_ + i * 2; for (int j = 0; j < total_instance; j += 2) { VLOG(2) << "slot_tensor[" << i << "][" << j << "] <- feature_buf[" - << feature_buf_offset + j * slot_num_ << "]"; + << feature_buf_offset + j * slot_num_ << "]"; VLOG(2) << "slot_tensor[" << i << "][" << j + 1 << "] <- feature_buf[" - << feature_buf_offset + j * slot_num_ + 1 << "]"; - cudaMemcpyAsync(slot_tensor_ptr_[i] + j, &feature_buf[feature_buf_offset + j * slot_num_], - sizeof(int64_t) * 2, cudaMemcpyDeviceToDevice, stream_); + << feature_buf_offset + j * slot_num_ + 1 << "]"; + cudaMemcpyAsync(slot_tensor_ptr_[i] + j, + &feature_buf[feature_buf_offset + j * slot_num_], + sizeof(int64_t) * 2, cudaMemcpyDeviceToDevice, + stream_); } - GraphFillSlotLodKernel<<>>( - slot_lod_tensor_ptr_[i], total_instance + 1); + GraphFillSlotLodKernel<<>>(slot_lod_tensor_ptr_[i], + total_instance + 1); } } } @@ -487,18 +546,21 @@ int GraphDataGenerator::GenerateBatch() { int64_t h_slot_tensor[slot_num_][total_instance]; int64_t h_slot_lod_tensor[slot_num_][total_instance + 1]; for (int i = 0; i < slot_num_; ++i) { - cudaMemcpy(h_slot_tensor[i], slot_tensor_ptr_[i], total_instance * sizeof(int64_t), - cudaMemcpyDeviceToHost); - int len = total_instance > 5000? 5000: total_instance; - for(int j = 0; j < len; ++j) { - VLOG(2) << "gpu[" << gpuid_ << "] slot_tensor[" << i <<"][" << j << "] = " << h_slot_tensor[i][j]; + cudaMemcpy(h_slot_tensor[i], slot_tensor_ptr_[i], + total_instance * sizeof(int64_t), cudaMemcpyDeviceToHost); + int len = total_instance > 5000 ? 5000 : total_instance; + for (int j = 0; j < len; ++j) { + VLOG(2) << "gpu[" << gpuid_ << "] slot_tensor[" << i << "][" << j + << "] = " << h_slot_tensor[i][j]; } - cudaMemcpy(h_slot_lod_tensor[i], slot_lod_tensor_ptr_[i], (total_instance + 1) * sizeof(int64_t), - cudaMemcpyDeviceToHost); - len = total_instance + 1 > 5000? 5000: total_instance + 1; - for(int j = 0; j < len; ++j) { - VLOG(2) << "gpu[" << gpuid_ << "] slot_lod_tensor[" << i <<"][" << j << "] = " << h_slot_lod_tensor[i][j]; + cudaMemcpy(h_slot_lod_tensor[i], slot_lod_tensor_ptr_[i], + (total_instance + 1) * sizeof(int64_t), + cudaMemcpyDeviceToHost); + len = total_instance + 1 > 5000 ? 5000 : total_instance + 1; + for (int j = 0; j < len; ++j) { + VLOG(2) << "gpu[" << gpuid_ << "] slot_lod_tensor[" << i << "][" << j + << "] = " << h_slot_lod_tensor[i][j]; } } } @@ -625,21 +687,25 @@ void GraphDataGenerator::FillOneStep(uint64_t *d_start_ids, uint64_t *walk, cur_sampleidx2row_ = 1 - cur_sampleidx2row_; } -int GraphDataGenerator::FillFeatureBuf(int64_t* d_walk, int64_t* d_feature, size_t key_num) { +int GraphDataGenerator::FillFeatureBuf(int64_t *d_walk, int64_t *d_feature, + size_t key_num) { platform::CUDADeviceGuard guard(gpuid_); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - int ret = gpu_graph_ptr->get_feature_of_nodes(gpuid_, d_walk, d_feature, key_num, slot_num_); + int ret = gpu_graph_ptr->get_feature_of_nodes(gpuid_, d_walk, d_feature, + key_num, slot_num_); return ret; } -int GraphDataGenerator::FillFeatureBuf(std::shared_ptr d_walk, - std::shared_ptr d_feature) { +int GraphDataGenerator::FillFeatureBuf( + std::shared_ptr d_walk, + std::shared_ptr d_feature) { platform::CUDADeviceGuard guard(gpuid_); auto gpu_graph_ptr = GraphGpuWrapper::GetInstance(); - int ret = gpu_graph_ptr->get_feature_of_nodes(gpuid_, (int64_t*)d_walk->ptr(), - (int64_t*)d_feature->ptr(), buf_size_, slot_num_); + int ret = gpu_graph_ptr->get_feature_of_nodes( + gpuid_, (int64_t *)d_walk->ptr(), (int64_t *)d_feature->ptr(), buf_size_, + slot_num_); return ret; } @@ -801,7 +867,7 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, // d_device_keys_.resize(h_device_keys_.size()); VLOG(2) << "h_device_keys size: " << h_device_keys_.size(); - + infer_node_type_start_ = std::vector(h_device_keys_.size(), 0); for (size_t i = 0; i < h_device_keys_.size(); i++) { for (size_t j = 0; j < h_device_keys_[i]->size(); j++) { VLOG(3) << "h_device_keys_[" << i << "][" << j @@ -832,8 +898,10 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, d_walk_ = memory::AllocShared(place_, buf_size_ * sizeof(uint64_t)); cudaMemsetAsync(d_walk_->ptr(), 0, buf_size_ * sizeof(uint64_t), stream_); if (!FLAGS_enable_opt_get_features && slot_num_ > 0) { - d_feature_ = memory::AllocShared(place_, buf_size_ * slot_num_ * sizeof(uint64_t)); - cudaMemsetAsync(d_feature_->ptr(), 0, buf_size_ * sizeof(uint64_t), stream_); + d_feature_ = + memory::AllocShared(place_, buf_size_ * slot_num_ * sizeof(uint64_t)); + cudaMemsetAsync(d_feature_->ptr(), 0, buf_size_ * sizeof(uint64_t), + stream_); } d_sample_keys_ = memory::AllocShared(place_, once_max_sample_keynum * sizeof(uint64_t)); @@ -862,13 +930,15 @@ void GraphDataGenerator::AllocResource(const paddle::platform::Place &place, d_ins_buf_ = memory::AllocShared(place_, (batch_size_ * 2 * 2) * sizeof(int64_t)); if (slot_num_ > 0) { - d_feature_buf_ = - memory::AllocShared(place_, (batch_size_ * 2 * 2) * slot_num_ * sizeof(int64_t)); + d_feature_buf_ = memory::AllocShared( + place_, (batch_size_ * 2 * 2) * slot_num_ * sizeof(int64_t)); } d_pair_num_ = memory::AllocShared(place_, sizeof(int)); if (FLAGS_enable_opt_get_features && slot_num_ > 0) { - d_slot_tensor_ptr_ = memory::AllocShared(place_, slot_num_ * sizeof(int64_t*)); - d_slot_lod_tensor_ptr_ = memory::AllocShared(place_, slot_num_ * sizeof(int64_t*)); + d_slot_tensor_ptr_ = + memory::AllocShared(place_, slot_num_ * sizeof(int64_t *)); + d_slot_lod_tensor_ptr_ = + memory::AllocShared(place_, slot_num_ * sizeof(int64_t *)); } cudaStreamSynchronize(stream_); @@ -887,6 +957,7 @@ void GraphDataGenerator::SetConfig( } else { batch_size_ = once_sample_startid_len_; } + gpu_graph_training_ = graph_config.gpu_graph_training(); repeat_time_ = graph_config.sample_times_one_chunk(); buf_size_ = once_sample_startid_len_ * walk_len_ * walk_degree_ * repeat_time_; diff --git a/paddle/fluid/framework/data_feed.h b/paddle/fluid/framework/data_feed.h index 29d045ea936ab..371c5e72850f5 100644 --- a/paddle/fluid/framework/data_feed.h +++ b/paddle/fluid/framework/data_feed.h @@ -897,7 +897,7 @@ class GraphDataGenerator { int FillWalkBuf(std::shared_ptr d_walk); int FillFeatureBuf(int64_t* d_walk, int64_t* d_feature, size_t key_num); int FillFeatureBuf(std::shared_ptr d_walk, - std::shared_ptr d_feature); + std::shared_ptr d_feature); void FillOneStep(uint64_t* start_ids, uint64_t* walk, int len, NeighborSampleResult& sample_res, int cur_degree, int step, int* len_per_row); @@ -944,6 +944,7 @@ class GraphDataGenerator { std::set finish_node_type_; std::unordered_map node_type_start_; + std::vector infer_node_type_start_; std::shared_ptr d_ins_buf_; std::shared_ptr d_feature_buf_; @@ -962,6 +963,7 @@ class GraphDataGenerator { int debug_mode_; std::vector first_node_type_; std::vector> meta_path_; + bool gpu_graph_training_; }; class DataFeed { diff --git a/paddle/fluid/framework/data_feed.proto b/paddle/fluid/framework/data_feed.proto index fe606630f9218..a7ab70948795f 100644 --- a/paddle/fluid/framework/data_feed.proto +++ b/paddle/fluid/framework/data_feed.proto @@ -37,6 +37,7 @@ message GraphConfig { optional int32 debug_mode = 7 [ default = 0 ]; optional string first_node_type = 8; optional string meta_path = 9; + optional bool gpu_graph_training = 10 [ default = true ]; } message DataFeedDesc { diff --git a/paddle/fluid/framework/device_worker.cc b/paddle/fluid/framework/device_worker.cc index 880261436831d..7df1808b85fe9 100644 --- a/paddle/fluid/framework/device_worker.cc +++ b/paddle/fluid/framework/device_worker.cc @@ -32,42 +32,62 @@ void DeviceWorker::SetDataFeed(DataFeed* data_feed) { } template -std::string PrintLodTensorType(Tensor* tensor, int64_t start, int64_t end) { +std::string PrintLodTensorType(Tensor* tensor, int64_t start, int64_t end, + char separator = ':', + bool need_leading_separator = true) { auto count = tensor->numel(); if (start < 0 || end > count) { VLOG(3) << "access violation"; return "access violation"; } + if (start >= end) return ""; std::ostringstream os; + if (!need_leading_separator) { + os << tensor->data()[start]; + start++; + } for (int64_t i = start; i < end; i++) { - os << ":" << tensor->data()[i]; + // os << ":" << tensor->data()[i]; + os << separator << tensor->data()[i]; } return os.str(); } -std::string PrintLodTensorIntType(Tensor* tensor, int64_t start, int64_t end) { +std::string PrintLodTensorIntType(Tensor* tensor, int64_t start, int64_t end, + char separator = ':', + bool need_leading_separator = true) { auto count = tensor->numel(); if (start < 0 || end > count) { VLOG(3) << "access violation"; return "access violation"; } + if (start >= end) return ""; std::ostringstream os; + if (!need_leading_separator) { + os << static_cast(tensor->data()[start]); + start++; + } for (int64_t i = start; i < end; i++) { - os << ":" << static_cast(tensor->data()[i]); + // os << ":" << static_cast(tensor->data()[i]); + os << separator << static_cast(tensor->data()[i]); } return os.str(); } -std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end) { +std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end, + char separator, bool need_leading_separator) { std::string out_val; if (framework::TransToProtoVarType(tensor->dtype()) == proto::VarType::FP32) { - out_val = PrintLodTensorType(tensor, start, end); + out_val = PrintLodTensorType(tensor, start, end, separator, + need_leading_separator); } else if (framework::TransToProtoVarType(tensor->dtype()) == proto::VarType::INT64) { - out_val = PrintLodTensorIntType(tensor, start, end); + out_val = PrintLodTensorIntType(tensor, start, end, separator, + need_leading_separator); } else if (framework::TransToProtoVarType(tensor->dtype()) == proto::VarType::FP64) { - out_val = PrintLodTensorType(tensor, start, end); + out_val = PrintLodTensorType(tensor, start, end, separator, + need_leading_separator); } else { out_val = "unsupported type"; } @@ -122,6 +142,11 @@ void DeviceWorker::DumpParam(const Scope& scope, const int batch_id) { } void DeviceWorker::InitRandomDumpConfig(const TrainerDesc& desc) { + bool is_dump_in_simple_mode = desc.is_dump_in_simple_mode(); + if (is_dump_in_simple_mode) { + dump_mode_ = 3; + return; + } bool enable_random_dump = desc.enable_random_dump(); if (!enable_random_dump) { dump_mode_ = 0; @@ -139,7 +164,7 @@ void DeviceWorker::DumpField(const Scope& scope, int dump_mode, int dump_interval) { // dump_mode: 0: no random, // 1: random with insid hash, // 2: random with random - // number + // 3: simple mode size_t batch_size = device_reader_->GetCurBatchSize(); auto& ins_id_vec = device_reader_->GetInsIdVec(); auto& ins_content_vec = device_reader_->GetInsContentVec(); @@ -163,12 +188,15 @@ void DeviceWorker::DumpField(const Scope& scope, int dump_mode, } hit[i] = true; } - for (size_t i = 0; i < ins_id_vec.size(); i++) { - if (!hit[i]) { - continue; + + if (dump_mode_ != 3) { + for (size_t i = 0; i < ins_id_vec.size(); i++) { + if (!hit[i]) { + continue; + } + ars[i] += ins_id_vec[i]; + ars[i] = ars[i] + "\t" + ins_content_vec[i]; } - ars[i] += ins_id_vec[i]; - ars[i] = ars[i] + "\t" + ins_content_vec[i]; } for (auto& field : *dump_fields_) { Variable* var = scope.FindVar(field); @@ -195,14 +223,20 @@ void DeviceWorker::DumpField(const Scope& scope, int dump_mode, "wrong "; continue; } + for (size_t i = 0; i < batch_size; ++i) { if (!hit[i]) { continue; } auto bound = GetTensorBound(tensor, i); - ars[i] = ars[i] + "\t" + field + ":" + - std::to_string(bound.second - bound.first); - ars[i] += PrintLodTensor(tensor, bound.first, bound.second); + if (dump_mode_ == 3) { + if (ars[i].size() > 0) ars[i] += "\t"; + ars[i] += PrintLodTensor(tensor, bound.first, bound.second, ' ', false); + } else { + ars[i] = ars[i] + "\t" + field + ":" + + std::to_string(bound.second - bound.first); + ars[i] += PrintLodTensor(tensor, bound.first, bound.second); + } } } // #pragma omp parallel for diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 895e459a37dd7..1952e40d5e310 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -59,7 +59,9 @@ class Scope; namespace paddle { namespace framework { -std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end); +std::string PrintLodTensor(Tensor* tensor, int64_t start, int64_t end, + char separator = ',', + bool need_leading_separator = false); std::pair GetTensorBound(LoDTensor* tensor, int index); bool CheckValidOutput(LoDTensor* tensor, size_t batch_size); diff --git a/paddle/fluid/framework/trainer.cc b/paddle/fluid/framework/trainer.cc index b033f9a99d6d9..28c8081409c1d 100644 --- a/paddle/fluid/framework/trainer.cc +++ b/paddle/fluid/framework/trainer.cc @@ -57,7 +57,6 @@ void TrainerBase::DumpWork(int tid) { int err_no = 0; // GetDumpPath is implemented in each Trainer std::string path = GetDumpPath(tid); - std::shared_ptr fp = fs_open_write(path, &err_no, dump_converter_); while (1) { std::string out_str; diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 6fe33545aa22d..daded21ec62d9 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -68,7 +68,7 @@ message TrainerDesc { // add for gpu optional string fleet_desc = 37; - + optional bool is_dump_in_simple_mode = 38 [ default = false ]; // device worker parameters optional HogwildWorkerParameter hogwild_param = 101; optional DownpourWorkerParameter downpour_param = 103; diff --git a/python/paddle/fluid/dataset.py b/python/paddle/fluid/dataset.py index 70c7c0fb8c438..9b8aa28c48882 100644 --- a/python/paddle/fluid/dataset.py +++ b/python/paddle/fluid/dataset.py @@ -1062,6 +1062,8 @@ def set_graph_config(self, config): self.proto_desc.graph_config.first_node_type = config.get( "first_node_type", "") self.proto_desc.graph_config.meta_path = config.get("meta_path", "") + self.proto_desc.graph_config.gpu_graph_training = config.get( + "gpu_graph_training", True) class QueueDataset(DatasetBase): diff --git a/python/paddle/fluid/trainer_desc.py b/python/paddle/fluid/trainer_desc.py index cdc9b14b6e328..2e9658d422aa1 100644 --- a/python/paddle/fluid/trainer_desc.py +++ b/python/paddle/fluid/trainer_desc.py @@ -155,6 +155,9 @@ def _set_dump_fields(self, dump_fields): for field in dump_fields: self.proto_desc.dump_fields.append(field) + def _set_is_dump_in_simple_mode(self, is_dump_in_simple_mode): + self.proto_desc.is_dump_in_simple_mode = is_dump_in_simple_mode + def _set_dump_fields_path(self, path): self.proto_desc.dump_fields_path = path diff --git a/python/paddle/fluid/trainer_factory.py b/python/paddle/fluid/trainer_factory.py index d64f4f17ae323..1267c7b8d97a2 100644 --- a/python/paddle/fluid/trainer_factory.py +++ b/python/paddle/fluid/trainer_factory.py @@ -83,6 +83,9 @@ def _create_trainer(self, opt_info=None): trainer._set_worker_places(opt_info["worker_places"]) if opt_info.get("use_ps_gpu") is not None: trainer._set_use_ps_gpu(opt_info["use_ps_gpu"]) + if opt_info.get("is_dump_in_simple_mode") is not None: + trainer._set_is_dump_in_simple_mode(opt_info[ + "is_dump_in_simple_mode"]) if opt_info.get("enable_random_dump") is not None: trainer._set_enable_random_dump(opt_info[ "enable_random_dump"])