# Graph Network Block
This is a simple implementation of the graph network block proposed in [Rational inductive biases, deep learning, and graph networks. Battaglia *et al*, 2018.](https://arxiv.org/pdf/1806.01261.pdf)

## Definitions
A "graph" is defined as a directed, attributed multi-graph with a global attribute.
  - A node is denoted $\mathbf{v}_i$
  - An edge is denoted $\mathbf{e}_k$
  - The global attribute is $\mathbf{u}$
  - $s_k$ and $r_k$ indicate the indices of the sender and receiver nodes.
  
![multigraph](./multigraph.png)

A *graph* is a 3-tuple $G = (\mathbf{u}, V, E)$. The $\mathbf{u}$ is the global attribute, $V = \{\mathbf{u}_i\}_{i=1:N_v}$ is the set of vertices, $E = \{(\mathbf{e}_k, r_k, s_k)\}_{k=1:N_e}$.

## Internal Structures
A GN block contains 3 updates functions, $\phi$,
$$
\begin{align}
\mathbf{e\prime}_k & = \phi^e(\mathbf{e}_k, \mathbf{v}_{r_k}, \mathbf{v}_{s_k}, \mathbf{u})
\\
\mathbf{v\prime}_i & = \phi^v(\mathbf{\bar e\prime}_i, \mathbf{v}_i, \mathbf{u})
\\
\mathbf{u\prime} & = \phi^u(\mathbf{\bar e\prime}, \mathbf{\bar v\prime}, \mathbf{u})
\end{align}
$$
and 3 aggregation functions, $\rho$,
$$
\begin{align}
\mathbf{\bar e\prime}_i & = \rho^{e \rightarrow v}(E\prime_i) \\
\mathbf{\bar e\prime} & = \rho^{e \rightarrow u}(E\prime) \\
\mathbf{\bar v\prime} & = \rho^{v \rightarrow u}(V\prime) \\
\end{align}
$$
where $E\prime_i = \{(e\prime_k, r_k, s_k)\}_{r_k=i,k=1:N_e}$, $V\prime = \{v\prime_i\}_{i=1:N_v}$, and $E\prime = \cup_iE\prime_i = \{(e\prime_k, r_k, s_k)\}_{k=1:N_e}$.

$\phi^e$ is mapped across all edges, $\phi^v$ is mapped across all nodes, and $\phi^u$ is applied once as a global update. The $\rho$ functions all take a set as input and reduce it to a single element which represents the aggregate information. The $\rho$ functions must be invariant to permutations of their inputs, and should take a variable number of arguments.

In [1]:
# GraphNetwork struct
# - Needs to initialize the ϕ transition weights
# - Each ϕ is basically a dense neural network, but could be multiple layers
# - Each ρ on the other hand is a function that operates on a set; this could potentially
#   include parameters / weights...
# Start small: Simplest ρ and ϕ functions


In [2]:
using Flux
using Flux.Tracker

In [3]:
struct GN
    edge_units
    node_units
    global_units
    Wₑ
    bₑ
    Wᵥ
    bᵥ
    Wᵤ
    bᵤ
end
# Create the variables for the Graph Network block
GN(edge_units, node_units, global_units) = GN(edge_units, node_units, global_units,
                                              param(randn(edge_units, edge_units+2*node_units+global_units)), 
                                              param(zeros(edge_units)), 
                                              param(randn(node_units, edge_units+node_units+global_units)), 
                                              param(zeros(node_units)),
                                              param(randn(global_units, edge_units+node_units+global_units)), 
                                              param(zeros(global_units)))
"""
graph network block
E is an array of triples=(i, j, e) indicates the 
    indicates for the sender, receiver nodes for
    an edge, and the edge attributes
V is an array [[values]] where each element i
    of the array representes the attributes for
    node i.
u is an array of attributes.
"""
function (g::GN)(x)
    E, V, u = x
    # compute edge updates
    E′ = [(i, j, ϕᵉ(g, eₖ, V[i], V[j], u)) for (i, j, eₖ) in E]
    # aggregate per-node edge attributes
    ēᵛ = ρᵛ(E′, V)
    # compute node updates
    V′ = vcat([ϕᵛ(g, V[i], ēᵛ[i], u) for i in 1:size(V)[1]])
    # aggregate edge attributes
    ē = ρᵉ(E′)
    # aggregate node attributes
    v̄ = ρᵘ(V′)
    # compute updated global attribute
    u′ = ϕᵘ(g, ē, v̄, u)
    return (E′, V′, u′)
end
Flux.treelike(GN)

mapchildren (generic function with 7 methods)

In [4]:
"""
Update an edge
"""
function ϕᵉ(g::GN, eₖ, v₁, v₂, u) 
    x = vcat(eₖ, vcat(v₁, vcat(v₂, u)))
    eₖ′ = g.Wₑ * x + g.bₑ
    return eₖ′
end

ϕᵉ

In [5]:
"""
Update the nodes.
"""
function ϕᵛ(g::GN, vᵢ, ēᵢ, u)
    x = vcat(vᵢ, vcat(ēᵢ, u))
    vᵢ′ = g.Wᵥ * x + g.bᵥ
    return vᵢ′
end

ϕᵛ

In [6]:
"""
Update the global state attribute.
"""
function ϕᵘ(g::GN, ē, v̄, u)
    x = vcat(ē, vcat(v̄, u))
    u′ = g.Wᵤ * x + g.bᵤ
    return u′
end

ϕᵘ

In [7]:
"""
Aggregate edge attributes for each node.
"""
function ρᵛ(E′, V)
    # outer loop: each target node j
    # inner loop: sum incoming edges (i, j) for j
    ēᵛ = [reduce(+, zeros(E′[1][3]), eₖ for (i, j_, eₖ) in E′ if j_ == j) for j in 1:size(V, 1)]
    return ēᵛ
end
# test ρᵛ
V = [[0.], [0.]]
E = [(1, 2, [5., 0.]), (2, 2, [3., 0.])]
ēᵛ = [[0., 0.], [8., 0.]]
out = ρᵛ(E, V)
out == ēᵛ

true

In [8]:
"""
Aggregate the edge attributes globally.
"""
function ρᵉ(E′)
    # ex: calculate the mean of all edges
    return mean((e[3] for e in E′))
end
# test ρᵉ
E = [(1, 2, [5., -2.]), (2, 2, [3., -2.])]
ē = [4., -2.]
out = ρᵉ(E)
all(ē .== out)

true

In [9]:
"""
Aggregate all the node attributes globally.
"""
function ρᵘ(V′)
    # ex: sum all the attributes of the nodes
    return sum((V′[i] for i = 1:size(V′, 1)))
end
# test ρᵘ
V = [[0., 10., 3.], [0., -2., -3.]]
v̄= [0., 8., 0.]
out = ρᵘ(V)
all(v̄ .== out)

true

In [10]:
"""
Calculate the loss for a graph.
"""
function piecewise_loss(g::GN, x, y)
    (Ê, V̂, û) = g(g(x))
    E, V, u = y
    # edge loss
    Lₑ = sum((E[i][3] .- Ê[i][3]).^2 for i in 1:size(E, 1))
    Lᵥ = sum((V[i] .- V̂[i]).^2 for i in 1:size(V, 1))
    Lᵤ = sum((u .- û).^2)
    return Lₑ, Lᵥ, Lᵤ
end
function loss(g::GN, x, y)
    Lₑ, Lᵥ, Lᵤ = piecewise_loss(g, x, y)
    return sum(Lₑ) + sum(Lᵥ) + sum(Lᵤ)
end


loss (generic function with 1 method)

In [15]:

V = [[3.,], [-5.,], [0.,]]
E = [(1, 3, [1., 0.]), (2, 3, [-2., -1.])]
u = [0.]
x = (E, V, u)

V′= [[3.], [-5.], [12.]]
E′ = [(1, 3, [1., 0.]), (2, 3, [-2., -1.])]
u′ = [12.,]
y = (E′, V′, u′)

g = GN(2, 1, 1)

dataset = repeated((x, y), 20000)
evalcb = () -> @show(piecewise_loss(g, x, y))
opt = ADAM(params(g))



(::#58) (generic function with 1 method)

.repeated instead.
  likely near In[15]:16


In [16]:
Flux.train!((x, y) -> loss(g, x, y), dataset, opt, cb=Flux.throttle(evalcb, .5))

piecewise_loss(g, x, y) = (param([14.5044, 84.0105]), param([1106.58]), param(653.14))
piecewise_loss(g, x, y) = (param([7.02636, 8.82641]), param([60.3937]), param(11.3887))
piecewise_loss(g, x, y) = (param([3.88931, 7.80544]), param([9.77416]), param(0.374435))
piecewise_loss(g, x, y) = (param([2.12563, 7.521]), param([1.7579]), param(0.0322459))
piecewise_loss(g, x, y) = (param([1.77897, 7.40941]), param([1.11471]), param(0.00608632))
piecewise_loss(g, x, y) = (param([1.61444, 7.03967]), param([0.970088]), param(0.00233593))
piecewise_loss(g, x, y) = (param([1.46557, 6.4733]), param([0.863143]), param(0.00184785))
piecewise_loss(g, x, y) = (param([1.30978, 5.76518]), param([0.743509]), param(0.00216715))
piecewise_loss(g, x, y) = (param([1.13482, 4.93591]), param([0.614246]), param(0.00292799))
piecewise_loss(g, x, y) = (param([0.926159, 3.98774]), param([0.483445]), param(0.0039915))
piecewise_loss(g, x, y) = (param([0.685601, 2.99452]), param([0.363819]), param(0.00506813))
piecew

In [17]:
print(params(g))

Any[param([-0.866753 -1.36886 -1.05818 -0.15445 -1.18158; -0.168985 0.735997 -0.0914339 0.0569512 -0.324104]), param([1.00077, -0.425651]), param([1.0 1.27074 -2.32485 -0.934552]), param([-0.277853]), param([-2.74842 -0.603706 1.18517 -0.543579]), param([-1.85098])]

In [18]:
print(g(g(x)))

(Tuple{Int64,Int64,TrackedArray{…,Array{Float64,1}}}[(1, 3, param([1.0, -2.54979e-9])), (2, 3, param([-2.0, -1.0]))], TrackedArray{…,Array{Float64,1}}[param([3.0]), param([-5.0]), param([12.0])], param([12.0]))