-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature][Sampler] Sort CSR by tag #1664
Conversation
include/dgl/aten/csr.h
Outdated
/*! | ||
* \brief Sort the colunm index according to a node feature called tag | ||
* | ||
* \return the split positions of different tags |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also document the argument. Please include the shape of each argument and return.
A high-level question: should sort_csr and sort_csc, as well as their in-place variants, be exposed to the users? I feel it is dedicated to BiasedSampler. |
M2C: The first question is whether we can hide this inside BiasedSampler. The answer is yes. We can sort the graph in the init function and pass the returned segment position to the sampling API. Then why should we still expose this API to the end users? The reasons are two. First, the implementation of the BiasedSampler will be in PyThon so having an API docstring for this sorting routine makes it easier for user to understand what happens internally. Second, the functionality is general enough that might help other cases besides BiasedSampler. These sorting APIs are intended for the users with knowledge of (1) sparse storage format such as CSR/CSC and (2) why sorting the adjlist helps the performance, helping them to control the internal storage of a graph. |
python/dgl/transform.py
Outdated
After sorting, edges whose destination shares the same | ||
tag will be arranged in a consecutive range. Note that this | ||
will not change the edge ID. It only changes the order in the | ||
internal CSR storage. As such, the graph must allow CSR storage. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggested change
def sort_out_edges(g, tag, tag_offset_name='_TAG_OFFSET'):
"""Return a new graph which sorts the out edges of each node.
Sort the out edges according to the given destination node tags in integer.
A typical use case is to sort the edges by the destination node types, where
the tags represent destination node types. After sorting, edges sharing
the same tag will be arranged in a consecutive range in
a node's adjacency list. Following is an example:
Consider a graph as follows:
0 -> 0, 1, 2, 3, 4
1 -> 0, 1, 2
Given node tags [1, 1, 0, 2, 0], each node's adjacency list
will be sorted as follows:
0 -> 2, 4, 0, 1, 3
1 -> 2, 0, 1
The function will also returns the starting offsets of the tag
segments in a tensor of shape `(N, max_tag+2)`. For node `i`,
its out-edges connecting to node tag `j` is stored between
`tag_offsets[i][j]` ~ `tag_offsets[i][j+1]`. Since the offsets
can be viewed node data, we store it in the
`ndata` of the returned graph. Users can specify the
ndata name by the `tag_pos_name` argument.
Note that the function will not change the edge ID neither
how the edge features are stored. The input graph must
allow CSR format. Graph must be on CPU.
If the input graph is heterogenous, it must have only one edge
type and two node types (i.e., source and destination node types).
In this case, the provided node tags are for the destination nodes,
and the tag offsets are stored in the source node data.
The sorted graph and the calculated tag offsets are needed by
certain operators that consider node tags. See xxx for an example.
Examples
-----------
...
Parameters
------------
g : DGLGraph
The input graph.
tag : Tensor
Integer tensor of shape `(N,)`, `N` being the number of (destination) nodes.
tag_offset_name : str
The name of the node feature to store tag offsets.
Returns
-------
g_sorted : DGLGraph
A new graph whose out edges are sorted. The node/edge features of the
input graph is shallow-copied over.
- `g_sorted.ndata[tag_offset_name]` : Tensor of shape `(N, max_tag + 2)`. If
`g` is heterogeneous, get from `g_sorted.srcdata`.
"""
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to support None tag at the moment. We could add it later when there is a request.
HeteroGraphRef hg = args[0]; | ||
NDArray tag = args[1]; | ||
int64_t num_tag = args[2]; | ||
const auto csr = hg->GetCSRMatrix(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add check for device. Raise error if it is on GPU.
Description
To implement efficient biased sampler based on tag, we need to first sort CSR matrix by tag.
Currently, it only supports homograph and bipartite graph.
Checklist
Please feel free to remove inapplicable items for your PR.
or have been fixed to be compatible with this change
Changes