-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replacing GeometricFlux.jl to boost performances #186
Conversation
variableFeatures .+ mask, # F'xAxB | ||
globalFeatures .+ mask, # G'xAxB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
smart way to generate matrices with appropriate dimensions ! I double-checked the operations, and everything looks correct. what are the pros of using .+
compared to the method repeat
?
@@ -21,8 +21,7 @@ function (nn::NNStructure)(x::AbstractVector{<:NonTabularTrajectoryState}) | |||
return hcat(qval...) | |||
end | |||
|
|||
include("weighted_graph_gat.jl") | |||
include("geometricflux.jl") | |||
#include("weighted_graph_gat.jl") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this file still usefull as we tend to be fully autonomous with respect to GeometricFlux ?
mutable struct FeaturedGraph{T <: AbstractMatrix, N <: AbstractMatrix, E <: AbstractArray, G <: AbstractVector} <: AbstractFeaturedGraph | ||
graph::T | ||
nf::N | ||
ef::E | ||
gf::G | ||
directed::Bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the benefit of using parametric types {T <: AbstractMatrix, N <: AbstractMatrix, E <: AbstractArray, G <: AbstractVector}
for the FeaturedGraph
struct ? Is it from GeometricFlux ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's from GeometricFlux.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still reviewing
ngraphs = length(fgs) | ||
maxNodes = Base.maximum(nv, fgs) | ||
nfLength = size(fgs[1].nf, 1) | ||
efLength = size(fgs[1].ef, 1) | ||
gfLength = size(fgs[1].gf, 1) | ||
|
||
graph = zeros(T, maxNodes, maxNodes, ngraphs) | ||
nf = zeros(T, nfLength, maxNodes, ngraphs) | ||
ef = zeros(T, efLength, maxNodes, maxNodes, ngraphs) | ||
gf = zeros(T, gfLength, ngraphs) | ||
|
||
for (i, fg) in enumerate(fgs) | ||
graph[1:nv(fg),1:nv(fg),i] = fg.graph | ||
nf[:, 1:nv(fg), i] = fg.nf | ||
ef[:, 1:nv(fg), 1:nv(fg), i] = fg.ef | ||
gf[:, i] = fg.gf | ||
end | ||
|
||
return BatchedFeaturedGraph{T}(graph, nf, ef, gf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me !
# ========== Accessing ========== | ||
# code from GraphSignals.jl | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the following code used elsewhere in the project ?
mutable struct FeaturedGraph{T <: AbstractMatrix, N <: AbstractMatrix, E <: AbstractArray, G <: AbstractVector} <: AbstractFeaturedGraph | ||
graph::T | ||
nf::N | ||
ef::E | ||
gf::G | ||
directed::Bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's from GeometricFlux.
The PR is complete and all test were working fine. But there seems to be an issue with one of our dependencies, that appeared only recently. I won't be able to investigate further this issue, but I don't think it's wise to merge a PR with a failing test. I don't think the problem is related with the new code, only with the new modules versions that were previously restricted by GeometricFlux. Outside of this, SeaPearl is working just fine and I have been able to run many experiments with this code. |
Hi people, |
Hello Carlo, |
function BatchedFeaturedGraph{T}(graph, nf, ef, gf) where T <: Real | ||
check_dimensions(graph, nf, ef, gf) | ||
return new{T}(graph, nf, ef, gf) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the utility of defining two different constructors instead of just keeping the second one ?
SeaPearl v0.4.0 on its way
The goal of this PR is to free SeaPearl from GeometricFlux.jl by providing our own implementation of GNNs, and making it closer to our needs.
The new implementation remains deeply inspired by GraphSignals.jl for the handling of featured graphs, but the computation are now done only through matrix products (and thus are GPU friendly). GeometricFlux.jl was using an iterative approach, which would be more efficient to compute many operations on the same large graph, but in our case we are working with many small graphs, sometimes at the same time and a tensor based solution is far more efficient.
We are still waiting for the first results, but this is the step that should enable us to be within the same order of magnitude than deterministic/random heuristic timewise.
Disclaimer
I am no GPU expert and I'm pretty sure that I am not using the full potential of CUDA.jl. Currently the computation time are comparable on CPU/GPU but I am not able to tell if it is due to the small size of our graphs or to the implementation.