Skip to content

Commit

Permalink
[GpuGraph]direct access (PaddlePaddle#36)
Browse files Browse the repository at this point in the history
* direct access

* format

* log level
  • Loading branch information
Thunderbrook committed Jun 16, 2022
1 parent 0969831 commit 8ce3cf9
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 48 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/distributed/ps/table/ctr_dymf_accessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ std::string CtrDymfAccessor::ParseToString(const float* v, int param) {
auto score = ShowClickScore(show, click);
if (score >= _config.embedx_threshold() &&
param > common_feature_value.EmbedxG2SumIndex()) {
VLOG(0) << "common_feature_value.EmbedxG2SumIndex():"
VLOG(3) << "common_feature_value.EmbedxG2SumIndex():"
<< common_feature_value.EmbedxG2SumIndex();
for (auto i = common_feature_value.EmbedxG2SumIndex();
i < common_feature_value.Dim(); ++i) {
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/framework/fleet/heter_ps/heter_comm.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ class HeterComm {
#endif
}

void create_storage(int start_index, int end_index, size_t keylen, size_t vallen);
void create_storage(int start_index, int end_index, size_t keylen,
size_t vallen);
void destroy_storage(int start_index, int end_index);
void walk_to_dest(int start_index, int gpu_num, int* h_left, int* h_right,
KeyType* src_key, GradType* src_val);
Expand All @@ -238,6 +239,7 @@ class HeterComm {
std::vector<std::vector<Path>> path_;
float load_factor_{0.75};
int block_size_{256};
int direct_access_ = 1;
std::unique_ptr<HeterCommKernel> heter_comm_kernel_;

private:
Expand Down
120 changes: 74 additions & 46 deletions paddle/fluid/framework/fleet/heter_ps/heter_comm_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -772,27 +772,37 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
memory_copy(dst_place, h_right, src_place, d_right_ptr,
total_device * sizeof(int), stream);

for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
if (!direct_access_) {
for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
create_storage(num, i, shard_len * sizeof(KeyType),
shard_len * val_type_size);
}
create_storage(num, i, shard_len * sizeof(KeyType),
shard_len * val_type_size);
walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL);
}
walk_to_dest(num, total_device, h_left, h_right, d_shard_keys_ptr, NULL);

for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1) {
continue;
}
auto& node = path_[num][i].nodes_.back();
sync_stream(node.in_stream);
if (!direct_access_) {
sync_stream(node.in_stream);
}
AnyDeviceGuard guard(resource_->dev_id(i));
ptr_tables_[i]->rwlock_->RDLock();
ptr_tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, num));
if (!direct_access_) {
ptr_tables_[i]->get(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1,
resource_->remote_stream(i, num));
} else {
ptr_tables_[i]->get(
d_shard_keys_ptr + h_left[i],
reinterpret_cast<char*>(d_shard_vals_ptr) + h_left[i] * val_type_size,
h_right[i] - h_left[i] + 1, resource_->remote_stream(i, num));
}
}

for (int i = 0; i < total_device; ++i) {
Expand All @@ -802,21 +812,25 @@ void HeterComm<KeyType, ValType, GradType>::pull_sparse(int num,
}
ptr_tables_[i]->rwlock_->UNLock();
}
walk_to_src(num, total_device, h_left, h_right,
reinterpret_cast<char*>(d_shard_vals_ptr), val_type_size);
for (int i = 0; i < total_device; ++i) {
auto& node = path_[num][i].nodes_.front();
sync_stream(node.out_stream);
if (!direct_access_) {
walk_to_src(num, total_device, h_left, h_right,
reinterpret_cast<char*>(d_shard_vals_ptr), val_type_size);
for (int i = 0; i < total_device; ++i) {
auto& node = path_[num][i].nodes_.front();
sync_stream(node.out_stream);
}
}
heter_comm_kernel_->dy_mf_fill_dvals(d_shard_vals_ptr, d_vals, d_idx_ptr, len,
val_type_size, stream);

sync_stream(stream);
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
if (!direct_access_) {
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
destroy_storage(num, i);
}
destroy_storage(num, i);
}
}

Expand Down Expand Up @@ -912,34 +926,38 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
memory_copy(dst_place, h_right, src_place, d_right_ptr,
total_device * sizeof(int), stream);

for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
if (!direct_access_) {
for (int i = 0; i < total_device; ++i) {
int shard_len = h_right[i] - h_left[i] + 1;
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
if (!multi_mf_dim_) {
create_storage(dev_num, i, shard_len * sizeof(KeyType),
shard_len * sizeof(GradType));
} else {
create_storage(dev_num, i, shard_len * sizeof(KeyType),
shard_len * grad_value_size);
}
}

if (!multi_mf_dim_) {
create_storage(dev_num, i, shard_len * sizeof(KeyType),
shard_len * sizeof(GradType));
walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr,
d_shard_grads_ptr);
} else {
create_storage(dev_num, i, shard_len * sizeof(KeyType),
shard_len * grad_value_size);
walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr,
reinterpret_cast<char*>(d_shard_grads_ptr), grad_value_size);
}
}

if (!multi_mf_dim_) {
walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr,
d_shard_grads_ptr);
} else {
walk_to_dest(dev_num, total_device, h_left, h_right, d_shard_keys_ptr,
reinterpret_cast<char*>(d_shard_grads_ptr), grad_value_size);
}

for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
auto& node = path_[dev_num][i].nodes_.back();
sync_stream(node.in_stream);
if (!direct_access_) {
sync_stream(node.in_stream);
}

AnyDeviceGuard guard(resource_->dev_id(i));
if (!multi_mf_dim_) {
Expand All @@ -950,9 +968,17 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
resource_->remote_stream(i, dev_num));
} else {
ptr_tables_[i]->rwlock_->WRLock();
ptr_tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i, dev_num));
if (!direct_access_) {
ptr_tables_[i]->update(reinterpret_cast<KeyType*>(node.key_storage),
node.val_storage, h_right[i] - h_left[i] + 1,
sgd, resource_->remote_stream(i, dev_num));
} else {
ptr_tables_[i]->update(d_shard_keys_ptr + h_left[i],
reinterpret_cast<char*>(d_shard_grads_ptr) +
grad_value_size * h_left[i],
h_right[i] - h_left[i] + 1, sgd,
resource_->remote_stream(i, dev_num));
}
}
}

Expand All @@ -966,12 +992,14 @@ void HeterComm<KeyType, ValType, GradType>::push_sparse(int dev_num,
}
}
}

for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;

if (!direct_access_) {
for (int i = 0; i < total_device; ++i) {
if (h_left[i] == -1 || h_right[i] == -1) {
continue;
}
destroy_storage(dev_num, i);
}
destroy_storage(dev_num, i);
}
}

Expand Down

0 comments on commit 8ce3cf9

Please sign in to comment.