forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add csr_neighborhood op * update neighborhood sample * Update csr_neighborhood_sample-inl.h * Update csr_neighborhood_sample-inl.h * Update csr_neighborhood_sample.cc
- Loading branch information
Showing
2 changed files
with
333 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file index_copy-inl.h | ||
* \brief implementation of neighbor_sample tensor operation | ||
*/ | ||
|
||
#ifndef MXNET_OPERATOR_CONTRIB_CSR_NEIGHBORHOOD_SAMPLE_INL_H_ | ||
#define MXNET_OPERATOR_CONTRIB_CSR_NEIGHBORHOOD_SAMPLE_INL_H_ | ||
|
||
#include <mxnet/io.h> | ||
#include <mxnet/base.h> | ||
#include <mxnet/ndarray.h> | ||
#include <mxnet/operator.h> | ||
#include <mxnet/operator_util.h> | ||
#include <dmlc/logging.h> | ||
#include <dmlc/optional.h> | ||
#include "../operator_common.h" | ||
|
||
#include <vector> | ||
#include <cstdlib> | ||
#include <ctime> | ||
#include <unordered_map> | ||
#include <algorithm> | ||
#include <queue> | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
typedef int64_t dgl_id_t; | ||
|
||
//------------------------------------------------------------------------------ | ||
// input[0]: Graph | ||
// input[1]: seed_vertices | ||
// args[0]: num_hops | ||
// args[1]: num_neighbor | ||
// args[2]: max_num_vertices | ||
//------------------------------------------------------------------------------ | ||
|
||
// For BFS traversal | ||
struct ver_node { | ||
dgl_id_t vertex_id; | ||
int level; | ||
}; | ||
|
||
// How to set the default value? | ||
struct NeighborSampleParam : public dmlc::Parameter<NeighborSampleParam> { | ||
dgl_id_t num_hops, num_neighbor, max_num_vertices; | ||
DMLC_DECLARE_PARAMETER(NeighborSampleParam) { | ||
DMLC_DECLARE_FIELD(num_hops) | ||
.set_default(1) | ||
.describe("Number of hops."); | ||
DMLC_DECLARE_FIELD(num_neighbor) | ||
.set_default(2) | ||
.describe("Number of neighbor."); | ||
DMLC_DECLARE_FIELD(max_num_vertices) | ||
.set_default(100) | ||
.describe("Max number of vertices."); | ||
} | ||
}; | ||
|
||
static bool CSRNeighborSampleStorageType(const nnvm::NodeAttrs& attrs, | ||
const int dev_mask, | ||
DispatchMode* dispatch_mode, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 2); | ||
CHECK_EQ(out_attrs->size(), 1); | ||
|
||
CHECK_EQ(in_attrs->at(0), mxnet::kCSRStorage); | ||
|
||
CHECK_EQ(in_attrs->at(1), mxnet::kDefaultStorage); | ||
|
||
bool success = true; | ||
if (!type_assign(&(*out_attrs)[0], mxnet::kDefaultStorage)) { | ||
success = false; | ||
} | ||
|
||
*dispatch_mode = DispatchMode::kFComputeEx; | ||
|
||
return success; | ||
} | ||
|
||
static bool CSRNeighborSampleShape(const nnvm::NodeAttrs& attrs, | ||
std::vector<TShape> *in_attrs, | ||
std::vector<TShape> *out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 2); | ||
CHECK_EQ(out_attrs->size(), 1); | ||
|
||
CHECK_EQ(in_attrs->at(0).ndim(), 2U); | ||
CHECK_EQ(in_attrs->at(1).ndim(), 1U); | ||
// Check the graph shape | ||
CHECK_EQ(in_attrs->at(0)[0], in_attrs->at(0)[1]); | ||
|
||
const NeighborSampleParam& params = | ||
nnvm::get<NeighborSampleParam>(attrs.parsed); | ||
|
||
TShape out_shape(1); | ||
out_shape[0] = params.max_num_vertices; | ||
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape); | ||
|
||
return out_attrs->at(0).ndim() != 0U && | ||
out_attrs->at(0).Size() != 0U; | ||
} | ||
|
||
static bool CSRNeighborSampleType(const nnvm::NodeAttrs& attrs, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 2); | ||
CHECK_EQ(out_attrs->size(), 1); | ||
out_attrs->at(0) = in_attrs->at(0); | ||
|
||
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); | ||
return out_attrs->at(0) != -1; | ||
} | ||
|
||
static void GetSrcList(const dgl_id_t* col_list, | ||
const dgl_id_t* indptr, | ||
const dgl_id_t dst_id, | ||
std::vector<dgl_id_t>& src_list) { | ||
for (dgl_id_t i = *(indptr+dst_id); i < *(indptr+dst_id+1); ++i) { | ||
src_list.push_back(col_list[i]); | ||
} | ||
} | ||
|
||
static void GetSample(std::vector<dgl_id_t>& ver_list, | ||
const size_t max_num_neighbor, | ||
std::vector<dgl_id_t>& out) { | ||
// Copy ver_list to output | ||
if (ver_list.size() <= max_num_neighbor) { | ||
for (size_t i = 0; i < ver_list.size(); ++i) { | ||
out.push_back(ver_list[i]); | ||
} | ||
return; | ||
} | ||
// Make sample | ||
std::unordered_map<size_t, bool> mp; | ||
size_t sample_count = 0; | ||
for (;;) { | ||
// rand_num = [0, ver_list.size()-1] | ||
size_t rand_num = rand() % ver_list.size(); | ||
auto got = mp.find(rand_num); | ||
if (got != mp.end() && mp[rand_num]) { | ||
// re-sample | ||
continue; | ||
} | ||
mp[rand_num] = true; | ||
out.push_back(ver_list[rand_num]); | ||
sample_count++; | ||
if (sample_count == max_num_neighbor) { | ||
break; | ||
} | ||
} | ||
} | ||
|
||
static void CSRNeighborSampleComputeExCPU(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<NDArray>& outputs) { | ||
CHECK_EQ(inputs.size(), 2U); | ||
CHECK_EQ(outputs.size(), 1U); | ||
|
||
const NeighborSampleParam& params = | ||
nnvm::get<NeighborSampleParam>(attrs.parsed); | ||
|
||
// set seed for random sampling | ||
srand(time(nullptr)); | ||
|
||
dgl_id_t num_hops = params.num_hops; | ||
dgl_id_t num_neighbor = params.num_neighbor; | ||
dgl_id_t max_num_vertices = params.max_num_vertices; | ||
|
||
size_t seed_num = inputs[1].data().Size(); | ||
|
||
CHECK_GE(max_num_vertices, seed_num); | ||
|
||
const dgl_id_t* col_list = inputs[0].aux_data(1).dptr<dgl_id_t>(); | ||
const dgl_id_t* indptr = inputs[0].aux_data(0).dptr<dgl_id_t>(); | ||
const dgl_id_t* seed = inputs[1].data().dptr<dgl_id_t>(); | ||
|
||
dgl_id_t* out = outputs[0].data().dptr<dgl_id_t>(); | ||
|
||
// BFS traverse the graph and sample vertices | ||
dgl_id_t sub_vertices_count = 0; | ||
std::unordered_map<dgl_id_t, bool> sub_ver_mp; | ||
std::queue<ver_node> node_queue; | ||
// add seed vertices | ||
for (size_t i = 0; i < seed_num; ++i) { | ||
ver_node node; | ||
node.vertex_id = seed[i]; | ||
node.level = 0; | ||
node_queue.push(node); | ||
sub_ver_mp[node.vertex_id] = true; | ||
sub_vertices_count++; | ||
} | ||
|
||
std::vector<dgl_id_t> tmp_src_list; | ||
std::vector<dgl_id_t> tmp_sampled_list; | ||
|
||
while (!node_queue.empty() && sub_vertices_count < max_num_vertices) { | ||
ver_node& cur_node = node_queue.front(); | ||
if (cur_node.level < num_hops) { | ||
dgl_id_t dst_id = cur_node.vertex_id; | ||
tmp_src_list.clear(); | ||
tmp_sampled_list.clear(); | ||
GetSrcList(col_list, indptr, dst_id, tmp_src_list); | ||
GetSample(tmp_src_list, num_neighbor, tmp_sampled_list); | ||
for (size_t i = 0; i < tmp_sampled_list.size(); ++i) { | ||
auto got = sub_ver_mp.find(tmp_sampled_list[i]); | ||
if (got == sub_ver_mp.end()) { | ||
sub_vertices_count++; | ||
sub_ver_mp[tmp_sampled_list[i]] = true; | ||
ver_node new_node; | ||
new_node.vertex_id = tmp_sampled_list[i]; | ||
new_node.level = cur_node.level + 1; | ||
node_queue.push(new_node); | ||
} | ||
if (sub_vertices_count >= max_num_vertices) { | ||
break; | ||
} | ||
} | ||
} | ||
node_queue.pop(); | ||
} | ||
|
||
// Copy sub_ver_mp to output | ||
dgl_id_t idx = 0; | ||
for (auto& data: sub_ver_mp) { | ||
if (data.second) { | ||
*(out+idx) = data.first; | ||
idx++; | ||
} | ||
} | ||
// The rest data will be set to -1 | ||
for (dgl_id_t i = idx; i < max_num_vertices; ++i) { | ||
*(out+i) = -1; | ||
} | ||
} | ||
|
||
} // op | ||
} // mxnet | ||
|
||
#endif // MXNET_OPERATOR_CONTRIB_CSR_NEIGHBORHOOD_SAMPLE_INL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file csr_neighborhood_sample.cc | ||
* \brief | ||
*/ | ||
#include "./csr_neighborhood_sample-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
/* | ||
Usage: | ||
import mxnet as mx | ||
import numpy as np | ||
shape = (5, 5) | ||
data_np = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], dtype=np.int64) | ||
indices_np = np.array([1,2,3,4,0,2,3,4,0,1,3,4,0,1,2,4,0,1,2,3], dtype=np.int64) | ||
indptr_np = np.array([0, 4,8,12,16,20], dtype=np.int64) | ||
a = mx.nd.sparse.csr_matrix((data_np, indices_np, indptr_np), shape=shape) | ||
a.asnumpy() | ||
seed = mx.nd.array([0], dtype=np.int64) | ||
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=1, max_num_vertices=5) | ||
out.asnumpy() | ||
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=2, max_num_vertices=5) | ||
out.asnumpy() | ||
seed = mx.nd.array([0,4], dtype=np.int64) | ||
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=2, max_num_vertices=5) | ||
out.asnumpy() | ||
*/ | ||
|
||
DMLC_REGISTER_PARAMETER(NeighborSampleParam); | ||
|
||
NNVM_REGISTER_OP(_contrib_neighbor_sample) | ||
.MXNET_DESCRIBE("") | ||
.set_attr_parser(ParamParser<NeighborSampleParam>) | ||
.set_num_inputs(2) | ||
.set_num_outputs(1) | ||
.set_attr<FInferStorageType>("FInferStorageType", CSRNeighborSampleStorageType) | ||
.set_attr<nnvm::FInferShape>("FInferShape", CSRNeighborSampleShape) | ||
.set_attr<nnvm::FInferType>("FInferType", CSRNeighborSampleType) | ||
.set_attr<FComputeEx>("FComputeEx<cpu>", CSRNeighborSampleComputeExCPU) | ||
.add_argument("csr_matrix", "NDArray-or-Symbol", "csr matrix") | ||
.add_argument("seed_array", "NDArray-or-Symbol", "seed vertices") | ||
.add_arguments(NeighborSampleParam::__FIELDS__()); | ||
|
||
} // op | ||
} // mxnet |