# Graph Network Block
This is a simple implementation of the graph network block proposed in [Rational inductive biases, deep learning, and graph networks. Battaglia *et al*, 2018.](https://arxiv.org/pdf/1806.01261.pdf)

## Definitions
A "graph" is defined as a directed, attributed multi-graph with a global attribute.
  - A node is denoted $\mathbf{v}_i$
  - An edge is denoted $\mathbf{e}_k$
  - The global attribute is $\mathbf{u}$
  - $s_k$ and $r_k$ indicate the indices of the sender and receiver nodes.
  
![multigraph](./multigraph.png)

A *graph* is a 3-tuple $G = (\mathbf{u}, V, E)$. The $\mathbf{u}$ is the global attribute, $V = \{\mathbf{u}_i\}_{i=1:N_v}$ is the set of vertices, $E = \{(\mathbf{e}_k, r_k, s_k)\}_{k=1:N_e}$.

## Internal Structures
A GN block contains 3 updates functions, $\phi$,
$$
\begin{align}
\mathbf{e\prime}_k & = \phi^e(\mathbf{e}_k, \mathbf{v}_{r_k}, \mathbf{v}_{s_k}, \mathbf{u})
\\
\mathbf{v\prime}_i & = \phi^v(\mathbf{\bar e\prime}_i, \mathbf{v}_i, \mathbf{u})
\\
\mathbf{u\prime} & = \phi^u(\mathbf{\bar e\prime}, \mathbf{\bar v\prime}, \mathbf{u})
\end{align}
$$
and 3 aggregation functions, $\rho$,
$$
\begin{align}
\mathbf{\bar e\prime}_i & = \rho^{e \rightarrow v}(E\prime_i) \\
\mathbf{\bar e\prime} & = \rho^{e \rightarrow u}(E\prime) \\
\mathbf{\bar v\prime} & = \rho^{v \rightarrow u}(V\prime) \\
\end{align}
$$
where $E\prime_i = \{(e\prime_k, r_k, s_k)\}_{r_k=i,k=1:N_e}$, $V\prime = \{v\prime_i\}_{i=1:N_v}$, and $E\prime = \cup_iE\prime_i = \{(e\prime_k, r_k, s_k)\}_{k=1:N_e}$.

$\phi^e$ is mapped across all edges, $\phi^v$ is mapped across all nodes, and $\phi^u$ is applied once as a global update. The $\rho$ functions all take a set as input and reduce it to a single element which represents the aggregate information. The $\rho$ functions must be invariant to permutations of their inputs, and should take a variable number of arguments.

In [1]:
# GraphNetwork struct
# - Needs to initialize the ϕ transition weights
# - Each ϕ is basically a dense neural network, but could be multiple layers
# - Each ρ on the other hand is a function that operates on a set; this could potentially
#   include parameters / weights...
# Start small: Simplest ρ and ϕ functions


In [2]:
using Flux
using Flux.Tracker

In [3]:
struct GN
    edge_units
    node_units
    global_units
    Wₑ
    bₑ
    Wᵥ
    bᵥ
    Wᵤ
    bᵤ
end
# Create the variables for the Graph Network block
GN(edge_units, node_units, global_units) = GN(edge_units, node_units, global_units,
                                              param(randn(edge_units, edge_units+2*node_units+global_units)), 
                                              param(zeros(edge_units)), 
                                              param(randn(node_units, edge_units+node_units+global_units)), 
                                              param(zeros(node_units)),
                                              param(randn(global_units, edge_units+node_units+global_units)), 
                                              param(zeros(global_units)))
"""
graph network block
E is an array of triples=(i, j, e) indicates the 
    indicates for the sender, receiver nodes for
    an edge, and the edge attributes
V is an array [[values]] where each element i
    of the array representes the attributes for
    node i.
u is an array of attributes.
"""
function (g::GN)(E, V, u)
    # compute edge updates
    E′ = [(i, j, ϕᵉ(g, eₖ, V[i], V[j], u)) for (i, j, eₖ) in E]
    # aggregate per-node edge attributes
    ēᵛ = ρᵛ(E′, V)
    # compute node updates
    V′ = vcat([ϕᵛ(g, V[i], ēᵛ[i], u) for i in 1:size(V)[1]])
    # aggregate edge attributes
    ē = ρᵉ(E′)
    # aggregate node attributes
    v̄ = ρᵘ(V′)
    # compute updated global attribute
    u′ = ϕᵘ(ē, v̄, u)
    
    return (E′, V′, u′)
end

g

In [4]:
"""
Update an edge
"""
function ϕᵉ(g::GN, eₖ, v₁, v₂, u)
    # right now, the input vectors are col vectors
    print("In ϕᵉ: eₖ, ", eₖ, "\n v₁: ", v₁, "\n v₂: ", v₂, "\n u: ", u)
    x = vcat(eₖ, v₁, v₂, u)
    eₖ′ = g.Wₑ * x + g.bₑ
    return eₖ′
end

ϕᵉ

In [11]:
"""
Update the nodes.
"""
function ϕᵛ(g::GN, vᵢ, ēᵢ, u)
    print("In ϕᵛ: vᵢ: ", vᵢ, "ēᵢ: ", ēᵢ, ", u: ", u)
    x = vcat(vᵢ, ēᵢ, u)
    println("x in ϕᵛ: ", x)
    vᵢ′ = g.Wᵥ * x + g.bᵥ
    
    return vᵢ′
end

ϕᵛ

In [6]:
"""
Update the global state attribute.
"""
function ϕᵘ(g::GN, ē, v̄, u)
    x = vcat(ē, v̄, u)
    u′ = g.Wᵤ * x + g.bᵤ
    return u′
end

ϕᵘ

In [15]:
"""
Aggregate edge attributes for each node.
"""
function ρᵛ(E′, V)
    # ex: calculate the sum of
    #     the edges going TO each node i
    ēᵛ = [TrackedArray(zeros(size(E′[1][3])[1])) for _ in 1:size(V)[1]]
    println("In ρᵛ, ēᵛ: ", ēᵛ)
    for (i, j, eₖ) in E′
        ēᵛ[j] += eₖ
    end
    return ēᵛ
end
# test ρᵛ
V = [[0.], [0.]]
E = [(1, 2, [5., 0.]), (2, 2, [3., 0.])]
ēᵛ = [[0., 0.], [8., 0.]]
out = ρᵛ(E, V)
out[1]

In ρᵛ, ēᵛ: TrackedArray{…,Array{Float64,1}}[param([0.0, 0.0]), param([0.0, 0.0])]


Tracked 2-element Array{Float64,1}:
 0.0
 0.0

In [8]:
"""
Aggregate the edge attributes globally.
"""
function ρᵉ(E′)
    # ex: calculate the mean of all edges
    return mean((e[3] for e in E′))
end
# test ρᵉ
E = [(1, 2, [5. -2.]), (2, 2, [3. -2.])]
ē = [4. -2.]
out = ρᵉ(E)
all(ē .== out)

true

In [9]:
"""
Aggregate all the node attributes globally.
"""
function ρᵘ(V′)
    # ex: sum all the attributes of the nodes
    return sum((V′[i] for i = 1:size(V′)[1]))
end
# test ρᵘ
V = [[0., 10., 3.], [0., -2., -3.]]
v̄= [0., 8., 0.]
out = ρᵘ(V)
print(out)
all(v̄ .== out)

[0.0, 8.0, 0.0]

true

In [17]:
# an example graph where the target of the
#    network is the sum the two nodes with edges
#    leading into the final node.
V = [[3.,], [-5.,], [0.,]]
E = [(1, 3, [1., 0.]), (2, 3, [1., 0.])]
u = [0.]
V′= [[0.], [0.], [-2.]]
E′ = [(1, 3, [0., 1.]), (2, 3, [0., 1.])]
u′ = [-2.,]
g = GN(2, 1, 1)
Ê, V̂, û = g(E, V, u)
println(V̂)
println(Ê)
println(û)

In ϕᵉ: eₖ, [1.0, 0.0]
 v₁: [3.0]
 v₂: [0.0]
 u: [0.0]In ϕᵉ: eₖ, [1.0, 0.0]
 v₁: [-5.0]
 v₂: [0.0]
 u: [0.0]In ρᵛ, ēᵛ: TrackedArray{…,Array{Float64,1}}[param([0.0, 0.0]), param([0.0, 0.0]), param([0.0, 0.0])]
In ϕᵛ: vᵢ: [3.0]ēᵢ: param([0.0, 0.0]), u: [0.0]

LoadError: [91mMethodError: Cannot `convert` an object of type TrackedArray{…,Array{Float64,0}} to an object of type Float64
This may have arisen from a call to the constructor Float64(...),
since type constructors fall back to convert methods.[39m