In [19]:
import Pkg; Pkg.activate(".")

[32m[1m  Activating[22m[39m project at `~/SAFT_ML`


In [41]:
using Flux
using Flux: onecold, onehotbatch, logitcrossentropy
using Flux: DataLoader
using GraphNeuralNetworks
using MLDatasets
using MLUtils
using LinearAlgebra, Random, Statistics

using Zygote
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"  # don't ask for dataset download confirmation
Random.seed!(17) # for reproducibility

TaskLocalRNG()

In [21]:
dataset = TUDataset("MUTAG")

dataset TUDataset:
  name        =>    MUTAG
  metadata    =>    Dict{String, Any} with 1 entry
  graphs      =>    188-element Vector{MLDatasets.Graph}
  graph_data  =>    (targets = "188-element Vector{Int64}",)
  num_nodes   =>    3371
  num_edges   =>    7442
  num_graphs  =>    188

In [22]:
dataset.graphs[1]

Graph:
  num_nodes   =>    17
  num_edges   =>    38
  edge_index  =>    ("38-element Vector{Int64}", "38-element Vector{Int64}")
  node_data   =>    (targets = "17-element Vector{Int64}",)
  edge_data   =>    (targets = "38-element Vector{Int64}",)

In [23]:
dataset.graph_data.targets |> union

2-element Vector{Int64}:
  1
 -1

In [24]:
g1, y1 = dataset[1] #get the first graph and target

(graphs = Graph(17, 38), targets = 1)

In [25]:
reduce(vcat, g.node_data.targets for (g, _) in dataset) |> union

7-element Vector{Int64}:
 0
 1
 2
 3
 4
 5
 6

In [26]:
reduce(vcat, g.edge_data.targets for (g, _) in dataset) |> union


4-element Vector{Int64}:
 0
 1
 2
 3

In [27]:
graphs = mldataset2gnngraph(dataset)

188-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}:
 GNNGraph(17, 38) with targets: 17-element, targets: 38-element data
 GNNGraph(13, 28) with targets: 13-element, targets: 28-element data
 GNNGraph(13, 28) with targets: 13-element, targets: 28-element data
 GNNGraph(19, 44) with targets: 19-element, targets: 44-element data
 GNNGraph(11, 22) with targets: 11-element, targets: 22-element data
 GNNGraph(28, 62) with targets: 28-element, targets: 62-element data
 GNNGraph(16, 34) with targets: 16-element, targets: 34-element data
 GNNGraph(20, 44) with targets: 20-element, targets: 44-element data
 GNNGraph(12, 26) with targets: 12-element, targets: 26-element data
 GNNGraph(17, 38) with targets: 17-element, targets: 38-element data
 ⋮
 GNNGraph(28, 66) with targets: 28-element, targets: 66-element data
 GNNGraph(11, 22) with targets: 11-element, targets: 22-element data
 GNNGraph(14, 30) with targets: 14-element, targets: 30-element data
 GNNGraph(22, 50) with t

In [28]:
# graphs = mldataset2gnngraph(dataset)
graphs = [GNNGraph(g,
                    ndata = Float32.(onehotbatch(g.ndata.targets, 0:6)),
                    edata = nothing)
            for g in graphs]
y = onehotbatch(dataset.graph_data.targets, [-1, 1])

2×188 OneHotMatrix(::Vector{UInt32}) with eltype Bool:
 ⋅  1  1  ⋅  1  ⋅  1  ⋅  1  ⋅  ⋅  ⋅  ⋅  …  ⋅  ⋅  1  ⋅  1  1  ⋅  ⋅  1  1  ⋅  1
 1  ⋅  ⋅  1  ⋅  1  ⋅  1  ⋅  1  1  1  1     1  1  ⋅  1  ⋅  ⋅  1  1  ⋅  ⋅  1  ⋅

In [29]:
g = graphs[1]
display((typeof(g), typeof(graphs)))
display(fieldnames(typeof(g)))

g.ndata.x

(GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}, Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}})

(:graph, :num_nodes, :num_edges, :num_graphs, :graph_indicator, :ndata, :edata, :gdata)

7×17 Matrix{Float32}:
 1.0  1.0  1.0  1.0  1.0  1.0  1.0  1.0  …  1.0  1.0  1.0  1.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  1.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  1.0  1.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

In [30]:
graphs[1].edata

DataStore(38)

In [31]:
fieldnames(typeof(graphs[1].ndata))

(:_n, :_data)

In [32]:
train_data, test_data = splitobs((graphs, y), at = 150, shuffle = true) |> getobs

((GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(23, 48) with x: 7×23 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(25, 56) with x: 7×25 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(23, 54) with x: 7×23 data, GNNGraph(12, 26) with x: 7×12 data, GNNGraph(24, 50) with x: 7×24 data, GNNGraph(16, 34) with x: 7×16 data, GNNGraph(28, 62) with x: 7×28 data, GNNGraph(13, 26) with x: 7×13 data  …  GNNGraph(20, 44) with x: 7×20 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(15, 34) with x: 7×15 data, GNNGraph(20, 46) with x: 7×20 data, GNNGraph(19, 42) with x: 7×19 data, GNNGraph(12, 24) with x: 7×12 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(20, 46) with x: 7×20 data, GNNGraph(17, 38) with x: 7×17 data], Bool[0 1 … 0 0; 1 0 … 1 1]), (GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(20, 46) with x: 7×20 data, GNNGraph(14, 28) with x: 7×14 data, GNNGraph(16, 36) with x: 7×16 data, GNNGraph(19, 44) with x: 7

In [33]:
train_loader = DataLoader(train_data, batchsize = 32, shuffle = true)
test_loader = DataLoader(test_data, batchsize = 32, shuffle = false)

2-element DataLoader(::Tuple{Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, OneHotArrays.OneHotMatrix{UInt32, Vector{UInt32}}}, batchsize=32)
  with first element:
  (32-element Vector{GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}}, 2×32 OneHotMatrix(::Vector{UInt32}) with eltype Bool,)

In [34]:
vec_gs, _ = first(train_loader)

(GNNGraph{Tuple{Vector{Int64}, Vector{Int64}, Nothing}}[GNNGraph(23, 54) with x: 7×23 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(20, 44) with x: 7×20 data, GNNGraph(15, 34) with x: 7×15 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(18, 38) with x: 7×18 data, GNNGraph(14, 28) with x: 7×14 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(15, 34) with x: 7×15 data, GNNGraph(26, 56) with x: 7×26 data  …  GNNGraph(19, 44) with x: 7×19 data, GNNGraph(19, 44) with x: 7×19 data, GNNGraph(13, 26) with x: 7×13 data, GNNGraph(17, 38) with x: 7×17 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(22, 50) with x: 7×22 data, GNNGraph(19, 40) with x: 7×19 data, GNNGraph(13, 28) with x: 7×13 data, GNNGraph(12, 24) with x: 7×12 data], Bool[0 1 … 1 1; 1 0 … 0 0])

In [35]:
vec_gs[1]

GNNGraph:
  num_nodes: 23
  num_edges: 54
  ndata:
	x = 7×23 Matrix{Float32}

In [36]:
MLUtils.batch(vec_gs)

GNNGraph:
  num_nodes: 573
  num_edges: 1274
  num_graphs: 32
  ndata:
	x = 7×573 Matrix{Float32}

In [53]:
function create_model(nin, nh, nout)
    GNNChain(GraphConv(nin => nh, relu),
             GraphConv(nh => nh, relu),
             GraphConv(nh => nh),
             GlobalPool(mean), # Average the node features
             Dropout(0.5),
             Dense(nh, nout))
end

function eval_loss_accuracy(model, data_loader, device)
    loss = 0.0
    acc = 0.0
    ntot = 0
    for (g, y) in data_loader
        g, y = MLUtils.batch(g) |> device, y |> device
        n = length(y)
        ŷ = model(g, g.ndata.x)
        loss += logitcrossentropy(ŷ, y) * n
        acc += mean((ŷ .> 0) .== y) * n
        ntot += n
    end
    return (loss = round(loss / ntot, digits = 4),
            acc = round(acc * 100 / ntot, digits = 2))
end

function train!(model; epochs = 3, η = 1e-2, infotime = 1)
    # device = Flux.gpu # uncomment this for GPU training
    device = Flux.cpu
    model = model |> device
    # opt = Flux.setup(Adam(1e-3), model)
    opt = ADAM(1e-3)

    function report(epoch)
        train = eval_loss_accuracy(model, train_loader, device)
        test = eval_loss_accuracy(model, test_loader, device)
        @info (; epoch, train, test)
    end

    report(0)
    for epoch in 1:epochs
        for (g, y) in train_loader
            g, y = MLUtils.batch(g) |> device, y |> device

            loss = 0.0
            loss_fn() = begin
                ŷ = model(g, g.ndata.x)
                for (ŷᵢ, yᵢ) in zip(ŷ, y)
                    # loss += logitcrossentropy(ŷᵢ, yᵢ)
                    loss += (ŷᵢ - yᵢ)^2
                end
                # loss = logitcrossentropy(ŷ, y)
                loss /= length(y)
            end
            grads = Zygote.gradient(Flux.params(model)) do 
                loss_fn()
            end
            # grad = Flux.gradient(model) do model
            #     ŷ = model(g, g.ndata.x)
            #     logitcrossentropy(ŷ, y)
            # end
            Flux.update!(opt, Flux.params(model), grads)
        end
        epoch % infotime == 0 && report(epoch)
    end
end

train! (generic function with 1 method)

In [54]:
nin = 7
nh = 64
nout = 2
model = create_model(nin, nh, nout)
train!(model)

┌ Info: (epoch = 0, train = (loss = 0.8131, acc = 45.0), test = (loss = 0.8708, acc = 50.0))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:36
┌ Info: (epoch = 1, train = (loss = 1.0403, acc = 33.67), test = (loss = 1.042, acc = 34.21))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:36


┌ Info: (epoch = 2, train = (loss = 0.5818, acc = 53.0), test = (loss = 0.6046, acc = 52.63))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:36
┌ Info: (epoch = 3, train = (loss = 0.5406, acc = 50.0), test = (loss = 0.5751, acc = 50.0))
└ @ Main /home/luc/SAFT_ML/3_gnn_example_2.ipynb:36
