Skip to content

Commit

Permalink
Gpugraph01 (PaddlePaddle#30)
Browse files Browse the repository at this point in the history
* Optimize memory overhead for gpugraph.

* Optimize memory overhead for gpugraph.

* Add debug codes for HBM
  • Loading branch information
lxsbupt committed Jun 13, 2022
1 parent 77b007e commit a7ce0cf
Show file tree
Hide file tree
Showing 13 changed files with 105 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
}
#endif

DECLARE_bool(gpugraph_enable_hbm_table_collision_stat);

// TODO: can we do this more efficiently?
__inline__ __device__ int8_t atomicCAS(int8_t* address, int8_t compare,
int8_t val) {
Expand Down Expand Up @@ -313,8 +315,7 @@ __host__ __device__ bool operator!=(const cycle_iterator_adapter<T>& lhs,
template <typename Key, typename Element, Key unused_key,
typename Hasher = default_hash<Key>,
typename Equality = equal_to<Key>,
typename Allocator = managed_allocator<thrust::pair<Key, Element>>,
bool count_collisions = false>
typename Allocator = managed_allocator<thrust::pair<Key, Element>>>
class concurrent_unordered_map : public managed {
public:
using size_type = size_t;
Expand Down Expand Up @@ -346,9 +347,13 @@ class concurrent_unordered_map : public managed {
m_allocator(a),
m_hashtbl_size(n),
m_hashtbl_capacity(n),
m_collisions(0),
m_unused_element(
unused_element) { // allocate the raw data of hash table:
m_unused_element(unused_element),
m_enable_collision_stat(false),
m_insert_times(0),
m_insert_collisions(0),
m_query_times(0),
m_query_collisions(0)
{ // allocate the raw data of hash table:
// m_hashtbl_values,pre-alloc it on current GPU if UM.
m_hashtbl_values = m_allocator.allocate(m_hashtbl_capacity);
constexpr int block_size = 128;
Expand All @@ -373,9 +378,9 @@ class concurrent_unordered_map : public managed {
// Initialize kernel, set all entry to unused <K,V>
init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size>>>(
m_hashtbl_values, m_hashtbl_size, unused_key, m_unused_element);
// CUDA_RT_CALL( cudaGetLastError() );
CUDA_RT_CALL(cudaStreamSynchronize(0));
CUDA_RT_CALL(cudaGetLastError());
m_enable_collision_stat = FLAGS_gpugraph_enable_hbm_table_collision_stat;
}

~concurrent_unordered_map() {
Expand Down Expand Up @@ -549,11 +554,16 @@ class concurrent_unordered_map : public managed {
// TODO: How to handle data types less than 32 bits?
if (keys_equal(unused_key, old_key) || keys_equal(insert_key, old_key)) {
update_existing_value(existing_value, x, op);

insert_success = true;
if (m_enable_collision_stat) {
atomicAdd(&m_insert_times, 1);
}
break;
}

if (m_enable_collision_stat) {
atomicAdd(&m_insert_collisions, 1);
}
current_index = (current_index + 1) % hashtbl_size;
current_hash_bucket = &(hashtbl_values[current_index]);
}
Expand Down Expand Up @@ -591,9 +601,9 @@ std::numeric_limits<mapped_type>::is_integer && sizeof(unsigned long long int)
reinterpret_cast<unsigned long long
int*>(tmp_it), unused, value ); if ( old_val == unused ) { it = tmp_it;
}
else if ( count_collisions )
else if ( m_enable_collision_stat )
{
atomicAdd( &m_collisions, 1 );
atomicAdd( &m_insert_collisions, 1 );
}
} else {
const key_type old_key = atomicCAS( &(tmp_it->first), unused_key,
Expand All @@ -602,9 +612,9 @@ x.first );
(m_hashtbl_values+hash_tbl_idx)->second = x.second;
it = tmp_it;
}
else if ( count_collisions )
else if ( m_enable_collision_stat )
{
atomicAdd( &m_collisions, 1 );
atomicAdd( &m_insert_collisions, 1 );
}
}
#else
Expand All @@ -625,8 +635,8 @@ x.second );
}
*/

__forceinline__ __host__ __device__ const_iterator
find(const key_type& k) const {
__forceinline__ __device__ const_iterator
find(const key_type& k) {
size_type key_hash = m_hf(k);
size_type hash_tbl_idx = key_hash % m_hashtbl_size;

Expand All @@ -644,10 +654,17 @@ x.second );
begin_ptr = m_hashtbl_values + m_hashtbl_size;
break;
}
if (m_enable_collision_stat) {
atomicAdd(&m_query_collisions, 1);
}
hash_tbl_idx = (hash_tbl_idx + 1) % m_hashtbl_size;
++counter;
}

if (m_enable_collision_stat) {
atomicAdd(&m_query_times, 1);
}

return const_iterator(m_hashtbl_values, m_hashtbl_values + m_hashtbl_size,
begin_ptr);
}
Expand Down Expand Up @@ -743,7 +760,7 @@ x.second );

int assign_async(const concurrent_unordered_map& other,
cudaStream_t stream = 0) {
m_collisions = other.m_collisions;
m_insert_collisions = other.m_insert_collisions;
if (other.m_hashtbl_size <= m_hashtbl_capacity) {
m_hashtbl_size = other.m_hashtbl_size;
} else {
Expand All @@ -764,10 +781,15 @@ x.second );
init_hashtbl<<<((m_hashtbl_size - 1) / block_size) + 1, block_size, 0,
stream>>>(m_hashtbl_values, m_hashtbl_size, unused_key,
m_unused_element);
if (count_collisions) m_collisions = 0;
if (m_enable_collision_stat) {
m_insert_times = 0;
m_insert_collisions = 0;
m_query_times = 0;
m_query_collisions = 0;
}
}

unsigned long long get_num_collisions() const { return m_collisions; }
unsigned long long get_num_collisions() const { return m_insert_collisions; }

void print() {
for (size_type i = 0; i < 5; ++i) {
Expand Down Expand Up @@ -817,6 +839,17 @@ x.second );
return it;
}

__host__ void print_collision(int id) {
if (m_enable_collision_stat) {
printf("collision stat for hbm table %d, insert(%lu:%lu), query(%lu:%lu)\n",
id,
m_insert_times,
m_insert_collisions,
m_query_times,
m_query_collisions);
}
}

private:
const hasher m_hf;
const key_equal m_equal;
Expand All @@ -829,7 +862,11 @@ x.second );
size_type m_hashtbl_capacity;
value_type* m_hashtbl_values;

unsigned long long m_collisions;
bool m_enable_collision_stat;
uint64_t m_insert_times;
uint64_t m_insert_collisions;
uint64_t m_query_times;
uint64_t m_query_collisions;
};

#endif // CONCURRENT_UNORDERED_MAP_CUH
10 changes: 6 additions & 4 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_ps_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
#include "paddle/fluid/framework/fleet/heter_ps/heter_comm_kernel.h"
#include "paddle/fluid/platform/enforce.h"
#ifdef PADDLE_WITH_HETERPS

DECLARE_double(gpugraph_hbm_table_load_factor);

namespace paddle {
namespace framework {
enum GraphTableType { EDGE_TABLE, FEATURE_TABLE };
Expand All @@ -34,7 +37,8 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
GpuPsGraphTable(std::shared_ptr<HeterPsResource> resource, int topo_aware,
int graph_table_num)
: HeterComm<uint64_t, int64_t, int>(1, resource) {
load_factor_ = 0.25;
load_factor_ = FLAGS_gpugraph_hbm_table_load_factor;
VLOG(0) << "load_factor = " << load_factor_;
rw_lock.reset(new pthread_rwlock_t());
this->graph_table_num_ = graph_table_num;
this->feature_table_num_ = 1;
Expand Down Expand Up @@ -104,9 +108,6 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
}
}
~GpuPsGraphTable() {
// if (cpu_table_status != -1) {
// end_graph_sampling();
// }
}
void build_graph_on_single_gpu(GpuPsCommGraph &g, int gpu_id, int idx);
void build_graph_fea_on_single_gpu(GpuPsCommGraphFea &g, int gpu_id);
Expand Down Expand Up @@ -140,6 +141,7 @@ class GpuPsGraphTable : public HeterComm<uint64_t, int64_t, int> {
uint64_t *src_sample_res,
int *actual_sample_size);
int init_cpu_table(const paddle::distributed::GraphParameter &graph);

int gpu_num;
int graph_table_num_, feature_table_num_;
std::vector<GpuPsCommGraph> gpu_graph_list_;
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.cu
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,10 @@ void GraphGpuWrapper::init_service() {
graph_table = (char *)g;
}

void GraphGpuWrapper::finalize() {
((GpuPsGraphTable *)graph_table)->show_table_collisions();
}

void GraphGpuWrapper::upload_batch(int idx,
std::vector<std::vector<uint64_t>> &ids) {
debug_gpu_memory_info("upload_batch node start");
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/heter_ps/graph_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class GraphGpuWrapper {
}
static std::shared_ptr<GraphGpuWrapper> s_instance_;
void initialize();
void finalize();
void set_device(std::vector<int> ids);
void init_service();
void set_up_types(std::vector<std::string>& edge_type,
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/hashtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class XPUCacheArray {
}

void print() {}
void print_collision(int i) {}

#if defined(__xpu__)
__device__ ValType* find(const KeyType& key) {
Expand Down Expand Up @@ -167,6 +168,8 @@ class HashTable {
<< " push value size: " << push_grad_value_size_;
}

void show_collision(int id) { return container_->print_collision(id); }

std::unique_ptr<phi::RWLock> rwlock_{nullptr};

private:
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class HeterComm {
size_t feature_value_size, size_t chunk_size, int stream_num);
void dump();
void show_one_table(int gpu_num);
void show_table_collisions();
int get_index_by_devid(int devid);

#if defined(PADDLE_WITH_CUDA)
Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
#endif

DECLARE_double(gpugraph_hbm_table_load_factor);

namespace paddle {
namespace framework {
template <typename KeyType, typename ValType, typename GradType>
Expand All @@ -30,6 +32,8 @@ HeterComm<KeyType, ValType, GradType>::HeterComm(
resource_ = resource;
storage_.resize(resource_->total_device());
multi_mf_dim_ = resource->multi_mf();
load_factor_ = FLAGS_gpugraph_hbm_table_load_factor;
VLOG(0) << "load_factor = " << load_factor_;
for (int i = 0; i < resource_->total_device(); ++i) {
#if defined(PADDLE_WITH_CUDA)
platform::CUDADeviceGuard guard(resource_->dev_id(i));
Expand Down Expand Up @@ -379,6 +383,22 @@ void HeterComm<KeyType, ValType, GradType>::show_one_table(int gpu_num) {
}
}

template <typename KeyType, typename ValType, typename GradType>
void HeterComm<KeyType, ValType, GradType>::show_table_collisions() {
size_t idx = 0;
for (auto& table : tables_) {
if (table != nullptr) {
table->show_collision(idx++);
}
}
idx = 0;
for (auto& table : ptr_tables_) {
if (table != nullptr) {
table->show_collision(idx++);
}
}
}

template <typename KeyType, typename ValType, typename GradType>
int HeterComm<KeyType, ValType, GradType>::log2i(int x) {
unsigned res = 0;
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_ps.cu
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ void HeterPs::set_multi_mf_dim(int multi_mf_dim, int max_mf_dim) {
comm_->set_multi_mf_dim(multi_mf_dim, max_mf_dim);
}

void HeterPs::show_table_collisions() {
comm_->show_table_collisions();
}

} // end namespace framework
} // end namespace paddle
#endif
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_ps.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class HeterPs : public HeterPsBase {
void show_one_table(int gpu_num) override;
void push_sparse(int num, FeatureKey* d_keys, FeaturePushValue* d_grads,
size_t len) override;
void show_table_collisions() override;

private:
std::shared_ptr<HeterComm<FeatureKey, FeatureValue, FeaturePushValue>> comm_;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_ps_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class HeterPsBase {
#endif
virtual void end_pass() = 0;
virtual void show_one_table(int gpu_num) = 0;
virtual void show_table_collisions() = 0;
virtual void push_sparse(int num, FeatureKey* d_keys,
FeaturePushValue* d_grads, size_t len) = 0;

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class PSGPUWrapper {
pre_build_threads_.join();
s_instance_ = nullptr;
VLOG(3) << "PSGPUWrapper Finalize Finished.";
HeterPs_->show_table_collisions();
}

void InitializeGPU(const std::vector<int>& dev_ids) {
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/platform/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,13 @@ DEFINE_bool(enable_slotpool_wait_release, false,
DEFINE_bool(enable_slotrecord_reset_shrink, false,
"enable slotrecord obejct reset shrink memory, default false");
DEFINE_bool(enable_ins_parser_file, false,
"enable parser ins file , default false");
"enable parser ins file, default false");
PADDLE_DEFINE_EXPORTED_bool(
gpugraph_enable_hbm_table_collision_stat, false,
"enable hash collisions stat for hbm table, default false");
PADDLE_DEFINE_EXPORTED_double(
gpugraph_hbm_table_load_factor, 0.75,
"the load factor of hbm table, default 0.75");

/**
* ProcessGroupNCCL related FLAG
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pybind/fleet_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@ void BindGraphGpuWrapper(py::module* m) {
.def("get_partition", &GraphGpuWrapper::get_partition)
.def("load_node_weight", &GraphGpuWrapper::load_node_weight)
.def("export_partition_files", &GraphGpuWrapper::export_partition_files)
.def("load_node_file", &GraphGpuWrapper::load_node_file);
.def("load_node_file", &GraphGpuWrapper::load_node_file)
.def("finalize", &GraphGpuWrapper::finalize);
}
#endif

Expand Down

0 comments on commit a7ce0cf

Please sign in to comment.