# Graph Network Block
This is a simple implementation of the graph network block proposed in [Rational inductive biases, deep learning, and graph networks. Battaglia *et al*, 2018.](https://arxiv.org/pdf/1806.01261.pdf)

## Definitions
A "graph" is defined as a directed, attributed multi-graph with a global attribute.
  - A node is denoted $\mathbf{v}_i$
  - An edge is denoted $\mathbf{e}_k$
  - The global attribute is $\mathbf{u}$
  - $s_k$ and $r_k$ indicate the indices of the sender and receiver nodes.
  
![multigraph](./multigraph.png)

A *graph* is a 3-tuple $G = (\mathbf{u}, V, E)$. The $\mathbf{u}$ is the global attribute, $V = \{\mathbf{u}_i\}_{i=1:N_v}$ is the set of vertices, $E = \{(\mathbf{e}_k, r_k, s_k)\}_{k=1:N_e}$.

## Internal Structures
A GN block contains 3 updates functions, $\phi$,
$$
\begin{align}
\mathbf{e\prime}_k & = \phi^e(\mathbf{e}_k, \mathbf{v}_{r_k}, \mathbf{v}_{s_k}, \mathbf{u})
\\
\mathbf{v\prime}_i & = \phi^v(\mathbf{\bar e\prime}_i, \mathbf{v}_i, \mathbf{u})
\\
\mathbf{u\prime} & = \phi^u(\mathbf{\bar e\prime}, \mathbf{\bar v\prime}, \mathbf{u})
\end{align}
$$
and 3 aggregation functions, $\rho$,
$$
\begin{align}
\mathbf{\bar e\prime}_i & = \rho^{e \rightarrow v}(E\prime_i) \\
\mathbf{\bar e\prime} & = \rho^{e \rightarrow u}(E\prime) \\
\mathbf{\bar v\prime} & = \rho^{v \rightarrow u}(V\prime) \\
\end{align}
$$
where $E\prime_i = \{(e\prime_k, r_k, s_k)\}_{r_k=i,k=1:N_e}$, $V\prime = \{v\prime_i\}_{i=1:N_v}$, and $E\prime = \cup_iE\prime_i = \{(e\prime_k, r_k, s_k)\}_{k=1:N_e}$.

$\phi^e$ is mapped across all edges, $\phi^v$ is mapped across all nodes, and $\phi^u$ is applied once as a global update. The $\rho$ functions all take a set as input and reduce it to a single element which represents the aggregate information. The $\rho$ functions must be invariant to permutations of their inputs, and should take a variable number of arguments.

In [1]:
using Flux
using Flux.Tracker

In [2]:
struct GN
    edge_units
    node_units
    global_units
    Wₑ
    bₑ
    Wᵥ
    bᵥ
    Wᵤ
    bᵤ
end
# Create the variables for the Graph Network block
GN(edge_units, node_units, global_units) = GN(edge_units, node_units, global_units,
                                              param(randn(edge_units, edge_units+2*node_units+global_units)), 
                                              param(zeros(edge_units)), 
                                              param(randn(node_units, edge_units+node_units+global_units)), 
                                              param(zeros(node_units)),
                                              param(randn(global_units, edge_units+node_units+global_units)), 
                                              param(zeros(global_units)))
"""
graph network block
E is an array of triples=(i, j, e) indicates the 
    indicates for the sender, receiver nodes for
    an edge, and the edge attributes
V is an array [[values]] where each element i
    of the array representes the attributes for
    node i.
u is an array of attributes.
"""
function (g::GN)(x)
    E, V, u = x
    # compute edge updates
    E′ = [(i, j, ϕᵉ(g, eₖ, V[i], V[j], u)) for (i, j, eₖ) in E]
    # aggregate per-node edge attributes
    ēᵛ = ρᵛ(E′, V)
    # compute node updates
    V′ = vcat([ϕᵛ(g, V[i], ēᵛ[i], u) for i in 1:size(V)[1]])
    # aggregate edge attributes
    ē = ρᵉ(E′)
    # aggregate node attributes
    v̄ = ρᵘ(V′)
    # compute updated global attribute
    u′ = ϕᵘ(g, ē, v̄, u)
    return (E′, V′, u′)
end
Flux.treelike(GN)

mapchildren (generic function with 7 methods)

In [3]:
"""
Update an edge
"""
function ϕᵉ(g::GN, eₖ, v₁, v₂, u) 
    x = vcat(eₖ, vcat(v₁, vcat(v₂, u)))
    eₖ′ = g.Wₑ * x + g.bₑ
    return eₖ′
end

ϕᵉ

In [4]:
"""
Update the nodes.
"""
function ϕᵛ(g::GN, vᵢ, ēᵢ, u)
    x = vcat(vᵢ, vcat(ēᵢ, u))
    vᵢ′ = g.Wᵥ * x + g.bᵥ
    return vᵢ′
end

ϕᵛ

In [5]:
"""
Update the global state attribute.
"""
function ϕᵘ(g::GN, ē, v̄, u)
    x = vcat(ē, vcat(v̄, u))
    u′ = g.Wᵤ * x + g.bᵤ
    return u′
end

ϕᵘ

In [6]:
"""
Aggregate edge attributes for each node.
"""
function ρᵛ(E′, V)
    # outer loop: each target node j
    # inner loop: sum incoming edges (i, j) for j
    ēᵛ = [reduce(+, zeros(E′[1][3]), eₖ for (i, j_, eₖ) in E′ if j_ == j) for j in 1:size(V, 1)]
    return ēᵛ
end
# test ρᵛ
V = [[0.], [0.]]
E = [(1, 2, [5., 0.]), (2, 2, [3., 0.])]
ēᵛ = [[0., 0.], [8., 0.]]
out = ρᵛ(E, V)
out == ēᵛ

true

In [7]:
"""
Aggregate the edge attributes globally.
"""
function ρᵉ(E′)
    # ex: calculate the mean of all edges
    return mean((e[3] for e in E′))
end
# test ρᵉ
E = [(1, 2, [5., -2.]), (2, 2, [3., -2.])]
ē = [4., -2.]
out = ρᵉ(E)
all(ē .== out)

true

In [8]:
"""
Aggregate all the node attributes globally.
"""
function ρᵘ(V′)
    # ex: sum all the attributes of the nodes
    return sum((V′[i] for i = 1:size(V′, 1)))
end
# test ρᵘ
V = [[0., 10., 3.], [0., -2., -3.]]
v̄= [0., 8., 0.]
out = ρᵘ(V)
all(v̄ .== out)

true

In [9]:
"""
Calculate the loss for a graph.
"""
function piecewise_loss(g::GN, x, y)
    (Ê, V̂, û) = g(x)
    E, V, u = y
    # edge loss
    Lₑ = sum((E[i][3] .- Ê[i][3]).^2 for i in 1:size(E, 1))
    Lᵥ = sum((V[i] .- V̂[i]).^2 for i in 1:size(V, 1))
    Lᵤ = sum((u .- û).^2)
    return Lₑ, Lᵥ, Lᵤ
end
function loss(g::GN, x, y)
    Lₑ, Lᵥ, Lᵤ = piecewise_loss(g, x, y)
    return sum(Lₑ) + sum(Lᵥ) + sum(Lᵤ)
end


loss (generic function with 1 method)

In [10]:
"""
Generate simple synthetic graphs 
"""
function gen_data()
    
    n1, n2 = rand(-20:20), rand(-20:20)
    e1, e2 = rand(Float32, 2)*2-1, rand(Float32, 2)*2-1
    n3 = (n1*e1[1] + e1[2]) + (n2*e2[1] + e2[2])
    
    Vₓ = [[float(n1)], [float(n2)], [zero(Float32)]]
    Eₓ = [(1, 3, e1), (2, 3, e2)]
    uₓ = [0.]
    
    Vₜ = [[float(n1)], [float(n2)], [float(n3)]]
    Eₜ = Eₓ
    uₜ = [float(n3)]
    
    return (Eₓ, Vₓ, uₓ), (Eₜ, Vₜ, uₜ)    
end
gen_data()

((Tuple{Int64,Int64,Array{Float32,1}}[(1, 3, Float32[-0.6799, 0.56471]), (2, 3, Float32[0.611156, 0.907565])], Array{Float64,1}[[0.0], [16.0], [0.0]], [0.0]), (Tuple{Int64,Int64,Array{Float32,1}}[(1, 3, Float32[-0.6799, 0.56471]), (2, 3, Float32[0.611156, 0.907565])], Array{Float64,1}[[0.0], [16.0], [11.2508]], Float32[11.2508]))

In [11]:
V = [[3.,], [-5.,], [0.,]]
E = [(1, 3, [1., 0.]), (2, 3, [-2., -1.])]
u = [0.]
x = (E, V, u)

V′= [[3.], [-5.], [12.]]
E′ = [(1, 3, [1., 0.]), (2, 3, [-2., -1.])]
u′ = [12.,]
y = (E′, V′, u′)

g = GN(2, 1, 1)

dataset = repeated((x, y), 20000)
evalcb = () -> @show(piecewise_loss(g, x, y))
opt = ADAM(params(g))

  likely near In[11]:13


(::#58) (generic function with 1 method)

In [12]:
@show piecewise_loss(g, gen_data()...)
dataset = repeated(gen_data(), 50000)
evalcb = () -> @show loss(g, x, y)

piecewise_loss(g, gen_data()...) = (param([1299.34, 689.549]), param([95.8415]), param(7055.79))




(::#27) (generic function with 1 method)

In [20]:
Flux.train!((x, y) -> loss(g, x, y), dataset, opt, cb=Flux.throttle(evalcb, .5))

loss(g, x, y) = param(783.092)
loss(g, x, y) = param(782.901)
loss(g, x, y) = param(782.615)
loss(g, x, y) = param(782.334)
loss(g, x, y) = param(782.179)
loss(g, x, y) = param(781.776)
loss(g, x, y) = param(781.534)
loss(g, x, y) = param(781.211)
loss(g, x, y) = param(780.885)
loss(g, x, y) = param(780.615)
loss(g, x, y) = param(780.356)
loss(g, x, y) = param(780.037)
loss(g, x, y) = param(779.816)
loss(g, x, y) = param(779.557)
loss(g, x, y) = param(779.148)
loss(g, x, y) = param(778.983)
loss(g, x, y) = param(779.011)
loss(g, x, y) = param(778.171)
loss(g, x, y) = param(778.15)
loss(g, x, y) = param(777.866)
loss(g, x, y) = param(777.654)
loss(g, x, y) = param(777.298)
loss(g, x, y) = param(777.057)
loss(g, x, y) = param(776.791)
loss(g, x, y) = param(776.515)
loss(g, x, y) = param(776.29)
loss(g, x, y) = param(775.904)
loss(g, x, y) = param(775.674)
loss(g, x, y) = param(775.423)
loss(g, x, y) = param(775.099)
loss(g, x, y) = param(774.889)
loss(g, x, y) = param(774.588)
loss(g, x,

In [19]:
print(params(g))

Any[param([1.13876 1.88894 -0.0485112 -0.483028 1.41985; -2.46764 -1.96162 0.106373 -1.17107 -0.298738]), param([0.0089207, -0.825596]), param([1.0 -9.81669 1.47196 1.2392]), param([3.34748e-7]), param([-4.94778 -0.0147858 0.205698 0.565071]), param([3.6433])]

In [17]:
@show x, y = gen_data()
print(g(x))

(x, y) = gen_data() = ((Tuple{Int64,Int64,Array{Float32,1}}[(1, 3, Float32[0.24423, 0.42264]), (2, 3, Float32[-0.112241, 0.946124])], Array{Float64,1}[[-12.0], [17.0], [0.0]], [0.0]), (Tuple{Int64,Int64,Array{Float32,1}}[(1, 3, Float32[0.24423, 0.42264]), (2, 3, Float32[-0.112241, 0.946124])], Array{Float64,1}[[-12.0], [17.0], [-3.4701]], Float32[-3.4701]))
(Tuple{Int64,Int64,TrackedArray{…,Array{Float64,1}}}[(1, 3, param([1.66752, -3.53381])), (2, 3, param([0.843587, -0.596213]))], TrackedArray{…,Array{Float64,1}}[param([-12.0]), param([17.0]), param([-30.9646])], param([-8.03855]))

0.9299999999999994