Skip to content

Commit

Permalink
Add neighborhood op (#37)
Browse files Browse the repository at this point in the history
* add csr_neighborhood op

* update neighborhood sample

* Update csr_neighborhood_sample-inl.h

* Update csr_neighborhood_sample-inl.h

* Update csr_neighborhood_sample.cc
  • Loading branch information
aksnzhy authored and zheng-da committed Nov 24, 2018
1 parent ba84765 commit 81e6281
Show file tree
Hide file tree
Showing 2 changed files with 333 additions and 0 deletions.
261 changes: 261 additions & 0 deletions src/operator/contrib/csr_neighborhood_sample-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file index_copy-inl.h
* \brief implementation of neighbor_sample tensor operation
*/

#ifndef MXNET_OPERATOR_CONTRIB_CSR_NEIGHBORHOOD_SAMPLE_INL_H_
#define MXNET_OPERATOR_CONTRIB_CSR_NEIGHBORHOOD_SAMPLE_INL_H_

#include <mxnet/io.h>
#include <mxnet/base.h>
#include <mxnet/ndarray.h>
#include <mxnet/operator.h>
#include <mxnet/operator_util.h>
#include <dmlc/logging.h>
#include <dmlc/optional.h>
#include "../operator_common.h"

#include <vector>
#include <cstdlib>
#include <ctime>
#include <unordered_map>
#include <algorithm>
#include <queue>

namespace mxnet {
namespace op {

typedef int64_t dgl_id_t;

//------------------------------------------------------------------------------
// input[0]: Graph
// input[1]: seed_vertices
// args[0]: num_hops
// args[1]: num_neighbor
// args[2]: max_num_vertices
//------------------------------------------------------------------------------

// For BFS traversal
struct ver_node {
dgl_id_t vertex_id;
int level;
};

// How to set the default value?
struct NeighborSampleParam : public dmlc::Parameter<NeighborSampleParam> {
dgl_id_t num_hops, num_neighbor, max_num_vertices;
DMLC_DECLARE_PARAMETER(NeighborSampleParam) {
DMLC_DECLARE_FIELD(num_hops)
.set_default(1)
.describe("Number of hops.");
DMLC_DECLARE_FIELD(num_neighbor)
.set_default(2)
.describe("Number of neighbor.");
DMLC_DECLARE_FIELD(max_num_vertices)
.set_default(100)
.describe("Max number of vertices.");
}
};

static bool CSRNeighborSampleStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);

CHECK_EQ(in_attrs->at(0), mxnet::kCSRStorage);

CHECK_EQ(in_attrs->at(1), mxnet::kDefaultStorage);

bool success = true;
if (!type_assign(&(*out_attrs)[0], mxnet::kDefaultStorage)) {
success = false;
}

*dispatch_mode = DispatchMode::kFComputeEx;

return success;
}

static bool CSRNeighborSampleShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);

CHECK_EQ(in_attrs->at(0).ndim(), 2U);
CHECK_EQ(in_attrs->at(1).ndim(), 1U);
// Check the graph shape
CHECK_EQ(in_attrs->at(0)[0], in_attrs->at(0)[1]);

const NeighborSampleParam& params =
nnvm::get<NeighborSampleParam>(attrs.parsed);

TShape out_shape(1);
out_shape[0] = params.max_num_vertices;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape);

return out_attrs->at(0).ndim() != 0U &&
out_attrs->at(0).Size() != 0U;
}

static bool CSRNeighborSampleType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
out_attrs->at(0) = in_attrs->at(0);

TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
return out_attrs->at(0) != -1;
}

static void GetSrcList(const dgl_id_t* col_list,
const dgl_id_t* indptr,
const dgl_id_t dst_id,
std::vector<dgl_id_t>& src_list) {
for (dgl_id_t i = *(indptr+dst_id); i < *(indptr+dst_id+1); ++i) {
src_list.push_back(col_list[i]);
}
}

static void GetSample(std::vector<dgl_id_t>& ver_list,
const size_t max_num_neighbor,
std::vector<dgl_id_t>& out) {
// Copy ver_list to output
if (ver_list.size() <= max_num_neighbor) {
for (size_t i = 0; i < ver_list.size(); ++i) {
out.push_back(ver_list[i]);
}
return;
}
// Make sample
std::unordered_map<size_t, bool> mp;
size_t sample_count = 0;
for (;;) {
// rand_num = [0, ver_list.size()-1]
size_t rand_num = rand() % ver_list.size();
auto got = mp.find(rand_num);
if (got != mp.end() && mp[rand_num]) {
// re-sample
continue;
}
mp[rand_num] = true;
out.push_back(ver_list[rand_num]);
sample_count++;
if (sample_count == max_num_neighbor) {
break;
}
}
}

static void CSRNeighborSampleComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);

const NeighborSampleParam& params =
nnvm::get<NeighborSampleParam>(attrs.parsed);

// set seed for random sampling
srand(time(nullptr));

dgl_id_t num_hops = params.num_hops;
dgl_id_t num_neighbor = params.num_neighbor;
dgl_id_t max_num_vertices = params.max_num_vertices;

size_t seed_num = inputs[1].data().Size();

CHECK_GE(max_num_vertices, seed_num);

const dgl_id_t* col_list = inputs[0].aux_data(1).dptr<dgl_id_t>();
const dgl_id_t* indptr = inputs[0].aux_data(0).dptr<dgl_id_t>();
const dgl_id_t* seed = inputs[1].data().dptr<dgl_id_t>();

dgl_id_t* out = outputs[0].data().dptr<dgl_id_t>();

// BFS traverse the graph and sample vertices
dgl_id_t sub_vertices_count = 0;
std::unordered_map<dgl_id_t, bool> sub_ver_mp;
std::queue<ver_node> node_queue;
// add seed vertices
for (size_t i = 0; i < seed_num; ++i) {
ver_node node;
node.vertex_id = seed[i];
node.level = 0;
node_queue.push(node);
sub_ver_mp[node.vertex_id] = true;
sub_vertices_count++;
}

std::vector<dgl_id_t> tmp_src_list;
std::vector<dgl_id_t> tmp_sampled_list;

while (!node_queue.empty() && sub_vertices_count < max_num_vertices) {
ver_node& cur_node = node_queue.front();
if (cur_node.level < num_hops) {
dgl_id_t dst_id = cur_node.vertex_id;
tmp_src_list.clear();
tmp_sampled_list.clear();
GetSrcList(col_list, indptr, dst_id, tmp_src_list);
GetSample(tmp_src_list, num_neighbor, tmp_sampled_list);
for (size_t i = 0; i < tmp_sampled_list.size(); ++i) {
auto got = sub_ver_mp.find(tmp_sampled_list[i]);
if (got == sub_ver_mp.end()) {
sub_vertices_count++;
sub_ver_mp[tmp_sampled_list[i]] = true;
ver_node new_node;
new_node.vertex_id = tmp_sampled_list[i];
new_node.level = cur_node.level + 1;
node_queue.push(new_node);
}
if (sub_vertices_count >= max_num_vertices) {
break;
}
}
}
node_queue.pop();
}

// Copy sub_ver_mp to output
dgl_id_t idx = 0;
for (auto& data: sub_ver_mp) {
if (data.second) {
*(out+idx) = data.first;
idx++;
}
}
// The rest data will be set to -1
for (dgl_id_t i = idx; i < max_num_vertices; ++i) {
*(out+i) = -1;
}
}

} // op
} // mxnet

#endif // MXNET_OPERATOR_CONTRIB_CSR_NEIGHBORHOOD_SAMPLE_INL_H_
72 changes: 72 additions & 0 deletions src/operator/contrib/csr_neighborhood_sample.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file csr_neighborhood_sample.cc
* \brief
*/
#include "./csr_neighborhood_sample-inl.h"

namespace mxnet {
namespace op {

/*
Usage:
import mxnet as mx
import numpy as np
shape = (5, 5)
data_np = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], dtype=np.int64)
indices_np = np.array([1,2,3,4,0,2,3,4,0,1,3,4,0,1,2,4,0,1,2,3], dtype=np.int64)
indptr_np = np.array([0, 4,8,12,16,20], dtype=np.int64)
a = mx.nd.sparse.csr_matrix((data_np, indices_np, indptr_np), shape=shape)
a.asnumpy()
seed = mx.nd.array([0], dtype=np.int64)
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=1, max_num_vertices=5)
out.asnumpy()
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=2, max_num_vertices=5)
out.asnumpy()
seed = mx.nd.array([0,4], dtype=np.int64)
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=2, max_num_vertices=5)
out.asnumpy()
*/

DMLC_REGISTER_PARAMETER(NeighborSampleParam);

NNVM_REGISTER_OP(_contrib_neighbor_sample)
.MXNET_DESCRIBE("")
.set_attr_parser(ParamParser<NeighborSampleParam>)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<FInferStorageType>("FInferStorageType", CSRNeighborSampleStorageType)
.set_attr<nnvm::FInferShape>("FInferShape", CSRNeighborSampleShape)
.set_attr<nnvm::FInferType>("FInferType", CSRNeighborSampleType)
.set_attr<FComputeEx>("FComputeEx<cpu>", CSRNeighborSampleComputeExCPU)
.add_argument("csr_matrix", "NDArray-or-Symbol", "csr matrix")
.add_argument("seed_array", "NDArray-or-Symbol", "seed vertices")
.add_arguments(NeighborSampleParam::__FIELDS__());

} // op
} // mxnet

0 comments on commit 81e6281

Please sign in to comment.