Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature][Sampler] Sort CSR by tag #1664

Merged
merged 67 commits into from
Jun 1, 2021
Merged

Conversation

soodoshll
Copy link
Contributor

Description

To implement efficient biased sampler based on tag, we need to first sort CSR matrix by tag.
Currently, it only supports homograph and bipartite graph.

Checklist

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [$CATEGORY] (such as [NN], [Model], [Doc], [Feature]])
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented
  • To the my best knowledge, examples are either not affected by this change,
    or have been fixed to be compatible with this change
  • Related issue is referred in this PR

Changes

include/dgl/aten/csr.h Outdated Show resolved Hide resolved
/*!
* \brief Sort the colunm index according to a node feature called tag
*
* \return the split positions of different tags
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also document the argument. Please include the shape of each argument and return.

include/dgl/aten/csr.h Outdated Show resolved Hide resolved
include/dgl/base_heterograph.h Outdated Show resolved Hide resolved
python/dgl/transform.py Outdated Show resolved Hide resolved
src/array/cpu/spmat_op_impl.cc Outdated Show resolved Hide resolved
src/array/cpu/spmat_op_impl.cc Outdated Show resolved Hide resolved
src/array/cpu/spmat_op_impl.cc Outdated Show resolved Hide resolved
src/array/cpu/spmat_op_impl.cc Outdated Show resolved Hide resolved
src/array/cpu/spmat_op_impl.cc Outdated Show resolved Hide resolved
@BarclayII
Copy link
Collaborator

BarclayII commented Jun 30, 2020

A high-level question: should sort_csr and sort_csc, as well as their in-place variants, be exposed to the users? I feel it is dedicated to BiasedSampler.

@jermainewang
Copy link
Member

A high-level question: should sort_csr and sort_csc, as well as their in-place variants, be exposed to the users? I feel it is dedicated to BiasedSampler.

M2C:

The first question is whether we can hide this inside BiasedSampler. The answer is yes. We can sort the graph in the init function and pass the returned segment position to the sampling API.

Then why should we still expose this API to the end users? The reasons are two. First, the implementation of the BiasedSampler will be in PyThon so having an API docstring for this sorting routine makes it easier for user to understand what happens internally. Second, the functionality is general enough that might help other cases besides BiasedSampler. These sorting APIs are intended for the users with knowledge of (1) sparse storage format such as CSR/CSC and (2) why sorting the adjlist helps the performance, helping them to control the internal storage of a graph.

include/dgl/aten/csr.h Outdated Show resolved Hide resolved
After sorting, edges whose destination shares the same
tag will be arranged in a consecutive range. Note that this
will not change the edge ID. It only changes the order in the
internal CSR storage. As such, the graph must allow CSR storage.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

def sort_out_edges(g, tag, tag_offset_name='_TAG_OFFSET'):

"""Return a new graph which sorts the out edges of each node.

Sort the out edges according to the given destination node tags in integer.
A typical use case is to sort the edges by the destination node types, where
the tags represent destination node types. After sorting, edges sharing
the same tag will be arranged in a consecutive range in
a node's adjacency list. Following is an example:

    Consider a graph as follows:

    0 -> 0, 1, 2, 3, 4
    1 -> 0, 1, 2

    Given node tags [1, 1, 0, 2, 0], each node's adjacency list
    will be sorted as follows:

    0 -> 2, 4, 0, 1, 3
    1 -> 2, 0, 1

The function will also returns the starting offsets of the tag
segments in a tensor of shape `(N, max_tag+2)`. For node `i`,
its out-edges connecting to node tag `j` is stored between
`tag_offsets[i][j]` ~ `tag_offsets[i][j+1]`. Since the offsets
can be viewed node data, we store it in the
`ndata` of the returned graph. Users can specify the
ndata name by the `tag_pos_name` argument.

Note that the function will not change the edge ID neither
how the edge features are stored. The input graph must
allow CSR format. Graph must be on CPU.

If the input graph is heterogenous, it must have only one edge
type and two node types (i.e., source and destination node types).
In this case, the provided node tags are for the destination nodes,
and the tag offsets are stored in the source node data.

The sorted graph and the calculated tag offsets are needed by
certain operators that consider node tags. See xxx for an example.

Examples
-----------
...

Parameters
------------
g : DGLGraph
    The input graph.
tag : Tensor
    Integer tensor of shape `(N,)`, `N` being the number of (destination) nodes.
tag_offset_name : str
    The name of the node feature to store tag offsets.

Returns
-------
g_sorted : DGLGraph
    A new graph whose out edges are sorted. The node/edge features of the
    input graph is shallow-copied over.
    - `g_sorted.ndata[tag_offset_name]` : Tensor of shape `(N, max_tag + 2)`. If
       `g` is heterogeneous, get from `g_sorted.srcdata`.
"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to support None tag at the moment. We could add it later when there is a request.

HeteroGraphRef hg = args[0];
NDArray tag = args[1];
int64_t num_tag = args[2];
const auto csr = hg->GetCSRMatrix(0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add check for device. Raise error if it is on GPU.

@jermainewang jermainewang merged commit b8fe2b4 into dmlc:master Jun 1, 2021
@soodoshll soodoshll mentioned this pull request Jun 6, 2021
7 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants