Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory consumption of the GGNN module #852

Closed
m09 opened this issue Sep 10, 2019 · 5 comments
Closed

Memory consumption of the GGNN module #852

m09 opened this issue Sep 10, 2019 · 5 comments

Comments

@m09
Copy link

m09 commented Sep 10, 2019

馃悰 Bug

The memory consumption of the GGNN module prevents its application to medium/big graphs.

To Reproduce

Steps to reproduce the behavior:

  1. Run a GGNN on anything with >10k edges and it'll use gigs and gigs of memory.

Expected behavior

  1. GGNN should not use lots of memory.

Environment

Confirmed to eat a lot of memory on the following environment:

  • DGL Version (e.g., 1.0): 0.3.1
  • Backend Library & Version (e.g., PyTorch 0.4.1, MXNet/Gluon 1.3): PyTorch 1.1.0
  • OS (e.g., Linux): Ubuntu 18.04
  • How you installed DGL (conda, pip, source): pip install dgl-cu100
  • Build command you used (if compiling from source):
  • Python version: 3.7
  • CUDA/cuDNN version (if applicable): 10
  • GPU models and configuration (e.g. V100): GTX 1070
  • Any other relevant information:

Additional context

The problem stems from the formulation used in the implementation, where each edge (containing a type) is embedded: the linear transformation for each type being modelled this way, the weights are replicated as many times as there are edges of a given type. See https://github.com/src-d/formatml/blob/master/formatml/modules/graph_encoders/ggnn.py for an implementation that maintains explicit Linears and doesn't have this problem (it can probably be improved a lot btw, just giving a link here to outline the difference in implementation).

@yzh119
Copy link
Member

yzh119 commented Sep 10, 2019

Thanks for reporting this.
There are two kind of implementations, the first one is to save the weight on edges, and the memory cost of O(E * d_in * d_out); the second one is to first project the node feature with all relation matrices, while saving the memory footprint, the time complexity is high (O(R * d_in * d_out), when R is large, this implementation is not efficient.

We are considering balancing the time/space complexity with a set of new kernels, and we will notify you of our further updates.

@m09
Copy link
Author

m09 commented Sep 10, 2019

Thanks for the fast reply. If both use cases (a huge number of relations + a small enough graph, and any number of relations + a medium/big graph) are common enough, it may be worth splitting into 2 different implementations and let the user decide which one is more adapted.

@yzh119
Copy link
Member

yzh119 commented Sep 28, 2019

Yes, considering we have no plan of new kernels recently, I'll refactor the code and uses the implementation you mentioned. For dgl v0.5 we would have better solution with new segment ops.
Thanks!

@yzh119
Copy link
Member

yzh119 commented Sep 30, 2019

@m09
Copy link
Author

m09 commented Sep 30, 2019

Looks good to me 馃挴

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants